-
Notifications
You must be signed in to change notification settings - Fork 368
Expand file tree
/
Copy pathproxied_tools.go
More file actions
466 lines (395 loc) · 14.2 KB
/
proxied_tools.go
File metadata and controls
466 lines (395 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
package mcpgrafana
import (
"context"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/grafana/grafana-openapi-client-go/client/datasources"
)
const (
// mcpProbeTimeout is the timeout for probing a single datasource's MCP endpoint.
// This is kept short to avoid slow startup when datasources are unreachable.
mcpProbeTimeout = 5 * time.Second
)
// MCPDatasourceConfig defines configuration for a datasource type that supports MCP
type MCPDatasourceConfig struct {
Type string
EndpointPath string // e.g., "/api/mcp"
}
// mcpEnabledDatasources is a registry of datasource types that support MCP
var mcpEnabledDatasources = map[string]MCPDatasourceConfig{
"tempo": {Type: "tempo", EndpointPath: "/api/mcp"},
// Future: add other datasource types here
}
// DiscoveredDatasource represents a datasource that supports MCP
type DiscoveredDatasource struct {
UID string
Name string
Type string
MCPURL string // The MCP endpoint URL
}
// discoverMCPDatasources discovers datasources that support MCP
// Returns a list of datasources with MCP endpoints
func discoverMCPDatasources(ctx context.Context, logger *slog.Logger) ([]DiscoveredDatasource, error) {
gc := GrafanaClientFromContext(ctx)
if gc == nil {
return nil, fmt.Errorf("grafana client not found in context")
}
var discovered []DiscoveredDatasource
// List all datasources
resp, err := gc.Datasources.GetDataSourcesWithParams(
datasources.NewGetDataSourcesParamsWithContext(ctx),
)
if err != nil {
return nil, fmt.Errorf("failed to list datasources: %w", err)
}
// Get the Grafana base URL from context
config := GrafanaConfigFromContext(ctx)
if config.URL == "" {
return nil, fmt.Errorf("grafana url not found in context")
}
grafanaBaseURL := config.URL
// Filter for datasources that support MCP and collect candidates
type candidate struct {
uid string
name string
dsType string
dsConfig MCPDatasourceConfig
}
var candidates []candidate
for _, ds := range resp.Payload {
// Check if this datasource type supports MCP
dsConfig, supported := mcpEnabledDatasources[ds.Type]
if !supported {
continue
}
candidates = append(candidates, candidate{
uid: ds.UID,
name: ds.Name,
dsType: ds.Type,
dsConfig: dsConfig,
})
}
if len(candidates) == 0 {
logger.DebugContext(ctx, "no candidate MCP datasources found")
return nil, nil
}
transport, err := BuildTransport(&config, nil)
if err != nil {
return nil, fmt.Errorf("failed to create transport: %w", err)
}
httpClient := &http.Client{
Transport: transport,
Timeout: mcpProbeTimeout,
}
// Probe candidates in parallel with timeout
type probeResult struct {
ds DiscoveredDatasource
enabled bool
}
results := make(chan probeResult, len(candidates))
var wg sync.WaitGroup
for _, c := range candidates {
wg.Add(1)
go func(c candidate) {
defer wg.Done()
probeURL := fmt.Sprintf("%s/api/datasources/proxy/uid/%s%s", grafanaBaseURL, c.uid, c.dsConfig.EndpointPath)
probeCtx, cancel := context.WithTimeoutCause(ctx, mcpProbeTimeout,
fmt.Errorf("timed out after %s probing MCP endpoint for datasource %s (%s) at %s", mcpProbeTimeout, c.name, c.uid, probeURL))
defer cancel()
// Check if the datasource instance has MCP enabled
// We use a DELETE request to probe the MCP endpoint since:
// - GET would start an event stream and hang
// - POST doesn't work with the Grafana OpenAPI client
// - DELETE returns 200 if MCP is enabled, 404 if not
req, err := http.NewRequestWithContext(probeCtx, http.MethodDelete, probeURL, nil)
if err != nil {
logger.DebugContext(ctx, "failed to create probe request", "datasource", c.uid, "error", err)
return
}
resp, err := httpClient.Do(req)
if err != nil {
logger.DebugContext(ctx, "MCP probe failed", "datasource", c.uid, "error", contextCauseOrErr(probeCtx, err))
return
}
defer func() { _ = resp.Body.Close() }()
// MCP is enabled if we get a 200 response
if resp.StatusCode == http.StatusOK {
mcpURL := fmt.Sprintf("%s/api/datasources/proxy/uid/%s%s", grafanaBaseURL, c.uid, c.dsConfig.EndpointPath)
results <- probeResult{
ds: DiscoveredDatasource{
UID: c.uid,
Name: c.name,
Type: c.dsType,
MCPURL: mcpURL,
},
enabled: true,
}
} else {
logger.DebugContext(ctx, "MCP probe returned non-OK status", "datasource", c.uid, "status", resp.StatusCode, "url", probeURL)
}
}(c)
}
// Wait for all probes to complete and close results channel
go func() {
wg.Wait()
close(results)
}()
// Collect results
for result := range results {
if result.enabled {
discovered = append(discovered, result.ds)
}
}
logger.DebugContext(ctx, "discovered MCP datasources", "count", len(discovered), "candidates", len(candidates))
return discovered, nil
}
// addDatasourceUidParameter adds a required datasourceUid parameter to a tool's input schema
func addDatasourceUidParameter(tool mcp.Tool, datasourceType string) mcp.Tool {
modifiedTool := tool
// Prefix tool name with datasource type (e.g., "tempo_traceql-search")
modifiedTool.Name = datasourceType + "_" + tool.Name
// Add datasourceUid to the input schema
if modifiedTool.InputSchema.Properties == nil {
modifiedTool.InputSchema.Properties = make(map[string]any)
}
modifiedTool.InputSchema.Properties["datasourceUid"] = map[string]any{
"type": "string",
"description": "UID of the " + datasourceType + " datasource to query",
}
// Add to required fields
modifiedTool.InputSchema.Required = append(modifiedTool.InputSchema.Required, "datasourceUid")
return modifiedTool
}
// parseProxiedToolName extracts datasource type and original tool name from a proxied tool name
// Format: <datasource_type>_<original_tool_name>
// Returns: datasourceType, originalToolName, error
func parseProxiedToolName(toolName string) (string, string, error) {
parts := strings.SplitN(toolName, "_", 2)
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid proxied tool name format: %s", toolName)
}
return parts[0], parts[1], nil
}
// ToolManager manages proxied tools (either per-session or server-wide)
type ToolManager struct {
sm *SessionManager
server *server.MCPServer
logger *slog.Logger
// Whether to enable proxied tools.
enableProxiedTools bool
// For stdio transport: store clients at manager level (single-tenant).
// These will be unused for HTTP/SSE transports.
serverMode bool // true if using server-wide tools (stdio), false for per-session (HTTP/SSE)
serverClients map[string]*ProxiedClient
clientsMutex sync.RWMutex
}
// NewToolManager creates a new ToolManager
func NewToolManager(sm *SessionManager, mcpServer *server.MCPServer, opts ...toolManagerOption) *ToolManager {
tm := &ToolManager{
sm: sm,
server: mcpServer,
serverClients: make(map[string]*ProxiedClient),
}
for _, opt := range opts {
opt(tm)
}
if tm.logger == nil {
tm.logger = slog.Default()
}
return tm
}
type toolManagerOption func(*ToolManager)
// WithProxiedTools sets whether proxied tools are enabled
func WithProxiedTools(enabled bool) toolManagerOption {
return func(tm *ToolManager) {
tm.enableProxiedTools = enabled
}
}
// WithToolManagerLogger sets the logger for the ToolManager.
func WithToolManagerLogger(logger *slog.Logger) toolManagerOption {
return func(tm *ToolManager) {
tm.logger = logger
}
}
// loggerFromCtx returns the logger from the context's GrafanaConfig if available,
// otherwise falls back to the ToolManager's logger.
func (tm *ToolManager) loggerFromCtx(ctx context.Context) *slog.Logger {
config := GrafanaConfigFromContext(ctx)
if config.Logger != nil {
return config.Logger
}
return tm.logger
}
// InitializeAndRegisterServerTools discovers datasources and registers tools on the server (for stdio transport)
// This should be called once at server startup for single-tenant stdio servers
func (tm *ToolManager) InitializeAndRegisterServerTools(ctx context.Context) error {
if !tm.enableProxiedTools {
return nil
}
// Mark as server mode (stdio transport)
tm.serverMode = true
logger := tm.loggerFromCtx(ctx)
// Discover datasources with MCP support
discovered, err := discoverMCPDatasources(ctx, logger)
if err != nil {
return fmt.Errorf("failed to discover MCP datasources: %w", err)
}
if len(discovered) == 0 {
logger.InfoContext(ctx, "no MCP datasources discovered")
return nil
}
// Connect to each datasource and store in manager
tm.clientsMutex.Lock()
for _, ds := range discovered {
client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
if err != nil {
logger.ErrorContext(ctx, "failed to create proxied client", "datasource", ds.UID, "error", err)
continue
}
key := ds.Type + "_" + ds.UID
tm.serverClients[key] = client
}
clientCount := len(tm.serverClients)
tm.clientsMutex.Unlock()
if clientCount == 0 {
logger.WarnContext(ctx, "no proxied clients created")
return nil
}
logger.InfoContext(ctx, "connected to proxied MCP servers", "datasources", clientCount)
// Collect and register all unique tools
tm.clientsMutex.RLock()
toolMap := make(map[string]mcp.Tool)
for _, client := range tm.serverClients {
for _, tool := range client.ListTools() {
toolName := client.DatasourceType + "_" + tool.Name
if _, exists := toolMap[toolName]; !exists {
modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType)
toolMap[toolName] = modifiedTool
}
}
}
tm.clientsMutex.RUnlock()
// Register tools on the server (not per-session)
for toolName, tool := range toolMap {
handler := NewProxiedToolHandler(tm.sm, tm, toolName)
tm.server.AddTool(tool, handler.Handle)
}
logger.InfoContext(ctx, "registered proxied tools on server", "tools", len(toolMap))
return nil
}
// InitializeAndRegisterProxiedTools discovers datasources, creates clients, and registers tools per-session
// This should be called in OnBeforeListTools and OnBeforeCallTool hooks for HTTP/SSE transports
func (tm *ToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, session server.ClientSession) {
if !tm.enableProxiedTools {
return
}
logger := tm.loggerFromCtx(ctx)
sessionID := session.SessionID()
state, exists := tm.sm.GetSession(sessionID)
if !exists {
// Session exists in server context but not in our SessionManager yet
tm.sm.CreateSession(ctx, session)
state, exists = tm.sm.GetSession(sessionID)
if !exists {
logger.ErrorContext(ctx, "failed to create session in SessionManager", "sessionID", sessionID)
return
}
}
// Step 1: Discover and connect (guaranteed to run exactly once per session)
state.initOnce.Do(func() {
// Discover datasources with MCP support
discovered, err := discoverMCPDatasources(ctx, logger)
if err != nil {
logger.ErrorContext(ctx, "failed to discover MCP datasources", "error", err)
state.mutex.Lock()
state.proxiedToolsInitialized = true
state.mutex.Unlock()
return
}
state.mutex.Lock()
// For each discovered datasource, create a proxied client
for _, ds := range discovered {
client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
if err != nil {
logger.ErrorContext(ctx, "failed to create proxied client", "datasource", ds.UID, "error", err)
continue
}
// Store the client
key := ds.Type + "_" + ds.UID
state.proxiedClients[key] = client
}
state.proxiedToolsInitialized = true
state.mutex.Unlock()
logger.InfoContext(ctx, "connected to proxied MCP servers", "session", sessionID, "datasources", len(state.proxiedClients))
})
// Step 2: Register tools with the MCP server
state.mutex.Lock()
defer state.mutex.Unlock()
// Check if tools already registered
if len(state.proxiedTools) > 0 {
return
}
// Check if we have any clients (discovery should have happened above)
if len(state.proxiedClients) == 0 {
return
}
// First pass: collect all unique tools and track which datasources support them
toolMap := make(map[string]mcp.Tool) // unique tools by name
for key, client := range state.proxiedClients {
remoteTools := client.ListTools()
for _, tool := range remoteTools {
// Tool name format: datasourceType_originalToolName (e.g., "tempo_traceql-search")
toolName := client.DatasourceType + "_" + tool.Name
// Store the tool if we haven't seen it yet
if _, exists := toolMap[toolName]; !exists {
// Add datasourceUid parameter to the tool
modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType)
toolMap[toolName] = modifiedTool
}
// Track which datasources support this tool
state.toolToDatasources[toolName] = append(state.toolToDatasources[toolName], key)
}
}
// Second pass: register all unique tools at once (reduces listChanged notifications)
var serverTools []server.ServerTool
for toolName, tool := range toolMap {
handler := NewProxiedToolHandler(tm.sm, tm, toolName)
serverTools = append(serverTools, server.ServerTool{
Tool: tool,
Handler: handler.Handle,
})
state.proxiedTools = append(state.proxiedTools, tool)
}
if err := tm.server.AddSessionTools(sessionID, serverTools...); err != nil {
logger.WarnContext(ctx, "failed to add session tools", "session", sessionID, "error", err)
} else {
logger.InfoContext(ctx, "registered proxied tools", "session", sessionID, "tools", len(state.proxiedTools))
}
}
// GetServerClient retrieves a proxied client from server-level storage (for stdio transport)
func (tm *ToolManager) GetServerClient(datasourceType, datasourceUID string) (*ProxiedClient, error) {
tm.clientsMutex.RLock()
defer tm.clientsMutex.RUnlock()
key := datasourceType + "_" + datasourceUID
client, exists := tm.serverClients[key]
if !exists {
// List available datasources to help with debugging
var availableUIDs []string
for _, c := range tm.serverClients {
if c.DatasourceType == datasourceType {
availableUIDs = append(availableUIDs, c.DatasourceUID)
}
}
if len(availableUIDs) > 0 {
return nil, fmt.Errorf("datasource '%s' not found. Available %s datasources: %v", datasourceUID, datasourceType, availableUIDs)
}
return nil, fmt.Errorf("datasource '%s' not found. No %s datasources with MCP support are configured", datasourceUID, datasourceType)
}
return client, nil
}