Skip to content

Commit 232a6c5

Browse files
authored
Merge pull request #27 from flashcatcloud/fix/mcp-session-header-invalidation
fix(mcp): invalidate cached sessions when DynamicHeaders change
2 parents c8022bd + 45207d4 commit 232a6c5

File tree

7 files changed

+80
-30
lines changed

7 files changed

+80
-30
lines changed

.golangci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ linters:
4747
# Exclude G301 (directory permissions) - workspace needs readable directories
4848
# Exclude G304 (file inclusion) - paths are validated via safePath()
4949
# Exclude G306 (file permissions) - workspace files need to be readable
50+
# Exclude G706 (log injection) - we use slog structured logging which is inherently safe
5051
excludes:
5152
- G301
5253
- G304
5354
- G306
55+
- G706
5456

5557
formatters:
5658
enable:

mcp/client.go

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"fmt"
77
"log/slog"
8+
"sort"
89
"strings"
910
"sync"
1011
"time"
@@ -31,54 +32,69 @@ type ClientManager struct {
3132
mu sync.Mutex
3233
clients map[string]*sdk_mcp.Client
3334
sessions map[string]*sdk_mcp.ClientSession
35+
// headerHashes tracks the hash of headers used when creating each session,
36+
// so we can detect when DynamicHeaders change and invalidate stale sessions.
37+
headerHashes map[string]string
3438
}
3539

3640
// NewClientManager creates a new ClientManager.
3741
func NewClientManager() *ClientManager {
3842
return &ClientManager{
39-
clients: make(map[string]*sdk_mcp.Client),
40-
sessions: make(map[string]*sdk_mcp.ClientSession),
43+
clients: make(map[string]*sdk_mcp.Client),
44+
sessions: make(map[string]*sdk_mcp.ClientSession),
45+
headerHashes: make(map[string]string),
4146
}
4247
}
4348

4449
// GetSession returns or creates an MCP session for the given server.
50+
// Sessions are cached by server name. If DynamicHeaders change (e.g. a different
51+
// user or rotated credentials), the stale session is invalidated and recreated.
4552
func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServerConfig, logger *slog.Logger) (*sdk_mcp.ClientSession, error) {
4653
if logger == nil {
4754
logger = slog.With("server", server.Name)
4855
}
4956

57+
serverName := server.Name
58+
currentHash := headersCacheKey(server.Headers, server.DynamicHeaders)
59+
5060
m.mu.Lock()
5161
defer m.mu.Unlock()
5262

53-
serverName := server.Name
54-
5563
logger.Debug("mcp resolving session",
5664
"transport", server.Transport,
5765
"url", server.URL,
5866
"command", server.Command,
67+
"auth_key", maskKey(server.DynamicHeaders["Authorization"]),
5968
)
6069

61-
// Check if session exists and is still valid
70+
// Check if session exists and headers haven't changed
6271
if session, ok := m.sessions[serverName]; ok {
63-
logger.Debug("mcp reusing session")
64-
return session, nil
72+
if m.headerHashes[serverName] == currentHash {
73+
logger.Debug("mcp reusing session")
74+
return session, nil
75+
}
76+
logger.Info("mcp headers changed, invalidating cached session",
77+
"server", serverName,
78+
)
79+
_ = session.Close()
80+
delete(m.sessions, serverName)
81+
delete(m.clients, serverName)
82+
delete(m.headerHashes, serverName)
6583
}
6684

67-
// Create client if not exists
68-
client, ok := m.clients[serverName]
69-
if !ok {
70-
client = sdk_mcp.NewClient(&sdk_mcp.Implementation{
71-
Name: "flashduty-runner",
72-
Version: "1.0.0",
73-
}, nil)
74-
m.clients[serverName] = client
75-
}
85+
// Create client
86+
client := sdk_mcp.NewClient(&sdk_mcp.Implementation{
87+
Name: "flashduty-runner",
88+
Version: "1.0.0",
89+
}, nil)
90+
m.clients[serverName] = client
7691

7792
// Create transport
7893
logger.Info("mcp creating transport",
7994
"transport", server.Transport,
8095
"url", server.URL,
8196
"command", server.Command,
97+
"auth_key", maskKey(server.DynamicHeaders["Authorization"]),
8298
)
8399

84100
transport, err := createTransport(server)
@@ -103,8 +119,11 @@ func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServ
103119
return nil, fmt.Errorf("failed to connect to MCP server '%s': %w", serverName, err)
104120
}
105121

