55 "context"
66 "fmt"
77 "log/slog"
8+ "sort"
89 "strings"
910 "sync"
1011 "time"
@@ -31,54 +32,69 @@ type ClientManager struct {
3132 mu sync.Mutex
3233 clients map [string ]* sdk_mcp.Client
3334 sessions map [string ]* sdk_mcp.ClientSession
35+ // headerHashes tracks the hash of headers used when creating each session,
36+ // so we can detect when DynamicHeaders change and invalidate stale sessions.
37+ headerHashes map [string ]string
3438}
3539
3640// NewClientManager creates a new ClientManager.
3741func NewClientManager () * ClientManager {
3842 return & ClientManager {
39- clients : make (map [string ]* sdk_mcp.Client ),
40- sessions : make (map [string ]* sdk_mcp.ClientSession ),
43+ clients : make (map [string ]* sdk_mcp.Client ),
44+ sessions : make (map [string ]* sdk_mcp.ClientSession ),
45+ headerHashes : make (map [string ]string ),
4146 }
4247}
4348
4449// GetSession returns or creates an MCP session for the given server.
50+ // Sessions are cached by server name. If DynamicHeaders change (e.g. a different
51+ // user or rotated credentials), the stale session is invalidated and recreated.
4552func (m * ClientManager ) GetSession (ctx context.Context , server * protocol.MCPServerConfig , logger * slog.Logger ) (* sdk_mcp.ClientSession , error ) {
4653 if logger == nil {
4754 logger = slog .With ("server" , server .Name )
4855 }
4956
57+ serverName := server .Name
58+ currentHash := headersCacheKey (server .Headers , server .DynamicHeaders )
59+
5060 m .mu .Lock ()
5161 defer m .mu .Unlock ()
5262
53- serverName := server .Name
54-
5563 logger .Debug ("mcp resolving session" ,
5664 "transport" , server .Transport ,
5765 "url" , server .URL ,
5866 "command" , server .Command ,
67+ "auth_key" , maskKey (server .DynamicHeaders ["Authorization" ]),
5968 )
6069
61- // Check if session exists and is still valid
70+ // Check if session exists and headers haven't changed
6271 if session , ok := m .sessions [serverName ]; ok {
63- logger .Debug ("mcp reusing session" )
64- return session , nil
72+ if m .headerHashes [serverName ] == currentHash {
73+ logger .Debug ("mcp reusing session" )
74+ return session , nil
75+ }
76+ logger .Info ("mcp headers changed, invalidating cached session" ,
77+ "server" , serverName ,
78+ )
79+ _ = session .Close ()
80+ delete (m .sessions , serverName )
81+ delete (m .clients , serverName )
82+ delete (m .headerHashes , serverName )
6583 }
6684
67- // Create client if not exists
68- client , ok := m .clients [serverName ]
69- if ! ok {
70- client = sdk_mcp .NewClient (& sdk_mcp.Implementation {
71- Name : "flashduty-runner" ,
72- Version : "1.0.0" ,
73- }, nil )
74- m .clients [serverName ] = client
75- }
85+ // Create client
86+ client := sdk_mcp .NewClient (& sdk_mcp.Implementation {
87+ Name : "flashduty-runner" ,
88+ Version : "1.0.0" ,
89+ }, nil )
90+ m .clients [serverName ] = client
7691
7792 // Create transport
7893 logger .Info ("mcp creating transport" ,
7994 "transport" , server .Transport ,
8095 "url" , server .URL ,
8196 "command" , server .Command ,
97+ "auth_key" , maskKey (server .DynamicHeaders ["Authorization" ]),
8298 )
8399
84100 transport , err := createTransport (server )
@@ -103,8 +119,11 @@ func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServ
103119 return nil , fmt .Errorf ("failed to connect to MCP server '%s': %w" , serverName , err )
104120 }
105121
106- logger .Info ("mcp connected" )
122+ logger .Info ("mcp connected" ,
123+ "auth_key" , maskKey (server .DynamicHeaders ["Authorization" ]),
124+ )
107125 m .sessions [serverName ] = session
126+ m .headerHashes [serverName ] = currentHash
108127 return session , nil
109128}
110129
@@ -174,9 +193,34 @@ func (m *ClientManager) ListTools(ctx context.Context, server *protocol.MCPServe
174193func (m * ClientManager ) invalidateSession (serverName string ) {
175194 m .mu .Lock ()
176195 delete (m .sessions , serverName )
196+ delete (m .headerHashes , serverName )
177197 m .mu .Unlock ()
178198}
179199
200+ func headersCacheKey (headers , dynamicHeaders map [string ]string ) string {
201+ encode := func (m map [string ]string ) string {
202+ keys := make ([]string , 0 , len (m ))
203+ for k := range m {
204+ keys = append (keys , k )
205+ }
206+ sort .Strings (keys )
207+ parts := make ([]string , len (keys ))
208+ for i , k := range keys {
209+ parts [i ] = k + "=" + m [k ]
210+ }
211+ return strings .Join (parts , "\x00 " )
212+ }
213+ return "s:" + encode (headers ) + "\x01 d:" + encode (dynamicHeaders )
214+ }
215+
216+ func maskKey (key string ) string {
217+ key = strings .TrimPrefix (key , "Bearer " )
218+ if len (key ) <= 6 {
219+ return key
220+ }
221+ return key [:6 ] + "***"
222+ }
223+
180224// Close closes all active sessions and clients.
181225func (m * ClientManager ) Close () {
182226 m .mu .Lock ()
0 commit comments