106-
logger.Info("mcp connected")
122+
logger.Info("mcp connected",
123+
"auth_key", maskKey(server.DynamicHeaders["Authorization"]),
124+
)
107125
m.sessions[serverName] = session
126+
m.headerHashes[serverName] = currentHash
108127
return session, nil
109128
}
110129

@@ -174,9 +193,34 @@ func (m *ClientManager) ListTools(ctx context.Context, server *protocol.MCPServe
174193
func (m *ClientManager) invalidateSession(serverName string) {
175194
m.mu.Lock()
176195
delete(m.sessions, serverName)
196+
delete(m.headerHashes, serverName)
177197
m.mu.Unlock()
178198
}
179199

200+
func headersCacheKey(headers, dynamicHeaders map[string]string) string {
201+
encode := func(m map[string]string) string {
202+
keys := make([]string, 0, len(m))
203+
for k := range m {
204+
keys = append(keys, k)
205+
}
206+
sort.Strings(keys)
207+
parts := make([]string, len(keys))
208+
for i, k := range keys {
209+
parts[i] = k + "=" + m[k]
210+
}
211+
return strings.Join(parts, "\x00")
212+
}
213+
return "s:" + encode(headers) + "\x01d:" + encode(dynamicHeaders)
214+
}
215+
216+
func maskKey(key string) string {
217+
key = strings.TrimPrefix(key, "Bearer ")
218+
if len(key) <= 6 {
219+
return key
220+
}
221+
return key[:6] + "***"
222+
}
223+
180224
// Close closes all active sessions and clients.
181225
func (m *ClientManager) Close() {
182226
m.mu.Lock()

mcp/transport.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package mcp
22

33
import (
4+
"context"
45
"log/slog"
56
"net/http"
67
"os"
@@ -12,7 +13,7 @@ import (
1213

1314
// NewStdioTransport creates a new stdio transport for MCP.
1415
func NewStdioTransport(command string, args []string, env map[string]string) sdk_mcp.Transport {
15-
cmd := exec.Command(command, args...)
16+
cmd := exec.CommandContext(context.Background(), command, args...) //nolint:gosec // G204: command comes from cloud-controlled MCP server config
1617
cmd.Env = buildEnv(env)
1718
return &sdk_mcp.CommandTransport{
1819
Command: cmd,
@@ -41,6 +42,7 @@ func NewSSETransport(endpoint string, headers map[string]string, dynamicHeaders
4142
"endpoint", endpoint,
4243
"headers_count", len(headers),
4344
"dynamic_headers_count", len(dynamicHeaders),
45+
"auth_key", maskKey(dynamicHeaders["Authorization"]),
4446
)
4547

4648
return &sdk_mcp.StreamableClientTransport{

workspace/large_output.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,18 @@ func (p *LargeOutputProcessor) truncateContent(content string, filePath string)
134134
// Build truncation message
135135
var sb strings.Builder
136136
sb.WriteString("<output_truncated>\n")
137-
sb.WriteString(fmt.Sprintf("Output too large (%d chars, %d lines).", len(content), totalLines))
137+
fmt.Fprintf(&sb, "Output too large (%d chars, %d lines).", len(content), totalLines)
138138

139139
if filePath != "" {
140-
sb.WriteString(fmt.Sprintf(" Full content saved to: %s\n\n", filePath))
140+
fmt.Fprintf(&sb, " Full content saved to: %s\n\n", filePath)
141141
} else {
142142
sb.WriteString(" Could not save full content.\n\n")
143143
}
144144

145-
sb.WriteString(fmt.Sprintf("Preview (first %d lines):\n```\n%s\n```\n\n", previewLines, preview))
145+
fmt.Fprintf(&sb, "Preview (first %d lines):\n```\n%s\n```\n\n", previewLines, preview)
146146

147147
if filePath != "" {
148-
sb.WriteString(fmt.Sprintf("To read more: read(\"%s\", offset=%d, limit=100)\n", filePath, previewLines))
148+
fmt.Fprintf(&sb, "To read more: read(\"%s\", offset=%d, limit=100)\n", filePath, previewLines)
149149
}
150150

151151
sb.WriteString("</output_truncated>")

workspace/workspace.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ func (w *Workspace) Grep(ctx context.Context, args *protocol.GrepArgs) (*protoco
252252
// Build content string
253253
var sb strings.Builder
254254
for _, match := range res.Matches {
255-
sb.WriteString(fmt.Sprintf("%s:%d:%s\n", match.Path, match.LineNumber, match.Content))
255+
fmt.Fprintf(&sb, "%s:%d:%s\n", match.Path, match.LineNumber, match.Content)
256256
}
257257
content := sb.String()
258258

@@ -274,13 +274,14 @@ func (w *Workspace) Grep(ctx context.Context, args *protocol.GrepArgs) (*protoco
274274
}
275275

276276
func (w *Workspace) grepWithRipgrep(ctx context.Context, args *protocol.GrepArgs) (*protocol.GrepResult, error) {
277-
cmdArgs := []string{"--column", "--line-number", "--no-heading", "--color", "never", "--smart-case"}
277+
cmdArgs := make([]string, 0, 6+2*len(args.Include)+2)
278+
cmdArgs = append(cmdArgs, "--column", "--line-number", "--no-heading", "--color", "never", "--smart-case")
278279
for _, inc := range args.Include {
279280
cmdArgs = append(cmdArgs, "--glob", inc)
280281
}
281282
cmdArgs = append(cmdArgs, args.Pattern, ".")
282283

283-
cmd := exec.CommandContext(ctx, "rg", cmdArgs...)
284+
cmd := exec.CommandContext(ctx, "rg", cmdArgs...) //nolint:gosec // G204: args built from validated grep parameters
284285
cmd.Dir = w.root
285286

286287
var stdout strings.Builder
@@ -393,7 +394,7 @@ func (w *Workspace) executeBashCommand(ctx context.Context, command, workdir str
393394
ctx, cancel := context.WithTimeout(ctx, timeout)
394395
defer cancel()
395396

396-
cmd := exec.CommandContext(ctx, "bash", "-c", command)
397+
cmd := exec.CommandContext(ctx, "bash", "-c", command) //nolint:gosec // G204: command is user-initiated via workspace tool
397398
cmd.Dir = workdir
398399

399400
// Use a limited writer to prevent OOM from very large outputs

ws/client.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,10 @@ func getLinuxVersion() string {
518518
return ""
519519
}
520520

521-
// getCommandOutput executes a command and returns its trimmed output.
522521
func getCommandOutput(name string, args ...string) string {
523-
out, err := exec.Command(name, args...).Output()
522+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
523+
defer cancel()
524+
out, err := exec.CommandContext(ctx, name, args...).Output() //nolint:gosec // G204: args are hardcoded system info commands
524525
if err == nil {
525526
return strings.TrimSpace(string(out))
526527
}
@@ -550,7 +551,7 @@ func getDefaultShell() string {
550551
func getTotalMemoryMB() int64 {
551552
switch runtime.GOOS {
552553
case "darwin":
553-
out, err := exec.Command("sysctl", "-n", "hw.memsize").Output()
554+
out, err := exec.CommandContext(context.Background(), "sysctl", "-n", "hw.memsize").Output() //nolint:gosec // G204: hardcoded command
554555
if err == nil {
555556
var bytes int64
556557
if _, err := fmt.Sscanf(strings.TrimSpace(string(out)), "%d", &bytes); err == nil {

ws/handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func (h *Handler) handleTaskRequest(ctx context.Context, msg *protocol.Message)
100100
"operation", req.Operation,
101101
)
102102

103-
taskCtx, cancel := context.WithCancel(ctx)
103+
taskCtx, cancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel is stored in h.runningTask and called on task completion/cancellation
104104
h.mu.Lock()
105105
h.runningTask[req.TaskID] = cancel
106106
h.mu.Unlock()

0 commit comments

Comments
 (0)