diff --git a/pkg/github/actions.go b/pkg/github/actions.go index 7b43f69ce..7d42b0448 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -50,47 +50,45 @@ func ListWorkflows(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional pagination parameters - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Set up list options - opts := &github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - } + // Get optional pagination parameters + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - workflows, resp, err := client.Actions.ListWorkflows(ctx, owner, repo, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list workflows: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Set up list options + opts := &github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + } - r, err := json.Marshal(workflows) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + workflows, resp, err := client.Actions.ListWorkflows(ctx, owner, repo, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list workflows: %w", err) + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(workflows) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -176,75 +174,73 @@ func ListWorkflowRuns(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo", "workflow_id"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - workflowID, err := RequiredParam[string](args, "workflow_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional filtering parameters - actor, err := OptionalParam[string](args, "actor") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := OptionalParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - event, err := OptionalParam[string](args, "event") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - status, err := OptionalParam[string](args, "status") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + workflowID, err := RequiredParam[string](args, "workflow_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get optional pagination parameters - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get optional filtering parameters + actor, err := OptionalParam[string](args, "actor") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := OptionalParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + event, err := OptionalParam[string](args, "event") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + status, err := OptionalParam[string](args, "status") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Set up list options - opts := &github.ListWorkflowRunsOptions{ - Actor: actor, - Branch: branch, - Event: event, - Status: status, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + // Get optional pagination parameters + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - workflowRuns, resp, err := client.Actions.ListWorkflowRunsByFileName(ctx, owner, repo, workflowID, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list workflow runs: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Set up list options + opts := &github.ListWorkflowRunsOptions{ + Actor: actor, + Branch: branch, + Event: event, + Status: status, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - r, err := json.Marshal(workflowRuns) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + workflowRuns, resp, err := client.Actions.ListWorkflowRunsByFileName(ctx, owner, repo, workflowID, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list workflow runs: %w", err) + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(workflowRuns) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -287,76 +283,74 @@ func RunWorkflow(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "workflow_id", "ref"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - workflowID, err := RequiredParam[string](args, "workflow_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ref, err := RequiredParam[string](args, "ref") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional inputs parameter - var inputs map[string]interface{} - if requestInputs, ok := args["inputs"]; ok { - if inputsMap, ok := requestInputs.(map[string]interface{}); ok { - inputs = inputsMap - } - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + workflowID, err := RequiredParam[string](args, "workflow_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ref, err := RequiredParam[string](args, "ref") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - event := github.CreateWorkflowDispatchEventRequest{ - Ref: ref, - Inputs: inputs, + // Get optional inputs parameter + var inputs map[string]interface{} + if requestInputs, ok := args["inputs"]; ok { + if inputsMap, ok := requestInputs.(map[string]interface{}); ok { + inputs = inputsMap } + } - var resp *github.Response - var workflowType string + event := github.CreateWorkflowDispatchEventRequest{ + Ref: ref, + Inputs: inputs, + } - if workflowIDInt, parseErr := strconv.ParseInt(workflowID, 10, 64); parseErr == nil { - resp, err = client.Actions.CreateWorkflowDispatchEventByID(ctx, owner, repo, workflowIDInt, event) - workflowType = "workflow_id" - } else { - resp, err = client.Actions.CreateWorkflowDispatchEventByFileName(ctx, owner, repo, workflowID, event) - workflowType = "workflow_file" - } + var resp *github.Response + var workflowType string - if err != nil { - return nil, nil, fmt.Errorf("failed to run workflow: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - result := map[string]any{ - "message": "Workflow run has been queued", - "workflow_type": workflowType, - "workflow_id": workflowID, - "ref": ref, - "inputs": inputs, - "status": resp.Status, - "status_code": resp.StatusCode, - } + if workflowIDInt, parseErr := strconv.ParseInt(workflowID, 10, 64); parseErr == nil { + resp, err = client.Actions.CreateWorkflowDispatchEventByID(ctx, owner, repo, workflowIDInt, event) + workflowType = "workflow_id" + } else { + resp, err = client.Actions.CreateWorkflowDispatchEventByFileName(ctx, owner, repo, workflowID, event) + workflowType = "workflow_file" + } - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + if err != nil { + return nil, nil, fmt.Errorf("failed to run workflow: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + result := map[string]any{ + "message": "Workflow run has been queued", + "workflow_type": workflowType, + "workflow_id": workflowID, + "ref": ref, + "inputs": inputs, + "status": resp.Status, + "status_code": resp.StatusCode, + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -391,40 +385,38 @@ func GetWorkflowRun(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "run_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - workflowRun, resp, err := client.Actions.GetWorkflowRunByID(ctx, owner, repo, runID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get workflow run: %w", err) - } - defer func() { _ = resp.Body.Close() }() + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - r, err := json.Marshal(workflowRun) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + workflowRun, resp, err := client.Actions.GetWorkflowRunByID(ctx, owner, repo, runID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get workflow run: %w", err) + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(workflowRun) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -459,50 +451,48 @@ func GetWorkflowRunLogs(t translations.TranslationHelperFunc) inventory.ServerTo Required: []string{"owner", "repo", "run_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get the download URL for the logs - url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, owner, repo, runID, 1) - if err != nil { - return nil, nil, fmt.Errorf("failed to get workflow run logs: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - // Create response with the logs URL and information - result := map[string]any{ - "logs_url": url.String(), - "message": "Workflow run logs are available for download", - "note": "The logs_url provides a download link for the complete workflow run logs as a ZIP archive. You can download this archive to extract and examine individual job logs.", - "warning": "This downloads ALL logs as a ZIP file which can be large and expensive. For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id instead.", - "optimization_tip": "Use: get_job_logs with parameters {run_id: " + fmt.Sprintf("%d", runID) + ", failed_only: true} for more efficient failed job debugging", - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Get the download URL for the logs + url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, owner, repo, runID, 1) + if err != nil { + return nil, nil, fmt.Errorf("failed to get workflow run logs: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // Create response with the logs URL and information + result := map[string]any{ + "logs_url": url.String(), + "message": "Workflow run logs are available for download", + "note": "The logs_url provides a download link for the complete workflow run logs as a ZIP archive. You can download this archive to extract and examine individual job logs.", + "warning": "This downloads ALL logs as a ZIP file which can be large and expensive. For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id instead.", + "optimization_tip": "Use: get_job_logs with parameters {run_id: " + fmt.Sprintf("%d", runID) + ", failed_only: true} for more efficient failed job debugging", + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -542,67 +532,65 @@ func ListWorkflowJobs(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo", "run_id"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional filtering parameters - filter, err := OptionalParam[string](args, "filter") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - // Get optional pagination parameters - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get optional filtering parameters + filter, err := OptionalParam[string](args, "filter") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Set up list options - opts := &github.ListWorkflowJobsOptions{ - Filter: filter, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + // Get optional pagination parameters + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - jobs, resp, err := client.Actions.ListWorkflowJobs(ctx, owner, repo, runID, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list workflow jobs: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Set up list options + opts := &github.ListWorkflowJobsOptions{ + Filter: filter, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - // Add optimization tip for failed job debugging - response := map[string]any{ - "jobs": jobs, - "optimization_tip": "For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id=" + fmt.Sprintf("%d", runID) + " to get logs directly without needing to list jobs first", - } + jobs, resp, err := client.Actions.ListWorkflowJobs(ctx, owner, repo, runID, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list workflow jobs: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Add optimization tip for failed job debugging + response := map[string]any{ + "jobs": jobs, + "optimization_tip": "For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id=" + fmt.Sprintf("%d", runID) + " to get logs directly without needing to list jobs first", + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -654,66 +642,64 @@ func GetJobLogs(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional parameters - jobID, err := OptionalIntParam(args, "job_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID, err := OptionalIntParam(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - failedOnly, err := OptionalParam[bool](args, "failed_only") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - returnContent, err := OptionalParam[bool](args, "return_content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - tailLines, err := OptionalIntParam(args, "tail_lines") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // Default to 500 lines if not specified - if tailLines == 0 { - tailLines = 500 - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Validate parameters - if failedOnly && runID == 0 { - return utils.NewToolResultError("run_id is required when failed_only is true"), nil, nil - } - if !failedOnly && jobID == 0 { - return utils.NewToolResultError("job_id is required when failed_only is false"), nil, nil - } + // Get optional parameters + jobID, err := OptionalIntParam(args, "job_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID, err := OptionalIntParam(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + failedOnly, err := OptionalParam[bool](args, "failed_only") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + returnContent, err := OptionalParam[bool](args, "return_content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + tailLines, err := OptionalIntParam(args, "tail_lines") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // Default to 500 lines if not specified + if tailLines == 0 { + tailLines = 500 + } - if failedOnly && runID > 0 { - // Handle failed-only mode: get logs for all failed jobs in the workflow run - return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.GetContentWindowSize()) - } else if jobID > 0 { - // Handle single job mode - return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.GetContentWindowSize()) - } + // Validate parameters + if failedOnly && runID == 0 { + return utils.NewToolResultError("run_id is required when failed_only is true"), nil, nil + } + if !failedOnly && jobID == 0 { + return utils.NewToolResultError("job_id is required when failed_only is false"), nil, nil + } - return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil + if failedOnly && runID > 0 { + // Handle failed-only mode: get logs for all failed jobs in the workflow run + return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.GetContentWindowSize()) + } else if jobID > 0 { + // Handle single job mode + return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.GetContentWindowSize()) } + + return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil }, ) } @@ -902,47 +888,45 @@ func RerunWorkflowRun(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo", "run_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun workflow run", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - result := map[string]any{ - "message": "Workflow run has been queued for re-run", - "run_id": runID, - "status": resp.Status, - "status_code": resp.StatusCode, - } + resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun workflow run", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + result := map[string]any{ + "message": "Workflow run has been queued for re-run", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -977,47 +961,45 @@ func RerunFailedJobs(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo", "run_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun failed jobs", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - result := map[string]any{ - "message": "Failed jobs have been queued for re-run", - "run_id": runID, - "status": resp.Status, - "status_code": resp.StatusCode, - } + resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun failed jobs", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + result := map[string]any{ + "message": "Failed jobs have been queued for re-run", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1052,49 +1034,47 @@ func CancelWorkflowRun(t translations.TranslationHelperFunc) inventory.ServerToo Required: []string{"owner", "repo", "run_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) - if err != nil { - if _, ok := err.(*github.AcceptedError); !ok { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to cancel workflow run", resp, err), nil, nil - } - } - defer func() { _ = resp.Body.Close() }() + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - result := map[string]any{ - "message": "Workflow run has been cancelled", - "run_id": runID, - "status": resp.Status, - "status_code": resp.StatusCode, + resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) + if err != nil { + if _, ok := err.(*github.AcceptedError); !ok { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to cancel workflow run", resp, err), nil, nil } + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + result := map[string]any{ + "message": "Workflow run has been cancelled", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1129,52 +1109,50 @@ func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) inventory.Se Required: []string{"owner", "repo", "run_id"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional pagination parameters - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - // Set up list options - opts := &github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - } + // Get optional pagination parameters + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - artifacts, resp, err := client.Actions.ListWorkflowRunArtifacts(ctx, owner, repo, runID, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow run artifacts", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // Set up list options + opts := &github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + } - r, err := json.Marshal(artifacts) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + artifacts, resp, err := client.Actions.ListWorkflowRunArtifacts(ctx, owner, repo, runID, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow run artifacts", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(artifacts) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1209,49 +1187,47 @@ func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) inventory Required: []string{"owner", "repo", "artifact_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - artifactIDInt, err := RequiredInt(args, "artifact_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - artifactID := int64(artifactIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get the download URL for the artifact - url, resp, err := client.Actions.DownloadArtifact(ctx, owner, repo, artifactID, 1) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get artifact download URL", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - // Create response with the download URL and information - result := map[string]any{ - "download_url": url.String(), - "message": "Artifact is available for download", - "note": "The download_url provides a download link for the artifact as a ZIP archive. The link is temporary and expires after a short time.", - "artifact_id": artifactID, - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + artifactIDInt, err := RequiredInt(args, "artifact_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + artifactID := int64(artifactIDInt) - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Get the download URL for the artifact + url, resp, err := client.Actions.DownloadArtifact(ctx, owner, repo, artifactID, 1) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get artifact download URL", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + // Create response with the download URL and information + result := map[string]any{ + "download_url": url.String(), + "message": "Artifact is available for download", + "note": "The download_url provides a download link for the artifact as a ZIP archive. The link is temporary and expires after a short time.", + "artifact_id": artifactID, + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1287,47 +1263,45 @@ func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) inventory.Serve Required: []string{"owner", "repo", "run_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to delete workflow run logs", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - result := map[string]any{ - "message": "Workflow run logs have been deleted", - "run_id": runID, - "status": resp.Status, - "status_code": resp.StatusCode, - } + resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to delete workflow run logs", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + result := map[string]any{ + "message": "Workflow run logs have been deleted", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1362,40 +1336,38 @@ func GetWorkflowRunUsage(t translations.TranslationHelperFunc) inventory.ServerT Required: []string{"owner", "repo", "run_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, runID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run usage", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - r, err := json.Marshal(usage) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run usage", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(usage) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } diff --git a/pkg/github/actions_test.go b/pkg/github/actions_test.go index 4d56f01aa..aca7a5d4f 100644 --- a/pkg/github/actions_test.go +++ b/pkg/github/actions_test.go @@ -114,7 +114,7 @@ func Test_ListWorkflows(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -203,7 +203,7 @@ func Test_RunWorkflow(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -299,7 +299,7 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -407,7 +407,7 @@ func Test_CancelWorkflowRun(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -537,7 +537,7 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -627,7 +627,7 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -713,7 +713,7 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -817,7 +817,7 @@ func Test_GetWorkflowRunUsage(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -1082,7 +1082,7 @@ func Test_GetJobLogs(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -1149,7 +1149,7 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) { "return_content": true, }) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -1202,7 +1202,7 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { "tail_lines": float64(1), // Requesting last 1 line }) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -1254,7 +1254,7 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { "tail_lines": float64(100), }) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.False(t, result.IsError) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 632fcddf9..5e25d0501 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -44,51 +44,49 @@ func GetCodeScanningAlert(t translations.TranslationHelperFunc) inventory.Server Required: []string{"owner", "repo", "alertNumber"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - alertNumber, err := RequiredInt(args, "alertNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + alertNumber, err := RequiredInt(args, "alertNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get alert", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil - } + alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get alert", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(alert) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(alert) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -137,62 +135,60 @@ func ListCodeScanningAlerts(t translations.TranslationHelperFunc) inventory.Serv Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ref, err := OptionalParam[string](args, "ref") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - severity, err := OptionalParam[string](args, "severity") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - toolName, err := OptionalParam[string](args, "tool_name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName}) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list alerts", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ref, err := OptionalParam[string](args, "ref") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + severity, err := OptionalParam[string](args, "severity") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + toolName, err := OptionalParam[string](args, "tool_name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName}) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list alerts", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(alerts) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(alerts) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index ec1f71035..59972fe52 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -90,7 +90,7 @@ func Test_GetCodeScanningAlert(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler with new signature - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -216,7 +216,7 @@ func Test_ListCodeScanningAlerts(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler with new signature - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 781b1f3cd..e0df82c88 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -51,51 +51,49 @@ func GetMe(t translations.TranslationHelperFunc) inventory.ServerTool { // OpenAI strict mode requires the properties field to be present. InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - user, res, err := client.Users.Get(ctx, "") - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get user", - res, - err, - ), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Create minimal user representation instead of returning full user object - minimalUser := MinimalUser{ - Login: user.GetLogin(), - ID: user.GetID(), - ProfileURL: user.GetHTMLURL(), - AvatarURL: user.GetAvatarURL(), - Details: &UserDetails{ - Name: user.GetName(), - Company: user.GetCompany(), - Blog: user.GetBlog(), - Location: user.GetLocation(), - Email: user.GetEmail(), - Hireable: user.GetHireable(), - Bio: user.GetBio(), - TwitterUsername: user.GetTwitterUsername(), - PublicRepos: user.GetPublicRepos(), - PublicGists: user.GetPublicGists(), - Followers: user.GetFollowers(), - Following: user.GetFollowing(), - CreatedAt: user.GetCreatedAt().Time, - UpdatedAt: user.GetUpdatedAt().Time, - PrivateGists: user.GetPrivateGists(), - TotalPrivateRepos: user.GetTotalPrivateRepos(), - OwnedPrivateRepos: user.GetOwnedPrivateRepos(), - }, - } + user, res, err := client.Users.Get(ctx, "") + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil, nil + } - return MarshalledTextResult(minimalUser), nil, nil + // Create minimal user representation instead of returning full user object + minimalUser := MinimalUser{ + Login: user.GetLogin(), + ID: user.GetID(), + ProfileURL: user.GetHTMLURL(), + AvatarURL: user.GetAvatarURL(), + Details: &UserDetails{ + Name: user.GetName(), + Company: user.GetCompany(), + Blog: user.GetBlog(), + Location: user.GetLocation(), + Email: user.GetEmail(), + Hireable: user.GetHireable(), + Bio: user.GetBio(), + TwitterUsername: user.GetTwitterUsername(), + PublicRepos: user.GetPublicRepos(), + PublicGists: user.GetPublicGists(), + Followers: user.GetFollowers(), + Following: user.GetFollowing(), + CreatedAt: user.GetCreatedAt().Time, + UpdatedAt: user.GetUpdatedAt().Time, + PrivateGists: user.GetPrivateGists(), + TotalPrivateRepos: user.GetTotalPrivateRepos(), + OwnedPrivateRepos: user.GetOwnedPrivateRepos(), + }, } + + return MarshalledTextResult(minimalUser), nil, nil }, ) } @@ -131,81 +129,79 @@ func GetTeams(t translations.TranslationHelperFunc) inventory.ServerTool { }, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - user, err := OptionalParam[string](args, "user") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - var username string - if user != "" { - username = user - } else { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + user, err := OptionalParam[string](args, "user") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - userResp, res, err := client.Users.Get(ctx, "") - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get user", - res, - err, - ), nil, nil - } - username = userResp.GetLogin() + var username string + if user != "" { + username = user + } else { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - gqlClient, err := deps.GetGQLClient(ctx) + userResp, res, err := client.Users.Get(ctx, "") if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil, nil } + username = userResp.GetLogin() + } - var q struct { - User struct { - Organizations struct { - Nodes []struct { - Login githubv4.String - Teams struct { - Nodes []struct { - Name githubv4.String - Slug githubv4.String - Description githubv4.String - } - } `graphql:"teams(first: 100, userLogins: [$login])"` - } - } `graphql:"organizations(first: 100)"` - } `graphql:"user(login: $login)"` - } - vars := map[string]interface{}{ - "login": githubv4.String(username), - } - if err := gqlClient.Query(ctx, &q, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find teams", err), nil, nil - } + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } - var organizations []OrganizationTeams - for _, org := range q.User.Organizations.Nodes { - orgTeams := OrganizationTeams{ - Org: string(org.Login), - Teams: make([]TeamInfo, 0, len(org.Teams.Nodes)), - } + var q struct { + User struct { + Organizations struct { + Nodes []struct { + Login githubv4.String + Teams struct { + Nodes []struct { + Name githubv4.String + Slug githubv4.String + Description githubv4.String + } + } `graphql:"teams(first: 100, userLogins: [$login])"` + } + } `graphql:"organizations(first: 100)"` + } `graphql:"user(login: $login)"` + } + vars := map[string]interface{}{ + "login": githubv4.String(username), + } + if err := gqlClient.Query(ctx, &q, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find teams", err), nil, nil + } - for _, team := range org.Teams.Nodes { - orgTeams.Teams = append(orgTeams.Teams, TeamInfo{ - Name: string(team.Name), - Slug: string(team.Slug), - Description: string(team.Description), - }) - } + var organizations []OrganizationTeams + for _, org := range q.User.Organizations.Nodes { + orgTeams := OrganizationTeams{ + Org: string(org.Login), + Teams: make([]TeamInfo, 0, len(org.Teams.Nodes)), + } - organizations = append(organizations, orgTeams) + for _, team := range org.Teams.Nodes { + orgTeams.Teams = append(orgTeams.Teams, TeamInfo{ + Name: string(team.Name), + Slug: string(team.Slug), + Description: string(team.Description), + }) } - return MarshalledTextResult(organizations), nil, nil + organizations = append(organizations, orgTeams) } + + return MarshalledTextResult(organizations), nil, nil }, ) } @@ -235,49 +231,47 @@ func GetTeamMembers(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"org", "team_slug"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - org, err := RequiredParam[string](args, "org") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - teamSlug, err := RequiredParam[string](args, "team_slug") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + org, err := RequiredParam[string](args, "org") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - gqlClient, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + teamSlug, err := RequiredParam[string](args, "team_slug") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var q struct { - Organization struct { - Team struct { - Members struct { - Nodes []struct { - Login githubv4.String - } - } `graphql:"members(first: 100)"` - } `graphql:"team(slug: $teamSlug)"` - } `graphql:"organization(login: $org)"` - } - vars := map[string]interface{}{ - "org": githubv4.String(org), - "teamSlug": githubv4.String(teamSlug), - } - if err := gqlClient.Query(ctx, &q, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to get team members", err), nil, nil - } + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } - var members []string - for _, member := range q.Organization.Team.Members.Nodes { - members = append(members, string(member.Login)) - } + var q struct { + Organization struct { + Team struct { + Members struct { + Nodes []struct { + Login githubv4.String + } + } `graphql:"members(first: 100)"` + } `graphql:"team(slug: $teamSlug)"` + } `graphql:"organization(login: $org)"` + } + vars := map[string]interface{}{ + "org": githubv4.String(org), + "teamSlug": githubv4.String(teamSlug), + } + if err := gqlClient.Query(ctx, &q, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to get team members", err), nil, nil + } - return MarshalledTextResult(members), nil, nil + var members []string + for _, member := range q.Organization.Team.Members.Nodes { + members = append(members, string(member.Login)) } + + return MarshalledTextResult(members), nil, nil }, ) } diff --git a/pkg/github/context_tools_test.go b/pkg/github/context_tools_test.go index e9faefc40..5f471ff37 100644 --- a/pkg/github/context_tools_test.go +++ b/pkg/github/context_tools_test.go @@ -113,7 +113,7 @@ func Test_GetMe(t *testing.T) { handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectToolError { @@ -353,10 +353,11 @@ func Test_GetTeams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - handler := serverTool.Handler(tc.makeDeps()) + deps := tc.makeDeps() + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectToolError { @@ -499,7 +500,7 @@ func Test_GetTeamMembers(t *testing.T) { handler := serverTool.Handler(tc.deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), tc.deps), &request) require.NoError(t, err) if tc.expectToolError { diff --git a/pkg/github/dependabot.go b/pkg/github/dependabot.go index c1a4ce46c..db6352dab 100644 --- a/pkg/github/dependabot.go +++ b/pkg/github/dependabot.go @@ -45,51 +45,49 @@ func GetDependabotAlert(t translations.TranslationHelperFunc) inventory.ServerTo Required: []string{"owner", "repo", "alertNumber"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - alertNumber, err := RequiredInt(args, "alertNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + alertNumber, err := RequiredInt(args, "alertNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - alert, resp, err := client.Dependabot.GetRepoAlert(ctx, owner, repo, alertNumber) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get alert with number '%d'", alertNumber), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil - } + alert, resp, err := client.Dependabot.GetRepoAlert(ctx, owner, repo, alertNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get alert with number '%d'", alertNumber), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(alert) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, err + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(alert) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, err } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -130,58 +128,56 @@ func ListDependabotAlerts(t translations.TranslationHelperFunc) inventory.Server Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - severity, err := OptionalParam[string](args, "severity") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + severity, err := OptionalParam[string](args, "severity") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - alerts, resp, err := client.Dependabot.ListRepoAlerts(ctx, owner, repo, &github.ListAlertsOptions{ - State: ToStringPtr(state), - Severity: ToStringPtr(severity), - }) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil - } + alerts, resp, err := client.Dependabot.ListRepoAlerts(ctx, owner, repo, &github.ListAlertsOptions{ + State: ToStringPtr(state), + Severity: ToStringPtr(severity), + }) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(alerts) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, err + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(alerts) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, err } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } diff --git a/pkg/github/dependabot_test.go b/pkg/github/dependabot_test.go index 614c6f383..0ceac4ffa 100644 --- a/pkg/github/dependabot_test.go +++ b/pkg/github/dependabot_test.go @@ -88,7 +88,7 @@ func Test_GetDependabotAlert(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -237,7 +237,7 @@ func Test_ListDependabotAlerts(t *testing.T) { request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 4c634076d..d23e993c3 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -2,6 +2,7 @@ package github import ( "context" + "errors" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" @@ -12,6 +13,42 @@ import ( "github.com/shurcooL/githubv4" ) +// depsContextKey is the context key for ToolDependencies. +// Using a private type prevents collisions with other packages. +type depsContextKey struct{} + +// ErrDepsNotInContext is returned when ToolDependencies is not found in context. +var ErrDepsNotInContext = errors.New("ToolDependencies not found in context; use ContextWithDeps to inject") + +// ContextWithDeps returns a new context with the ToolDependencies stored in it. +// This is used to inject dependencies at request time rather than at registration time, +// avoiding expensive closure creation during server initialization. +// +// For the local server, this is called once at startup since deps don't change. +// For the remote server, this is called per-request with request-specific deps. +func ContextWithDeps(ctx context.Context, deps ToolDependencies) context.Context { + return context.WithValue(ctx, depsContextKey{}, deps) +} + +// DepsFromContext retrieves ToolDependencies from the context. +// Returns the deps and true if found, or nil and false if not present. +// Use MustDepsFromContext if you want to panic on missing deps (for handlers +// that require deps to function). +func DepsFromContext(ctx context.Context) (ToolDependencies, bool) { + deps, ok := ctx.Value(depsContextKey{}).(ToolDependencies) + return deps, ok +} + +// MustDepsFromContext retrieves ToolDependencies from the context. +// Panics if deps are not found - use this in handlers where deps are required. +func MustDepsFromContext(ctx context.Context) ToolDependencies { + deps, ok := DepsFromContext(ctx) + if !ok { + panic(ErrDepsNotInContext) + } + return deps +} + // ToolDependencies defines the interface for dependencies that tool handlers need. // This is an interface to allow different implementations: // - Local server: stores closures that create clients on demand @@ -105,19 +142,27 @@ func (d BaseDeps) GetFlags() FeatureFlags { return d.Flags } // GetContentWindowSize implements ToolDependencies. func (d BaseDeps) GetContentWindowSize() int { return d.ContentWindowSize } -// NewTool creates a ServerTool with fully-typed ToolDependencies and toolset metadata. -// This helper isolates the type assertion from `any` to `ToolDependencies`, -// so tool implementations remain fully typed without assertions scattered throughout. -func NewTool[In, Out any](toolset inventory.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) inventory.ServerTool { - return inventory.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[In, Out] { - return handler(d.(ToolDependencies)) +// NewTool creates a ServerTool that retrieves ToolDependencies from context at call time. +// This avoids creating closures at registration time, which is important for performance +// in servers that create a new server instance per request (like the remote server). +// +// The handler function receives deps extracted from context via MustDepsFromContext. +// Ensure ContextWithDeps is called to inject deps before any tool handlers are invoked. +func NewTool[In, Out any](toolset inventory.ToolsetMetadata, tool mcp.Tool, handler func(ctx context.Context, deps ToolDependencies, req *mcp.CallToolRequest, args In) (*mcp.CallToolResult, Out, error)) inventory.ServerTool { + return inventory.NewServerToolWithContextHandler(tool, toolset, func(ctx context.Context, req *mcp.CallToolRequest, args In) (*mcp.CallToolResult, Out, error) { + deps := MustDepsFromContext(ctx) + return handler(ctx, deps, req, args) }) } -// NewToolFromHandler creates a ServerTool with fully-typed ToolDependencies and toolset metadata -// for handlers that conform to mcp.ToolHandler directly. -func NewToolFromHandler(toolset inventory.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) inventory.ServerTool { - return inventory.NewServerToolFromHandler(tool, toolset, func(d any) mcp.ToolHandler { - return handler(d.(ToolDependencies)) +// NewToolFromHandler creates a ServerTool that retrieves ToolDependencies from context at call time. +// Use this when you have a handler that conforms to mcp.ToolHandler directly. +// +// The handler function receives deps extracted from context via MustDepsFromContext. +// Ensure ContextWithDeps is called to inject deps before any tool handlers are invoked. +func NewToolFromHandler(toolset inventory.ToolsetMetadata, tool mcp.Tool, handler func(ctx context.Context, deps ToolDependencies, req *mcp.CallToolRequest) (*mcp.CallToolResult, error)) inventory.ServerTool { + return inventory.NewServerToolWithRawContextHandler(tool, toolset, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + deps := MustDepsFromContext(ctx) + return handler(ctx, deps, req) }) } diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index b79d70e9b..c891ba294 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -161,117 +161,115 @@ func ListDiscussions(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := OptionalParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // when not provided, default to the .github repository - // this will query discussions at the organisation level - if repo == "" { - repo = ".github" - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // when not provided, default to the .github repository + // this will query discussions at the organisation level + if repo == "" { + repo = ".github" + } - category, err := OptionalParam[string](args, "category") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + category, err := OptionalParam[string](args, "category") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - orderBy, err := OptionalParam[string](args, "orderBy") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + orderBy, err := OptionalParam[string](args, "orderBy") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get pagination parameters and convert to GraphQL format - pagination, err := OptionalCursorPaginationParams(args) - if err != nil { - return nil, nil, err - } - paginationParams, err := pagination.ToGraphQLParams() - if err != nil { - return nil, nil, err - } + // Get pagination parameters and convert to GraphQL format + pagination, err := OptionalCursorPaginationParams(args) + if err != nil { + return nil, nil, err + } + paginationParams, err := pagination.ToGraphQLParams() + if err != nil { + return nil, nil, err + } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - var categoryID *githubv4.ID - if category != "" { - id := githubv4.ID(category) - categoryID = &id - } + var categoryID *githubv4.ID + if category != "" { + id := githubv4.ID(category) + categoryID = &id + } - vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "first": githubv4.Int(*paginationParams.First), - } - if paginationParams.After != nil { - vars["after"] = githubv4.String(*paginationParams.After) - } else { - vars["after"] = (*githubv4.String)(nil) - } + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "first": githubv4.Int(*paginationParams.First), + } + if paginationParams.After != nil { + vars["after"] = githubv4.String(*paginationParams.After) + } else { + vars["after"] = (*githubv4.String)(nil) + } - // this is an extra check in case the tool description is misinterpreted, because - // we shouldn't use ordering unless both a 'field' and 'direction' are provided - useOrdering := orderBy != "" && direction != "" - if useOrdering { - vars["orderByField"] = githubv4.DiscussionOrderField(orderBy) - vars["orderByDirection"] = githubv4.OrderDirection(direction) - } + // this is an extra check in case the tool description is misinterpreted, because + // we shouldn't use ordering unless both a 'field' and 'direction' are provided + useOrdering := orderBy != "" && direction != "" + if useOrdering { + vars["orderByField"] = githubv4.DiscussionOrderField(orderBy) + vars["orderByDirection"] = githubv4.OrderDirection(direction) + } - if categoryID != nil { - vars["categoryId"] = *categoryID - } + if categoryID != nil { + vars["categoryId"] = *categoryID + } - discussionQuery := getQueryType(useOrdering, categoryID) - if err := client.Query(ctx, discussionQuery, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + discussionQuery := getQueryType(useOrdering, categoryID) + if err := client.Query(ctx, discussionQuery, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Extract and convert all discussion nodes using the common interface - var discussions []*github.Discussion - var pageInfo PageInfoFragment - var totalCount githubv4.Int - if queryResult, ok := discussionQuery.(DiscussionQueryResult); ok { - fragment := queryResult.GetDiscussionFragment() - for _, node := range fragment.Nodes { - discussions = append(discussions, fragmentToDiscussion(node)) - } - pageInfo = fragment.PageInfo - totalCount = fragment.TotalCount - } + // Extract and convert all discussion nodes using the common interface + var discussions []*github.Discussion + var pageInfo PageInfoFragment + var totalCount githubv4.Int + if queryResult, ok := discussionQuery.(DiscussionQueryResult); ok { + fragment := queryResult.GetDiscussionFragment() + for _, node := range fragment.Nodes { + discussions = append(discussions, fragmentToDiscussion(node)) + } + pageInfo = fragment.PageInfo + totalCount = fragment.TotalCount + } - // Create response with pagination info - response := map[string]interface{}{ - "discussions": discussions, - "pageInfo": map[string]interface{}{ - "hasNextPage": pageInfo.HasNextPage, - "hasPreviousPage": pageInfo.HasPreviousPage, - "startCursor": string(pageInfo.StartCursor), - "endCursor": string(pageInfo.EndCursor), - }, - "totalCount": totalCount, - } + // Create response with pagination info + response := map[string]interface{}{ + "discussions": discussions, + "pageInfo": map[string]interface{}{ + "hasNextPage": pageInfo.HasNextPage, + "hasPreviousPage": pageInfo.HasPreviousPage, + "startCursor": string(pageInfo.StartCursor), + "endCursor": string(pageInfo.EndCursor), + }, + "totalCount": totalCount, + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal discussions: %w", err) - } - return utils.NewToolResultText(string(out)), nil, nil + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal discussions: %w", err) } + return utils.NewToolResultText(string(out)), nil, nil }, ) } @@ -305,78 +303,76 @@ func GetDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "discussionNumber"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Decode params - var params struct { - Owner string - Repo string - DiscussionNumber int32 - } - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } - - var q struct { - Repository struct { - Discussion struct { - Number githubv4.Int - Title githubv4.String - Body githubv4.String - CreatedAt githubv4.DateTime - Closed githubv4.Boolean - IsAnswered githubv4.Boolean - AnswerChosenAt *githubv4.DateTime - URL githubv4.String `graphql:"url"` - Category struct { - Name githubv4.String - } `graphql:"category"` - } `graphql:"discussion(number: $discussionNumber)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - vars := map[string]interface{}{ - "owner": githubv4.String(params.Owner), - "repo": githubv4.String(params.Repo), - "discussionNumber": githubv4.Int(params.DiscussionNumber), - } - if err := client.Query(ctx, &q, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - d := q.Repository.Discussion - - // Build response as map to include fields not present in go-github's Discussion struct. - // The go-github library's Discussion type lacks isAnswered and answerChosenAt fields, - // so we use map[string]interface{} for the response (consistent with other functions - // like ListDiscussions and GetDiscussionComments). - response := map[string]interface{}{ - "number": int(d.Number), - "title": string(d.Title), - "body": string(d.Body), - "url": string(d.URL), - "closed": bool(d.Closed), - "isAnswered": bool(d.IsAnswered), - "createdAt": d.CreatedAt.Time, - "category": map[string]interface{}{ - "name": string(d.Category.Name), - }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + // Decode params + var params struct { + Owner string + Repo string + DiscussionNumber int32 + } + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - // Add optional timestamp fields if present - if d.AnswerChosenAt != nil { - response["answerChosenAt"] = d.AnswerChosenAt.Time - } + var q struct { + Repository struct { + Discussion struct { + Number githubv4.Int + Title githubv4.String + Body githubv4.String + CreatedAt githubv4.DateTime + Closed githubv4.Boolean + IsAnswered githubv4.Boolean + AnswerChosenAt *githubv4.DateTime + URL githubv4.String `graphql:"url"` + Category struct { + Name githubv4.String + } `graphql:"category"` + } `graphql:"discussion(number: $discussionNumber)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "discussionNumber": githubv4.Int(params.DiscussionNumber), + } + if err := client.Query(ctx, &q, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + d := q.Repository.Discussion + + // Build response as map to include fields not present in go-github's Discussion struct. + // The go-github library's Discussion type lacks isAnswered and answerChosenAt fields, + // so we use map[string]interface{} for the response (consistent with other functions + // like ListDiscussions and GetDiscussionComments). + response := map[string]interface{}{ + "number": int(d.Number), + "title": string(d.Title), + "body": string(d.Body), + "url": string(d.URL), + "closed": bool(d.Closed), + "isAnswered": bool(d.IsAnswered), + "createdAt": d.CreatedAt.Time, + "category": map[string]interface{}{ + "name": string(d.Category.Name), + }, + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal discussion: %w", err) - } + // Add optional timestamp fields if present + if d.AnswerChosenAt != nil { + response["answerChosenAt"] = d.AnswerChosenAt.Time + } - return utils.NewToolResultText(string(out)), nil, nil + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal discussion: %w", err) } + + return utils.NewToolResultText(string(out)), nil, nil }, ) } @@ -410,101 +406,99 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.Serve Required: []string{"owner", "repo", "discussionNumber"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Decode params - var params struct { - Owner string - Repo string - DiscussionNumber int32 - } - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - // Get pagination parameters and convert to GraphQL format - pagination, err := OptionalCursorPaginationParams(args) - if err != nil { - return nil, nil, err - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + // Decode params + var params struct { + Owner string + Repo string + DiscussionNumber int32 + } + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Check if pagination parameters were explicitly provided - _, perPageProvided := args["perPage"] - paginationExplicit := perPageProvided + // Get pagination parameters and convert to GraphQL format + pagination, err := OptionalCursorPaginationParams(args) + if err != nil { + return nil, nil, err + } - paginationParams, err := pagination.ToGraphQLParams() - if err != nil { - return nil, nil, err - } + // Check if pagination parameters were explicitly provided + _, perPageProvided := args["perPage"] + paginationExplicit := perPageProvided - // Use default of 30 if pagination was not explicitly provided - if !paginationExplicit { - defaultFirst := int32(DefaultGraphQLPageSize) - paginationParams.First = &defaultFirst - } + paginationParams, err := pagination.ToGraphQLParams() + if err != nil { + return nil, nil, err + } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + // Use default of 30 if pagination was not explicitly provided + if !paginationExplicit { + defaultFirst := int32(DefaultGraphQLPageSize) + paginationParams.First = &defaultFirst + } - var q struct { - Repository struct { - Discussion struct { - Comments struct { - Nodes []struct { - Body githubv4.String - } - PageInfo struct { - HasNextPage githubv4.Boolean - HasPreviousPage githubv4.Boolean - StartCursor githubv4.String - EndCursor githubv4.String - } - TotalCount int - } `graphql:"comments(first: $first, after: $after)"` - } `graphql:"discussion(number: $discussionNumber)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - vars := map[string]interface{}{ - "owner": githubv4.String(params.Owner), - "repo": githubv4.String(params.Repo), - "discussionNumber": githubv4.Int(params.DiscussionNumber), - "first": githubv4.Int(*paginationParams.First), - } - if paginationParams.After != nil { - vars["after"] = githubv4.String(*paginationParams.After) - } else { - vars["after"] = (*githubv4.String)(nil) - } - if err := client.Query(ctx, &q, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - var comments []*github.IssueComment - for _, c := range q.Repository.Discussion.Comments.Nodes { - comments = append(comments, &github.IssueComment{Body: github.Ptr(string(c.Body))}) - } + var q struct { + Repository struct { + Discussion struct { + Comments struct { + Nodes []struct { + Body githubv4.String + } + PageInfo struct { + HasNextPage githubv4.Boolean + HasPreviousPage githubv4.Boolean + StartCursor githubv4.String + EndCursor githubv4.String + } + TotalCount int + } `graphql:"comments(first: $first, after: $after)"` + } `graphql:"discussion(number: $discussionNumber)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "discussionNumber": githubv4.Int(params.DiscussionNumber), + "first": githubv4.Int(*paginationParams.First), + } + if paginationParams.After != nil { + vars["after"] = githubv4.String(*paginationParams.After) + } else { + vars["after"] = (*githubv4.String)(nil) + } + if err := client.Query(ctx, &q, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Create response with pagination info - response := map[string]interface{}{ - "comments": comments, - "pageInfo": map[string]interface{}{ - "hasNextPage": q.Repository.Discussion.Comments.PageInfo.HasNextPage, - "hasPreviousPage": q.Repository.Discussion.Comments.PageInfo.HasPreviousPage, - "startCursor": string(q.Repository.Discussion.Comments.PageInfo.StartCursor), - "endCursor": string(q.Repository.Discussion.Comments.PageInfo.EndCursor), - }, - "totalCount": q.Repository.Discussion.Comments.TotalCount, - } + var comments []*github.IssueComment + for _, c := range q.Repository.Discussion.Comments.Nodes { + comments = append(comments, &github.IssueComment{Body: github.Ptr(string(c.Body))}) + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal comments: %w", err) - } + // Create response with pagination info + response := map[string]interface{}{ + "comments": comments, + "pageInfo": map[string]interface{}{ + "hasNextPage": q.Repository.Discussion.Comments.PageInfo.HasNextPage, + "hasPreviousPage": q.Repository.Discussion.Comments.PageInfo.HasPreviousPage, + "startCursor": string(q.Repository.Discussion.Comments.PageInfo.StartCursor), + "endCursor": string(q.Repository.Discussion.Comments.PageInfo.EndCursor), + }, + "totalCount": q.Repository.Discussion.Comments.TotalCount, + } - return utils.NewToolResultText(string(out)), nil, nil + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal comments: %w", err) } + + return utils.NewToolResultText(string(out)), nil, nil }, ) } @@ -534,79 +528,77 @@ func ListDiscussionCategories(t translations.TranslationHelperFunc) inventory.Se Required: []string{"owner"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := OptionalParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // when not provided, default to the .github repository - // this will query discussion categories at the organisation level - if repo == "" { - repo = ".github" - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // when not provided, default to the .github repository + // this will query discussion categories at the organisation level + if repo == "" { + repo = ".github" + } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - var q struct { - Repository struct { - DiscussionCategories struct { - Nodes []struct { - ID githubv4.ID - Name githubv4.String - } - PageInfo struct { - HasNextPage githubv4.Boolean - HasPreviousPage githubv4.Boolean - StartCursor githubv4.String - EndCursor githubv4.String - } - TotalCount int - } `graphql:"discussionCategories(first: $first)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "first": githubv4.Int(25), - } - if err := client.Query(ctx, &q, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + var q struct { + Repository struct { + DiscussionCategories struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + } + PageInfo struct { + HasNextPage githubv4.Boolean + HasPreviousPage githubv4.Boolean + StartCursor githubv4.String + EndCursor githubv4.String + } + TotalCount int + } `graphql:"discussionCategories(first: $first)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "first": githubv4.Int(25), + } + if err := client.Query(ctx, &q, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var categories []map[string]string - for _, c := range q.Repository.DiscussionCategories.Nodes { - categories = append(categories, map[string]string{ - "id": fmt.Sprint(c.ID), - "name": string(c.Name), - }) - } + var categories []map[string]string + for _, c := range q.Repository.DiscussionCategories.Nodes { + categories = append(categories, map[string]string{ + "id": fmt.Sprint(c.ID), + "name": string(c.Name), + }) + } - // Create response with pagination info - response := map[string]interface{}{ - "categories": categories, - "pageInfo": map[string]interface{}{ - "hasNextPage": q.Repository.DiscussionCategories.PageInfo.HasNextPage, - "hasPreviousPage": q.Repository.DiscussionCategories.PageInfo.HasPreviousPage, - "startCursor": string(q.Repository.DiscussionCategories.PageInfo.StartCursor), - "endCursor": string(q.Repository.DiscussionCategories.PageInfo.EndCursor), - }, - "totalCount": q.Repository.DiscussionCategories.TotalCount, - } + // Create response with pagination info + response := map[string]interface{}{ + "categories": categories, + "pageInfo": map[string]interface{}{ + "hasNextPage": q.Repository.DiscussionCategories.PageInfo.HasNextPage, + "hasPreviousPage": q.Repository.DiscussionCategories.PageInfo.HasPreviousPage, + "startCursor": string(q.Repository.DiscussionCategories.PageInfo.StartCursor), + "endCursor": string(q.Repository.DiscussionCategories.PageInfo.EndCursor), + }, + "totalCount": q.Repository.DiscussionCategories.TotalCount, + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal discussion categories: %w", err) - } - return utils.NewToolResultText(string(out)), nil, nil + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal discussion categories: %w", err) } + return utils.NewToolResultText(string(out)), nil, nil }, ) } diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 73ae66748..0ec998280 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -451,7 +451,7 @@ func Test_ListDiscussions(t *testing.T) { handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) - res, err := handler(context.Background(), &req) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -564,7 +564,7 @@ func Test_GetDiscussion(t *testing.T) { reqParams := map[string]interface{}{"owner": "owner", "repo": "repo", "discussionNumber": int32(1)} req := createMCPRequest(reqParams) - res, err := handler(context.Background(), &req) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -649,7 +649,7 @@ func Test_GetDiscussionComments(t *testing.T) { } request := createMCPRequest(reqParams) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -795,7 +795,7 @@ func Test_ListDiscussionCategories(t *testing.T) { handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) - res, err := handler(context.Background(), &req) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) text := getTextResult(t, res).Text if tc.expectError { diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index 9ec615227..5c7d31d4e 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -27,7 +27,10 @@ type DynamicToolDependencies struct { } // NewDynamicTool creates a ServerTool with fully-typed DynamicToolDependencies. +// Dynamic tools use a different dependency structure (DynamicToolDependencies) than regular +// tools (ToolDependencies), so they intentionally use the closure pattern. func NewDynamicTool(toolset inventory.ToolsetMetadata, tool mcp.Tool, handler func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any]) inventory.ServerTool { + //nolint:staticcheck // SA1019: Dynamic tools use a different deps structure, closure pattern is intentional return inventory.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[map[string]any, any] { return handler(d.(DynamicToolDependencies)) }) diff --git a/pkg/github/gists.go b/pkg/github/gists.go index 511d7ea89..4d741b88d 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -41,65 +41,63 @@ func ListGists(t translations.TranslationHelperFunc) inventory.ServerTool { }, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - username, err := OptionalParam[string](args, "username") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + username, err := OptionalParam[string](args, "username") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + since, err := OptionalParam[string](args, "since") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.GistListOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Parse since timestamp if provided - if since != "" { - sinceTime, err := parseISOTimestamp(since) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil - } - opts.Since = sinceTime - } + opts := &github.GistListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } - client, err := deps.GetClient(ctx) + // Parse since timestamp if provided + if since != "" { + sinceTime, err := parseISOTimestamp(since) if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil } + opts.Since = sinceTime + } - gists, resp, err := client.Gists.List(ctx, username, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list gists", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list gists", resp, body), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + gists, resp, err := client.Gists.List(ctx, username, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list gists", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(gists) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list gists", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(gists) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -126,39 +124,37 @@ func GetGist(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"gist_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - gist, resp, err := client.Gists.Get(ctx, gistID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get gist", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get gist", resp, body), nil, nil - } + gist, resp, err := client.Gists.Get(ctx, gistID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get gist", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(gist) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get gist", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -198,71 +194,69 @@ func CreateGist(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"filename", "content"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - public, err := OptionalParam[bool](args, "public") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } + public, err := OptionalParam[bool](args, "public") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - gist := &github.Gist{ - Files: files, - Public: github.Ptr(public), - Description: github.Ptr(description), - } + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + gist := &github.Gist{ + Files: files, + Public: github.Ptr(public), + Description: github.Ptr(description), + } - createdGist, resp, err := client.Gists.Create(ctx, gist) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create gist", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create gist", resp, body), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - minimalResponse := MinimalResponse{ - ID: createdGist.GetID(), - URL: createdGist.GetHTMLURL(), - } + createdGist, resp, err := client.Gists.Create(ctx, gist) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create gist", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalResponse) + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create gist", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + minimalResponse := MinimalResponse{ + ID: createdGist.GetID(), + URL: createdGist.GetHTMLURL(), } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -301,70 +295,68 @@ func UpdateGist(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"gist_id", "filename", "content"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - gist := &github.Gist{ - Files: files, - Description: github.Ptr(description), - } + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + gist := &github.Gist{ + Files: files, + Description: github.Ptr(description), + } - updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to update gist", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update gist", resp, body), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - minimalResponse := MinimalResponse{ - ID: updatedGist.GetID(), - URL: updatedGist.GetHTMLURL(), - } + updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to update gist", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalResponse) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update gist", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + minimalResponse := MinimalResponse{ + ID: updatedGist.GetID(), + URL: updatedGist.GetHTMLURL(), } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index 7c6f69833..886db4a1a 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -167,7 +167,7 @@ func Test_ListGists(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) // Verify results @@ -284,7 +284,7 @@ func Test_GetGist(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) // Verify results @@ -430,7 +430,7 @@ func Test_CreateGist(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) // Verify results @@ -589,7 +589,7 @@ func Test_UpdateGist(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) // Verify results diff --git a/pkg/github/git.go b/pkg/github/git.go index 09d63cb9f..7b93c3675 100644 --- a/pkg/github/git.go +++ b/pkg/github/git.go @@ -76,102 +76,100 @@ func GetRepositoryTree(t translations.TranslationHelperFunc) inventory.ServerToo Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - treeSHA, err := OptionalParam[string](args, "tree_sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - recursive, err := OptionalBoolParamWithDefault(args, "recursive", false) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pathFilter, err := OptionalParam[string](args, "path_filter") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError("failed to get GitHub client"), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + treeSHA, err := OptionalParam[string](args, "tree_sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + recursive, err := OptionalBoolParamWithDefault(args, "recursive", false) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pathFilter, err := OptionalParam[string](args, "path_filter") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // If no tree_sha is provided, use the repository's default branch - if treeSHA == "" { - repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get repository info", - repoResp, - err, - ), nil, nil - } - treeSHA = *repoInfo.DefaultBranch - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError("failed to get GitHub client"), nil, nil + } - // Get the tree using the GitHub Git Tree API - tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive) + // If no tree_sha is provided, use the repository's default branch + if treeSHA == "" { + repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get repository tree", - resp, + "failed to get repository info", + repoResp, err, ), nil, nil } - defer func() { _ = resp.Body.Close() }() + treeSHA = *repoInfo.DefaultBranch + } - // Filter tree entries if path_filter is provided - var filteredEntries []*github.TreeEntry - if pathFilter != "" { - for _, entry := range tree.Entries { - if strings.HasPrefix(entry.GetPath(), pathFilter) { - filteredEntries = append(filteredEntries, entry) - } - } - } else { - filteredEntries = tree.Entries - } + // Get the tree using the GitHub Git Tree API + tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get repository tree", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - treeEntries := make([]TreeEntryResponse, len(filteredEntries)) - for i, entry := range filteredEntries { - treeEntries[i] = TreeEntryResponse{ - Path: entry.GetPath(), - Type: entry.GetType(), - Mode: entry.GetMode(), - SHA: entry.GetSHA(), - URL: entry.GetURL(), - } - if entry.Size != nil { - treeEntries[i].Size = entry.Size + // Filter tree entries if path_filter is provided + var filteredEntries []*github.TreeEntry + if pathFilter != "" { + for _, entry := range tree.Entries { + if strings.HasPrefix(entry.GetPath(), pathFilter) { + filteredEntries = append(filteredEntries, entry) } } + } else { + filteredEntries = tree.Entries + } - response := TreeResponse{ - SHA: *tree.SHA, - Truncated: *tree.Truncated, - Tree: treeEntries, - TreeSHA: treeSHA, - Owner: owner, - Repo: repo, - Recursive: recursive, - Count: len(filteredEntries), + treeEntries := make([]TreeEntryResponse, len(filteredEntries)) + for i, entry := range filteredEntries { + treeEntries[i] = TreeEntryResponse{ + Path: entry.GetPath(), + Type: entry.GetType(), + Mode: entry.GetMode(), + SHA: entry.GetSHA(), + URL: entry.GetURL(), } - - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + if entry.Size != nil { + treeEntries[i].Size = entry.Size } + } + + response := TreeResponse{ + SHA: *tree.SHA, + Truncated: *tree.Truncated, + Tree: treeEntries, + TreeSHA: treeSHA, + Owner: owner, + Repo: repo, + Recursive: recursive, + Count: len(filteredEntries), + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } diff --git a/pkg/github/git_test.go b/pkg/github/git_test.go index 9fb023f4b..d60aed092 100644 --- a/pkg/github/git_test.go +++ b/pkg/github/git_test.go @@ -134,7 +134,7 @@ func Test_GetRepositoryTree(t *testing.T) { // Create the tool request request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index d9bb2818b..f06dc2d9d 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -274,57 +274,55 @@ Options are: }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - issueNumber, err := RequiredInt(args, "issue_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + issueNumber, err := RequiredInt(args, "issue_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - gqlClient, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub graphql client", err), nil, nil - } + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub graphql client", err), nil, nil + } - switch method { - case "get": - result, err := GetIssue(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, deps.GetFlags()) - return result, nil, err - case "get_comments": - result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) - return result, nil, err - case "get_sub_issues": - result, err := GetSubIssues(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) - return result, nil, err - case "get_labels": - result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) - return result, nil, err - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil - } + switch method { + case "get": + result, err := GetIssue(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, deps.GetFlags()) + return result, nil, err + case "get_comments": + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) + return result, nil, err + case "get_sub_issues": + result, err := GetSubIssues(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) + return result, nil, err + case "get_labels": + result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) + return result, nil, err + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil } }) } @@ -567,38 +565,36 @@ func ListIssueTypes(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - issueTypes, resp, err := client.Organizations.ListIssueTypes(ctx, owner) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to list issue types", err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list issue types", resp, body), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + issueTypes, resp, err := client.Organizations.ListIssueTypes(ctx, owner) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to list issue types", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(issueTypes) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal issue types", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list issue types", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(issueTypes) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal issue types", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }) } @@ -636,54 +632,52 @@ func AddIssueComment(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo", "issue_number", "body"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - issueNumber, err := RequiredInt(args, "issue_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - body, err := RequiredParam[string](args, "body") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - comment := &github.IssueComment{ - Body: github.Ptr(body), - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + issueNumber, err := RequiredInt(args, "issue_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + body, err := RequiredParam[string](args, "body") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to create comment", err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + comment := &github.IssueComment{ + Body: github.Ptr(body), + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create comment", resp, body), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to create comment", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(createdComment) + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create comment", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(createdComment) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }) } @@ -742,62 +736,60 @@ Options are: Required: []string{"method", "owner", "repo", "issue_number", "sub_issue_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - issueNumber, err := RequiredInt(args, "issue_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - subIssueID, err := RequiredInt(args, "sub_issue_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - replaceParent, err := OptionalParam[bool](args, "replace_parent") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - afterID, err := OptionalIntParam(args, "after_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - beforeID, err := OptionalIntParam(args, "before_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + issueNumber, err := RequiredInt(args, "issue_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + subIssueID, err := RequiredInt(args, "sub_issue_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + replaceParent, err := OptionalParam[bool](args, "replace_parent") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + afterID, err := OptionalIntParam(args, "after_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + beforeID, err := OptionalIntParam(args, "before_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - switch strings.ToLower(method) { - case "add": - result, err := AddSubIssue(ctx, client, owner, repo, issueNumber, subIssueID, replaceParent) - return result, nil, err - case "remove": - // Call the remove sub-issue function - result, err := RemoveSubIssue(ctx, client, owner, repo, issueNumber, subIssueID) - return result, nil, err - case "reprioritize": - // Call the reprioritize sub-issue function - result, err := ReprioritizeSubIssue(ctx, client, owner, repo, issueNumber, subIssueID, afterID, beforeID) - return result, nil, err - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil - } + switch strings.ToLower(method) { + case "add": + result, err := AddSubIssue(ctx, client, owner, repo, issueNumber, subIssueID, replaceParent) + return result, nil, err + case "remove": + // Call the remove sub-issue function + result, err := RemoveSubIssue(ctx, client, owner, repo, issueNumber, subIssueID) + return result, nil, err + case "reprioritize": + // Call the reprioritize sub-issue function + result, err := ReprioritizeSubIssue(ctx, client, owner, repo, issueNumber, subIssueID, afterID, beforeID) + return result, nil, err + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil } }) } @@ -971,11 +963,9 @@ func SearchIssues(t translations.TranslationHelperFunc) inventory.ServerTool { }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - result, err := searchHandler(ctx, deps.GetClient, args, "issue", "failed to search issues") - return result, nil, err - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + result, err := searchHandler(ctx, deps.GetClient, args, "issue", "failed to search issues") + return result, nil, err }) } @@ -1062,104 +1052,102 @@ Options are: Required: []string{"method", "owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - title, err := OptionalParam[string](args, "title") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + title, err := OptionalParam[string](args, "title") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Optional parameters - body, err := OptionalParam[string](args, "body") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Optional parameters + body, err := OptionalParam[string](args, "body") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get assignees - assignees, err := OptionalStringArrayParam(args, "assignees") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get assignees + assignees, err := OptionalStringArrayParam(args, "assignees") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get labels - labels, err := OptionalStringArrayParam(args, "labels") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get labels + labels, err := OptionalStringArrayParam(args, "labels") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get optional milestone - milestone, err := OptionalIntParam(args, "milestone") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get optional milestone + milestone, err := OptionalIntParam(args, "milestone") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var milestoneNum int - if milestone != 0 { - milestoneNum = milestone - } + var milestoneNum int + if milestone != 0 { + milestoneNum = milestone + } - // Get optional type - issueType, err := OptionalParam[string](args, "type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get optional type + issueType, err := OptionalParam[string](args, "type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Handle state, state_reason and duplicateOf parameters - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Handle state, state_reason and duplicateOf parameters + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - stateReason, err := OptionalParam[string](args, "state_reason") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + stateReason, err := OptionalParam[string](args, "state_reason") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - duplicateOf, err := OptionalIntParam(args, "duplicate_of") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - if duplicateOf != 0 && stateReason != "duplicate" { - return utils.NewToolResultError("duplicate_of can only be used when state_reason is 'duplicate'"), nil, nil - } + duplicateOf, err := OptionalIntParam(args, "duplicate_of") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if duplicateOf != 0 && stateReason != "duplicate" { + return utils.NewToolResultError("duplicate_of can only be used when state_reason is 'duplicate'"), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - gqlClient, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GraphQL client", err), nil, nil - } + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GraphQL client", err), nil, nil + } - switch method { - case "create": - result, err := CreateIssue(ctx, client, owner, repo, title, body, assignees, labels, milestoneNum, issueType) - return result, nil, err - case "update": - issueNumber, err := RequiredInt(args, "issue_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - result, err := UpdateIssue(ctx, client, gqlClient, owner, repo, issueNumber, title, body, assignees, labels, milestoneNum, issueType, state, stateReason, duplicateOf) - return result, nil, err - default: - return utils.NewToolResultError("invalid method, must be either 'create' or 'update'"), nil, nil + switch method { + case "create": + result, err := CreateIssue(ctx, client, owner, repo, title, body, assignees, labels, milestoneNum, issueType) + return result, nil, err + case "update": + issueNumber, err := RequiredInt(args, "issue_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } + result, err := UpdateIssue(ctx, client, gqlClient, owner, repo, issueNumber, title, body, assignees, labels, milestoneNum, issueType, state, stateReason, duplicateOf) + return result, nil, err + default: + return utils.NewToolResultError("invalid method, must be either 'create' or 'update'"), nil, nil } }) } @@ -1393,187 +1381,185 @@ func ListIssues(t translations.TranslationHelperFunc) inventory.ServerTool { }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - // Set optional parameters if provided - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Normalize and filter by state - state = strings.ToUpper(state) - var states []githubv4.IssueState + // Set optional parameters if provided + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - switch state { - case "OPEN", "CLOSED": - states = []githubv4.IssueState{githubv4.IssueState(state)} - default: - states = []githubv4.IssueState{githubv4.IssueStateOpen, githubv4.IssueStateClosed} - } + // Normalize and filter by state + state = strings.ToUpper(state) + var states []githubv4.IssueState - // Get labels - labels, err := OptionalStringArrayParam(args, "labels") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + switch state { + case "OPEN", "CLOSED": + states = []githubv4.IssueState{githubv4.IssueState(state)} + default: + states = []githubv4.IssueState{githubv4.IssueStateOpen, githubv4.IssueStateClosed} + } - orderBy, err := OptionalParam[string](args, "orderBy") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get labels + labels, err := OptionalStringArrayParam(args, "labels") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + orderBy, err := OptionalParam[string](args, "orderBy") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Normalize and validate orderBy - orderBy = strings.ToUpper(orderBy) - switch orderBy { - case "CREATED_AT", "UPDATED_AT", "COMMENTS": - // Valid, keep as is - default: - orderBy = "CREATED_AT" - } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Normalize and validate direction - direction = strings.ToUpper(direction) - switch direction { - case "ASC", "DESC": - // Valid, keep as is - default: - direction = "DESC" - } + // Normalize and validate orderBy + orderBy = strings.ToUpper(orderBy) + switch orderBy { + case "CREATED_AT", "UPDATED_AT", "COMMENTS": + // Valid, keep as is + default: + orderBy = "CREATED_AT" + } - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Normalize and validate direction + direction = strings.ToUpper(direction) + switch direction { + case "ASC", "DESC": + // Valid, keep as is + default: + direction = "DESC" + } - // There are two optional parameters: since and labels. - var sinceTime time.Time - var hasSince bool - if since != "" { - sinceTime, err = parseISOTimestamp(since) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil, nil - } - hasSince = true - } - hasLabels := len(labels) > 0 + since, err := OptionalParam[string](args, "since") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get pagination parameters and convert to GraphQL format - pagination, err := OptionalCursorPaginationParams(args) + // There are two optional parameters: since and labels. + var sinceTime time.Time + var hasSince bool + if since != "" { + sinceTime, err = parseISOTimestamp(since) if err != nil { - return nil, nil, err + return utils.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil, nil } + hasSince = true + } + hasLabels := len(labels) > 0 - // Check if someone tried to use page-based pagination instead of cursor-based - if _, pageProvided := args["page"]; pageProvided { - return utils.NewToolResultError("This tool uses cursor-based pagination. Use the 'after' parameter with the 'endCursor' value from the previous response instead of 'page'."), nil, nil - } + // Get pagination parameters and convert to GraphQL format + pagination, err := OptionalCursorPaginationParams(args) + if err != nil { + return nil, nil, err + } - // Check if pagination parameters were explicitly provided - _, perPageProvided := args["perPage"] - paginationExplicit := perPageProvided + // Check if someone tried to use page-based pagination instead of cursor-based + if _, pageProvided := args["page"]; pageProvided { + return utils.NewToolResultError("This tool uses cursor-based pagination. Use the 'after' parameter with the 'endCursor' value from the previous response instead of 'page'."), nil, nil + } - paginationParams, err := pagination.ToGraphQLParams() - if err != nil { - return nil, nil, err - } + // Check if pagination parameters were explicitly provided + _, perPageProvided := args["perPage"] + paginationExplicit := perPageProvided - // Use default of 30 if pagination was not explicitly provided - if !paginationExplicit { - defaultFirst := int32(DefaultGraphQLPageSize) - paginationParams.First = &defaultFirst - } + paginationParams, err := pagination.ToGraphQLParams() + if err != nil { + return nil, nil, err + } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + // Use default of 30 if pagination was not explicitly provided + if !paginationExplicit { + defaultFirst := int32(DefaultGraphQLPageSize) + paginationParams.First = &defaultFirst + } - vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "states": states, - "orderBy": githubv4.IssueOrderField(orderBy), - "direction": githubv4.OrderDirection(direction), - "first": githubv4.Int(*paginationParams.First), - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - if paginationParams.After != nil { - vars["after"] = githubv4.String(*paginationParams.After) - } else { - // Used within query, therefore must be set to nil and provided as $after - vars["after"] = (*githubv4.String)(nil) - } + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "states": states, + "orderBy": githubv4.IssueOrderField(orderBy), + "direction": githubv4.OrderDirection(direction), + "first": githubv4.Int(*paginationParams.First), + } - // Ensure optional parameters are set - if hasLabels { - // Use query with labels filtering - convert string labels to githubv4.String slice - labelStrings := make([]githubv4.String, len(labels)) - for i, label := range labels { - labelStrings[i] = githubv4.String(label) - } - vars["labels"] = labelStrings - } + if paginationParams.After != nil { + vars["after"] = githubv4.String(*paginationParams.After) + } else { + // Used within query, therefore must be set to nil and provided as $after + vars["after"] = (*githubv4.String)(nil) + } - if hasSince { - vars["since"] = githubv4.DateTime{Time: sinceTime} + // Ensure optional parameters are set + if hasLabels { + // Use query with labels filtering - convert string labels to githubv4.String slice + labelStrings := make([]githubv4.String, len(labels)) + for i, label := range labels { + labelStrings[i] = githubv4.String(label) } + vars["labels"] = labelStrings + } - issueQuery := getIssueQueryType(hasLabels, hasSince) - if err := client.Query(ctx, issueQuery, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + if hasSince { + vars["since"] = githubv4.DateTime{Time: sinceTime} + } - // Extract and convert all issue nodes using the common interface - var issues []*github.Issue - var pageInfo struct { - HasNextPage githubv4.Boolean - HasPreviousPage githubv4.Boolean - StartCursor githubv4.String - EndCursor githubv4.String - } - var totalCount int + issueQuery := getIssueQueryType(hasLabels, hasSince) + if err := client.Query(ctx, issueQuery, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if queryResult, ok := issueQuery.(IssueQueryResult); ok { - fragment := queryResult.GetIssueFragment() - for _, issue := range fragment.Nodes { - issues = append(issues, fragmentToIssue(issue)) - } - pageInfo = fragment.PageInfo - totalCount = fragment.TotalCount - } + // Extract and convert all issue nodes using the common interface + var issues []*github.Issue + var pageInfo struct { + HasNextPage githubv4.Boolean + HasPreviousPage githubv4.Boolean + StartCursor githubv4.String + EndCursor githubv4.String + } + var totalCount int - // Create response with issues - response := map[string]interface{}{ - "issues": issues, - "pageInfo": map[string]interface{}{ - "hasNextPage": pageInfo.HasNextPage, - "hasPreviousPage": pageInfo.HasPreviousPage, - "startCursor": string(pageInfo.StartCursor), - "endCursor": string(pageInfo.EndCursor), - }, - "totalCount": totalCount, - } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal issues: %w", err) + if queryResult, ok := issueQuery.(IssueQueryResult); ok { + fragment := queryResult.GetIssueFragment() + for _, issue := range fragment.Nodes { + issues = append(issues, fragmentToIssue(issue)) } - return utils.NewToolResultText(string(out)), nil, nil + pageInfo = fragment.PageInfo + totalCount = fragment.TotalCount } + + // Create response with issues + response := map[string]interface{}{ + "issues": issues, + "pageInfo": map[string]interface{}{ + "hasNextPage": pageInfo.HasNextPage, + "hasPreviousPage": pageInfo.HasPreviousPage, + "startCursor": string(pageInfo.StartCursor), + "endCursor": string(pageInfo.EndCursor), + }, + "totalCount": totalCount, + } + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal issues: %w", err) + } + return utils.NewToolResultText(string(out)), nil, nil }) } @@ -1648,133 +1634,131 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server Required: []string{"owner", "repo", "issueNumber"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - var params struct { - Owner string - Repo string - IssueNumber int32 - } - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + var params struct { + Owner string + Repo string + IssueNumber int32 + } + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Firstly, we try to find the copilot bot in the suggested actors for the repository. - // Although as I write this, we would expect copilot to be at the top of the list, in future, maybe - // it will not be on the first page of responses, thus we will keep paginating until we find it. - type botAssignee struct { - ID githubv4.ID - Login string - TypeName string `graphql:"__typename"` - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - type suggestedActorsQuery struct { - Repository struct { - SuggestedActors struct { - Nodes []struct { - Bot botAssignee `graphql:"... on Bot"` - } - PageInfo struct { - HasNextPage bool - EndCursor string - } - } `graphql:"suggestedActors(first: 100, after: $endCursor, capabilities: CAN_BE_ASSIGNED)"` - } `graphql:"repository(owner: $owner, name: $name)"` - } + // Firstly, we try to find the copilot bot in the suggested actors for the repository. + // Although as I write this, we would expect copilot to be at the top of the list, in future, maybe + // it will not be on the first page of responses, thus we will keep paginating until we find it. + type botAssignee struct { + ID githubv4.ID + Login string + TypeName string `graphql:"__typename"` + } - variables := map[string]any{ - "owner": githubv4.String(params.Owner), - "name": githubv4.String(params.Repo), - "endCursor": (*githubv4.String)(nil), - } + type suggestedActorsQuery struct { + Repository struct { + SuggestedActors struct { + Nodes []struct { + Bot botAssignee `graphql:"... on Bot"` + } + PageInfo struct { + HasNextPage bool + EndCursor string + } + } `graphql:"suggestedActors(first: 100, after: $endCursor, capabilities: CAN_BE_ASSIGNED)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } - var copilotAssignee *botAssignee - for { - var query suggestedActorsQuery - err := client.Query(ctx, &query, variables) - if err != nil { - return nil, nil, err - } + variables := map[string]any{ + "owner": githubv4.String(params.Owner), + "name": githubv4.String(params.Repo), + "endCursor": (*githubv4.String)(nil), + } - // Iterate all the returned nodes looking for the copilot bot, which is supposed to have the - // same name on each host. We need this in order to get the ID for later assignment. - for _, node := range query.Repository.SuggestedActors.Nodes { - if node.Bot.Login == "copilot-swe-agent" { - copilotAssignee = &node.Bot - break - } - } + var copilotAssignee *botAssignee + for { + var query suggestedActorsQuery + err := client.Query(ctx, &query, variables) + if err != nil { + return nil, nil, err + } - if !query.Repository.SuggestedActors.PageInfo.HasNextPage { + // Iterate all the returned nodes looking for the copilot bot, which is supposed to have the + // same name on each host. We need this in order to get the ID for later assignment. + for _, node := range query.Repository.SuggestedActors.Nodes { + if node.Bot.Login == "copilot-swe-agent" { + copilotAssignee = &node.Bot break } - variables["endCursor"] = githubv4.String(query.Repository.SuggestedActors.PageInfo.EndCursor) } - // If we didn't find the copilot bot, we can't proceed any further. - if copilotAssignee == nil { - // The e2e tests depend upon this specific message to skip the test. - return utils.NewToolResultError("copilot isn't available as an assignee for this issue. Please inform the user to visit https://docs.github.com/en/copilot/using-github-copilot/using-copilot-coding-agent-to-work-on-tasks/about-assigning-tasks-to-copilot for more information."), nil, nil + if !query.Repository.SuggestedActors.PageInfo.HasNextPage { + break } + variables["endCursor"] = githubv4.String(query.Repository.SuggestedActors.PageInfo.EndCursor) + } - // Next let's get the GQL Node ID and current assignees for this issue because the only way to - // assign copilot is to use replaceActorsForAssignable which requires the full list. - var getIssueQuery struct { - Repository struct { - Issue struct { - ID githubv4.ID - Assignees struct { - Nodes []struct { - ID githubv4.ID - } - } `graphql:"assignees(first: 100)"` - } `graphql:"issue(number: $number)"` - } `graphql:"repository(owner: $owner, name: $name)"` - } + // If we didn't find the copilot bot, we can't proceed any further. + if copilotAssignee == nil { + // The e2e tests depend upon this specific message to skip the test. + return utils.NewToolResultError("copilot isn't available as an assignee for this issue. Please inform the user to visit https://docs.github.com/en/copilot/using-github-copilot/using-copilot-coding-agent-to-work-on-tasks/about-assigning-tasks-to-copilot for more information."), nil, nil + } - variables = map[string]any{ - "owner": githubv4.String(params.Owner), - "name": githubv4.String(params.Repo), - "number": githubv4.Int(params.IssueNumber), - } + // Next let's get the GQL Node ID and current assignees for this issue because the only way to + // assign copilot is to use replaceActorsForAssignable which requires the full list. + var getIssueQuery struct { + Repository struct { + Issue struct { + ID githubv4.ID + Assignees struct { + Nodes []struct { + ID githubv4.ID + } + } `graphql:"assignees(first: 100)"` + } `graphql:"issue(number: $number)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } - if err := client.Query(ctx, &getIssueQuery, variables); err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get issue ID: %v", err)), nil, nil - } + variables = map[string]any{ + "owner": githubv4.String(params.Owner), + "name": githubv4.String(params.Repo), + "number": githubv4.Int(params.IssueNumber), + } - // Finally, do the assignment. Just for reference, assigning copilot to an issue that it is already - // assigned to seems to have no impact (which is a good thing). - var assignCopilotMutation struct { - ReplaceActorsForAssignable struct { - Typename string `graphql:"__typename"` // Not required but we need a selector or GQL errors - } `graphql:"replaceActorsForAssignable(input: $input)"` - } + if err := client.Query(ctx, &getIssueQuery, variables); err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get issue ID: %v", err)), nil, nil + } - actorIDs := make([]githubv4.ID, len(getIssueQuery.Repository.Issue.Assignees.Nodes)+1) - for i, node := range getIssueQuery.Repository.Issue.Assignees.Nodes { - actorIDs[i] = node.ID - } - actorIDs[len(getIssueQuery.Repository.Issue.Assignees.Nodes)] = copilotAssignee.ID - - if err := client.Mutate( - ctx, - &assignCopilotMutation, - ReplaceActorsForAssignableInput{ - AssignableID: getIssueQuery.Repository.Issue.ID, - ActorIDs: actorIDs, - }, - nil, - ); err != nil { - return nil, nil, fmt.Errorf("failed to replace actors for assignable: %w", err) - } + // Finally, do the assignment. Just for reference, assigning copilot to an issue that it is already + // assigned to seems to have no impact (which is a good thing). + var assignCopilotMutation struct { + ReplaceActorsForAssignable struct { + Typename string `graphql:"__typename"` // Not required but we need a selector or GQL errors + } `graphql:"replaceActorsForAssignable(input: $input)"` + } - return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil + actorIDs := make([]githubv4.ID, len(getIssueQuery.Repository.Issue.Assignees.Nodes)+1) + for i, node := range getIssueQuery.Repository.Issue.Assignees.Nodes { + actorIDs[i] = node.ID } + actorIDs[len(getIssueQuery.Repository.Issue.Assignees.Nodes)] = copilotAssignee.ID + + if err := client.Mutate( + ctx, + &assignCopilotMutation, + ReplaceActorsForAssignableInput{ + AssignableID: getIssueQuery.Repository.Issue.ID, + ActorIDs: actorIDs, + }, + nil, + ); err != nil { + return nil, nil, fmt.Errorf("failed to replace actors for assignable: %w", err) + } + + return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil }) } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 4c686cc57..b810cede3 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -339,7 +339,7 @@ func Test_GetIssue(t *testing.T) { handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectHandlerError { require.Error(t, err) @@ -456,7 +456,7 @@ func Test_AddIssueComment(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -790,7 +790,7 @@ func Test_SearchIssues(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -962,7 +962,7 @@ func Test_CreateIssue(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1274,7 +1274,7 @@ func Test_ListIssues(t *testing.T) { handler := serverTool.Handler(deps) req := createMCPRequest(tc.reqParams) - res, err := handler(context.Background(), &req) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -1779,7 +1779,7 @@ func Test_UpdateIssue(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError || tc.expectedErrMsg != "" { @@ -2028,7 +2028,7 @@ func Test_GetIssueComments(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2145,7 +2145,7 @@ func Test_GetIssueLabels(t *testing.T) { handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -2569,7 +2569,7 @@ func TestAssignCopilotToIssue(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2800,7 +2800,7 @@ func Test_AddSubIssue(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3047,7 +3047,7 @@ func Test_GetSubIssues(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3284,7 +3284,7 @@ func Test_RemoveSubIssue(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3573,7 +3573,7 @@ func Test_ReprioritizeSubIssue(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3707,7 +3707,7 @@ func Test_ListIssueTypes(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { diff --git a/pkg/github/labels.go b/pkg/github/labels.go index a56956f6c..2811cf66e 100644 --- a/pkg/github/labels.go +++ b/pkg/github/labels.go @@ -45,67 +45,65 @@ func GetLabel(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "name"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - name, err := RequiredParam[string](args, "name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var query struct { - Repository struct { - Label struct { - ID githubv4.ID - Name githubv4.String - Color githubv4.String - Description githubv4.String - } `graphql:"label(name: $name)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } + name, err := RequiredParam[string](args, "name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - vars := map[string]any{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "name": githubv4.String(name), - } + var query struct { + Repository struct { + Label struct { + ID githubv4.ID + Name githubv4.String + Color githubv4.String + Description githubv4.String + } `graphql:"label(name: $name)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + vars := map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "name": githubv4.String(name), + } - if err := client.Query(ctx, &query, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find label", err), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if query.Repository.Label.Name == "" { - return utils.NewToolResultError(fmt.Sprintf("label '%s' not found in %s/%s", name, owner, repo)), nil, nil - } + if err := client.Query(ctx, &query, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find label", err), nil, nil + } - label := map[string]any{ - "id": fmt.Sprintf("%v", query.Repository.Label.ID), - "name": string(query.Repository.Label.Name), - "color": string(query.Repository.Label.Color), - "description": string(query.Repository.Label.Description), - } + if query.Repository.Label.Name == "" { + return utils.NewToolResultError(fmt.Sprintf("label '%s' not found in %s/%s", name, owner, repo)), nil, nil + } - out, err := json.Marshal(label) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal label: %w", err) - } + label := map[string]any{ + "id": fmt.Sprintf("%v", query.Repository.Label.ID), + "name": string(query.Repository.Label.Name), + "color": string(query.Repository.Label.Color), + "description": string(query.Repository.Label.Description), + } - return utils.NewToolResultText(string(out)), nil, nil + out, err := json.Marshal(label) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal label: %w", err) } + + return utils.NewToolResultText(string(out)), nil, nil }, ) } @@ -144,68 +142,66 @@ func ListLabels(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var query struct { - Repository struct { - Labels struct { - Nodes []struct { - ID githubv4.ID - Name githubv4.String - Color githubv4.String - Description githubv4.String - } - TotalCount githubv4.Int - } `graphql:"labels(first: 100)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - vars := map[string]any{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - } + var query struct { + Repository struct { + Labels struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + Color githubv4.String + Description githubv4.String + } + TotalCount githubv4.Int + } `graphql:"labels(first: 100)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } - if err := client.Query(ctx, &query, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to list labels", err), nil, nil - } + vars := map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + } - labels := make([]map[string]any, len(query.Repository.Labels.Nodes)) - for i, labelNode := range query.Repository.Labels.Nodes { - labels[i] = map[string]any{ - "id": fmt.Sprintf("%v", labelNode.ID), - "name": string(labelNode.Name), - "color": string(labelNode.Color), - "description": string(labelNode.Description), - } - } + if err := client.Query(ctx, &query, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to list labels", err), nil, nil + } - response := map[string]any{ - "labels": labels, - "totalCount": int(query.Repository.Labels.TotalCount), + labels := make([]map[string]any, len(query.Repository.Labels.Nodes)) + for i, labelNode := range query.Repository.Labels.Nodes { + labels[i] = map[string]any{ + "id": fmt.Sprintf("%v", labelNode.ID), + "name": string(labelNode.Name), + "color": string(labelNode.Color), + "description": string(labelNode.Description), } + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal labels: %w", err) - } + response := map[string]any{ + "labels": labels, + "totalCount": int(query.Repository.Labels.TotalCount), + } - return utils.NewToolResultText(string(out)), nil, nil + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal labels: %w", err) } + + return utils.NewToolResultText(string(out)), nil, nil }, ) } @@ -257,147 +253,145 @@ func LabelWrite(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"method", "owner", "repo", "name"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Get and validate required parameters - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + // Get and validate required parameters + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + method = strings.ToLower(method) + + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + name, err := RequiredParam[string](args, "name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + // Get optional parameters + newName, _ := OptionalParam[string](args, "new_name") + color, _ := OptionalParam[string](args, "color") + description, _ := OptionalParam[string](args, "description") + + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + switch method { + case "create": + // Validate required params for create + if color == "" { + return utils.NewToolResultError("color is required for create"), nil, nil } - method = strings.ToLower(method) - owner, err := RequiredParam[string](args, "owner") + // Get repository ID + repoID, err := getRepositoryID(ctx, client, owner, repo) if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find repository", err), nil, nil } - repo, err := RequiredParam[string](args, "repo") + input := githubv4.CreateLabelInput{ + RepositoryID: repoID, + Name: githubv4.String(name), + Color: githubv4.String(color), + } + if description != "" { + d := githubv4.String(description) + input.Description = &d + } + + var mutation struct { + CreateLabel struct { + Label struct { + Name githubv4.String + ID githubv4.ID + } + } `graphql:"createLabel(input: $input)"` + } + + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to create label", err), nil, nil + } + + return utils.NewToolResultText(fmt.Sprintf("label '%s' created successfully", mutation.CreateLabel.Label.Name)), nil, nil + + case "update": + // Validate required params for update + if newName == "" && color == "" && description == "" { + return utils.NewToolResultError("at least one of new_name, color, or description must be provided for update"), nil, nil + } + + // Get the label ID + labelID, err := getLabelID(ctx, client, owner, repo, name) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - name, err := RequiredParam[string](args, "name") + input := githubv4.UpdateLabelInput{ + ID: labelID, + } + if newName != "" { + n := githubv4.String(newName) + input.Name = &n + } + if color != "" { + c := githubv4.String(color) + input.Color = &c + } + if description != "" { + d := githubv4.String(description) + input.Description = &d + } + + var mutation struct { + UpdateLabel struct { + Label struct { + Name githubv4.String + ID githubv4.ID + } + } `graphql:"updateLabel(input: $input)"` + } + + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to update label", err), nil, nil + } + + return utils.NewToolResultText(fmt.Sprintf("label '%s' updated successfully", mutation.UpdateLabel.Label.Name)), nil, nil + + case "delete": + // Get the label ID + labelID, err := getLabelID(ctx, client, owner, repo, name) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - // Get optional parameters - newName, _ := OptionalParam[string](args, "new_name") - color, _ := OptionalParam[string](args, "color") - description, _ := OptionalParam[string](args, "description") + input := githubv4.DeleteLabelInput{ + ID: labelID, + } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + var mutation struct { + DeleteLabel struct { + ClientMutationID githubv4.String + } `graphql:"deleteLabel(input: $input)"` } - switch method { - case "create": - // Validate required params for create - if color == "" { - return utils.NewToolResultError("color is required for create"), nil, nil - } - - // Get repository ID - repoID, err := getRepositoryID(ctx, client, owner, repo) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find repository", err), nil, nil - } - - input := githubv4.CreateLabelInput{ - RepositoryID: repoID, - Name: githubv4.String(name), - Color: githubv4.String(color), - } - if description != "" { - d := githubv4.String(description) - input.Description = &d - } - - var mutation struct { - CreateLabel struct { - Label struct { - Name githubv4.String - ID githubv4.ID - } - } `graphql:"createLabel(input: $input)"` - } - - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to create label", err), nil, nil - } - - return utils.NewToolResultText(fmt.Sprintf("label '%s' created successfully", mutation.CreateLabel.Label.Name)), nil, nil - - case "update": - // Validate required params for update - if newName == "" && color == "" && description == "" { - return utils.NewToolResultError("at least one of new_name, color, or description must be provided for update"), nil, nil - } - - // Get the label ID - labelID, err := getLabelID(ctx, client, owner, repo, name) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - input := githubv4.UpdateLabelInput{ - ID: labelID, - } - if newName != "" { - n := githubv4.String(newName) - input.Name = &n - } - if color != "" { - c := githubv4.String(color) - input.Color = &c - } - if description != "" { - d := githubv4.String(description) - input.Description = &d - } - - var mutation struct { - UpdateLabel struct { - Label struct { - Name githubv4.String - ID githubv4.ID - } - } `graphql:"updateLabel(input: $input)"` - } - - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to update label", err), nil, nil - } - - return utils.NewToolResultText(fmt.Sprintf("label '%s' updated successfully", mutation.UpdateLabel.Label.Name)), nil, nil - - case "delete": - // Get the label ID - labelID, err := getLabelID(ctx, client, owner, repo, name) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - input := githubv4.DeleteLabelInput{ - ID: labelID, - } - - var mutation struct { - DeleteLabel struct { - ClientMutationID githubv4.String - } `graphql:"deleteLabel(input: $input)"` - } - - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to delete label", err), nil, nil - } - - return utils.NewToolResultText(fmt.Sprintf("label '%s' deleted successfully", name)), nil, nil - - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s. Supported methods are: create, update, delete", method)), nil, nil + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to delete label", err), nil, nil } + + return utils.NewToolResultText(fmt.Sprintf("label '%s' deleted successfully", name)), nil, nil + + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s. Supported methods are: create, update, delete", method)), nil, nil } }, ) diff --git a/pkg/github/labels_test.go b/pkg/github/labels_test.go index fa646e884..88102ba3c 100644 --- a/pkg/github/labels_test.go +++ b/pkg/github/labels_test.go @@ -120,7 +120,7 @@ func TestGetLabel(t *testing.T) { handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -218,7 +218,7 @@ func TestListLabels(t *testing.T) { handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -469,7 +469,7 @@ func TestWriteLabel(t *testing.T) { handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) assert.NotNil(t, result) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index c6e18529f..1e2011fa3 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -62,102 +62,100 @@ func ListNotifications(t translations.TranslationHelperFunc) inventory.ServerToo }, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - filter, err := OptionalParam[string](args, "filter") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + filter, err := OptionalParam[string](args, "filter") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - before, err := OptionalParam[string](args, "before") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + since, err := OptionalParam[string](args, "since") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := OptionalParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := OptionalParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + before, err := OptionalParam[string](args, "before") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - paginationParams, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := OptionalParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Build options - opts := &github.NotificationListOptions{ - All: filter == FilterIncludeRead, - Participating: filter == FilterOnlyParticipating, - ListOptions: github.ListOptions{ - Page: paginationParams.Page, - PerPage: paginationParams.PerPage, - }, - } + paginationParams, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Parse time parameters if provided - if since != "" { - sinceTime, err := time.Parse(time.RFC3339, since) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil, nil - } - opts.Since = sinceTime - } + // Build options + opts := &github.NotificationListOptions{ + All: filter == FilterIncludeRead, + Participating: filter == FilterOnlyParticipating, + ListOptions: github.ListOptions{ + Page: paginationParams.Page, + PerPage: paginationParams.PerPage, + }, + } - if before != "" { - beforeTime, err := time.Parse(time.RFC3339, before) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil, nil - } - opts.Before = beforeTime + // Parse time parameters if provided + if since != "" { + sinceTime, err := time.Parse(time.RFC3339, since) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil, nil } + opts.Since = sinceTime + } - var notifications []*github.Notification - var resp *github.Response - - if owner != "" && repo != "" { - notifications, resp, err = client.Activity.ListRepositoryNotifications(ctx, owner, repo, opts) - } else { - notifications, resp, err = client.Activity.ListNotifications(ctx, opts) - } + if before != "" { + beforeTime, err := time.Parse(time.RFC3339, before) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list notifications", - resp, - err, - ), nil, nil + return utils.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil, nil } - defer func() { _ = resp.Body.Close() }() + opts.Before = beforeTime + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get notifications", resp, body), nil, nil - } + var notifications []*github.Notification + var resp *github.Response - // Marshal response to JSON - r, err := json.Marshal(notifications) + if owner != "" && repo != "" { + notifications, resp, err = client.Activity.ListRepositoryNotifications(ctx, owner, repo, opts) + } else { + notifications, resp, err = client.Activity.ListNotifications(ctx, opts) + } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list notifications", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get notifications", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + // Marshal response to JSON + r, err := json.Marshal(notifications) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -189,58 +187,56 @@ func DismissNotification(t translations.TranslationHelperFunc) inventory.ServerT Required: []string{"threadID", "state"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - threadID, err := RequiredParam[string](args, "threadID") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + threadID, err := RequiredParam[string](args, "threadID") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + state, err := RequiredParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - state, err := RequiredParam[string](args, "state") + var resp *github.Response + switch state { + case "done": + // for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint + var threadIDInt int64 + threadIDInt, err = strconv.ParseInt(threadID, 10, 64) if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + return utils.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil, nil } + resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt) + case "read": + resp, err = client.Activity.MarkThreadRead(ctx, threadID) + default: + return utils.NewToolResultError("Invalid state. Must be one of: read, done."), nil, nil + } - var resp *github.Response - switch state { - case "done": - // for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint - var threadIDInt int64 - threadIDInt, err = strconv.ParseInt(threadID, 10, 64) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil, nil - } - resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt) - case "read": - resp, err = client.Activity.MarkThreadRead(ctx, threadID) - default: - return utils.NewToolResultError("Invalid state. Must be one of: read, done."), nil, nil - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to mark notification as %s", state), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to mark notification as %s", state), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to mark notification as %s", state), resp, body), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - - return utils.NewToolResultText(fmt.Sprintf("Notification marked as %s", state)), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to mark notification as %s", state), resp, body), nil, nil } + + return utils.NewToolResultText(fmt.Sprintf("Notification marked as %s", state)), nil, nil }, ) } @@ -274,66 +270,64 @@ func MarkAllNotificationsRead(t translations.TranslationHelperFunc) inventory.Se }, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - lastReadAt, err := OptionalParam[string](args, "lastReadAt") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + lastReadAt, err := OptionalParam[string](args, "lastReadAt") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := OptionalParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := OptionalParam[string](args, "repo") + owner, err := OptionalParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + var lastReadTime time.Time + if lastReadAt != "" { + lastReadTime, err = time.Parse(time.RFC3339, lastReadAt) if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + return utils.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil, nil } + } else { + lastReadTime = time.Now() + } - var lastReadTime time.Time - if lastReadAt != "" { - lastReadTime, err = time.Parse(time.RFC3339, lastReadAt) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil, nil - } - } else { - lastReadTime = time.Now() - } + markReadOptions := github.Timestamp{ + Time: lastReadTime, + } - markReadOptions := github.Timestamp{ - Time: lastReadTime, - } + var resp *github.Response + if owner != "" && repo != "" { + resp, err = client.Activity.MarkRepositoryNotificationsRead(ctx, owner, repo, markReadOptions) + } else { + resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) + } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to mark all notifications as read", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - var resp *github.Response - if owner != "" && repo != "" { - resp, err = client.Activity.MarkRepositoryNotificationsRead(ctx, owner, repo, markReadOptions) - } else { - resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) - } + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to mark all notifications as read", - resp, - err, - ), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to mark all notifications as read", resp, body), nil, nil - } - - return utils.NewToolResultText("All notifications marked as read"), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to mark all notifications as read", resp, body), nil, nil } + + return utils.NewToolResultText("All notifications marked as read"), nil, nil }, ) } @@ -360,43 +354,41 @@ func GetNotificationDetails(t translations.TranslationHelperFunc) inventory.Serv Required: []string{"notificationID"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - notificationID, err := RequiredParam[string](args, "notificationID") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - thread, resp, err := client.Activity.GetThread(ctx, notificationID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get notification details for ID '%s'", notificationID), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + notificationID, err := RequiredParam[string](args, "notificationID") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get notification details", resp, body), nil, nil - } + thread, resp, err := client.Activity.GetThread(ctx, notificationID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get notification details for ID '%s'", notificationID), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(thread) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get notification details", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(thread) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -435,66 +427,64 @@ func ManageNotificationSubscription(t translations.TranslationHelperFunc) invent Required: []string{"notificationID", "action"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - notificationID, err := RequiredParam[string](args, "notificationID") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - action, err := RequiredParam[string](args, "action") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + notificationID, err := RequiredParam[string](args, "notificationID") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + action, err := RequiredParam[string](args, "action") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var ( - resp *github.Response - result any - apiErr error - ) - - switch action { - case NotificationActionIgnore: - sub := &github.Subscription{Ignored: ToBoolPtr(true)} - result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) - case NotificationActionWatch: - sub := &github.Subscription{Ignored: ToBoolPtr(false), Subscribed: ToBoolPtr(true)} - result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) - case NotificationActionDelete: - resp, apiErr = client.Activity.DeleteThreadSubscription(ctx, notificationID) - default: - return utils.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil, nil - } + var ( + resp *github.Response + result any + apiErr error + ) + + switch action { + case NotificationActionIgnore: + sub := &github.Subscription{Ignored: ToBoolPtr(true)} + result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) + case NotificationActionWatch: + sub := &github.Subscription{Ignored: ToBoolPtr(false), Subscribed: ToBoolPtr(true)} + result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) + case NotificationActionDelete: + resp, apiErr = client.Activity.DeleteThreadSubscription(ctx, notificationID) + default: + return utils.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil, nil + } - if apiErr != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to %s notification subscription", action), - resp, - apiErr, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if apiErr != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to %s notification subscription", action), + resp, + apiErr, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(resp.Body) - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to %s notification subscription", action), resp, body), nil, nil - } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to %s notification subscription", action), resp, body), nil, nil + } - if action == NotificationActionDelete { - // Special case for delete as there is no response body - return utils.NewToolResultText("Notification subscription deleted"), nil, nil - } + if action == NotificationActionDelete { + // Special case for delete as there is no response body + return utils.NewToolResultText("Notification subscription deleted"), nil, nil + } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil - } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -536,73 +526,71 @@ func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFu Required: []string{"owner", "repo", "action"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - action, err := RequiredParam[string](args, "action") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + action, err := RequiredParam[string](args, "action") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var ( - resp *github.Response - result any - apiErr error - ) - - switch action { - case RepositorySubscriptionActionIgnore: - sub := &github.Subscription{Ignored: ToBoolPtr(true)} - result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) - case RepositorySubscriptionActionWatch: - sub := &github.Subscription{Ignored: ToBoolPtr(false), Subscribed: ToBoolPtr(true)} - result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) - case RepositorySubscriptionActionDelete: - resp, apiErr = client.Activity.DeleteRepositorySubscription(ctx, owner, repo) - default: - return utils.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil, nil - } + var ( + resp *github.Response + result any + apiErr error + ) + + switch action { + case RepositorySubscriptionActionIgnore: + sub := &github.Subscription{Ignored: ToBoolPtr(true)} + result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) + case RepositorySubscriptionActionWatch: + sub := &github.Subscription{Ignored: ToBoolPtr(false), Subscribed: ToBoolPtr(true)} + result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) + case RepositorySubscriptionActionDelete: + resp, apiErr = client.Activity.DeleteRepositorySubscription(ctx, owner, repo) + default: + return utils.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil, nil + } - if apiErr != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to %s repository subscription", action), - resp, - apiErr, - ), nil, nil - } - if resp != nil { - defer func() { _ = resp.Body.Close() }() - } + if apiErr != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to %s repository subscription", action), + resp, + apiErr, + ), nil, nil + } + if resp != nil { + defer func() { _ = resp.Body.Close() }() + } - // Handle non-2xx status codes - if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { - body, _ := io.ReadAll(resp.Body) - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to %s repository subscription", action), resp, body), nil, nil - } + // Handle non-2xx status codes + if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { + body, _ := io.ReadAll(resp.Body) + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to %s repository subscription", action), resp, body), nil, nil + } - if action == RepositorySubscriptionActionDelete { - // Special case for delete as there is no response body - return utils.NewToolResultText("Repository subscription deleted"), nil, nil - } + if action == RepositorySubscriptionActionDelete { + // Special case for delete as there is no response body + return utils.NewToolResultText("Repository subscription deleted"), nil, nil + } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil - } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + return utils.NewToolResultText(string(r)), nil, nil }, ) } diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index f730654db..1b12c911f 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -130,7 +130,7 @@ func Test_ListNotifications(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -263,7 +263,7 @@ func Test_ManageNotificationSubscription(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -426,7 +426,7 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -568,7 +568,7 @@ func Test_DismissNotification(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -693,7 +693,7 @@ func Test_MarkAllNotificationsRead(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -777,7 +777,7 @@ func Test_GetNotificationDetails(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { diff --git a/pkg/github/projects.go b/pkg/github/projects.go index d33ac5780..18c1f778b 100644 --- a/pkg/github/projects.go +++ b/pkg/github/projects.go @@ -67,79 +67,77 @@ func ListProjects(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner_type", "owner"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - queryStr, err := OptionalParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := extractPaginationOptionsFromArgs(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + queryStr, err := OptionalParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pagination, err := extractPaginationOptionsFromArgs(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projects []*github.ProjectV2 - var queryPtr *string + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if queryStr != "" { - queryPtr = &queryStr - } + var resp *github.Response + var projects []*github.ProjectV2 + var queryPtr *string - minimalProjects := []MinimalProject{} - opts := &github.ListProjectsOptions{ - ListProjectsPaginationOptions: pagination, - Query: queryPtr, - } + if queryStr != "" { + queryPtr = &queryStr + } - if ownerType == "org" { - projects, resp, err = client.Projects.ListOrganizationProjects(ctx, owner, opts) - } else { - projects, resp, err = client.Projects.ListUserProjects(ctx, owner, opts) - } + minimalProjects := []MinimalProject{} + opts := &github.ListProjectsOptions{ + ListProjectsPaginationOptions: pagination, + Query: queryPtr, + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list projects", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + projects, resp, err = client.Projects.ListOrganizationProjects(ctx, owner, opts) + } else { + projects, resp, err = client.Projects.ListUserProjects(ctx, owner, opts) + } - for _, project := range projects { - minimalProjects = append(minimalProjects, *convertToMinimalProject(project)) - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list projects", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - response := map[string]any{ - "projects": minimalProjects, - "pageInfo": buildPageInfo(resp), - } + for _, project := range projects { + minimalProjects = append(minimalProjects, *convertToMinimalProject(project)) + } - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + response := map[string]any{ + "projects": minimalProjects, + "pageInfo": buildPageInfo(resp), + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -174,62 +172,60 @@ func GetProject(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"project_number", "owner_type", "owner"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var project *github.ProjectV2 + var resp *github.Response + var project *github.ProjectV2 - if ownerType == "org" { - project, resp, err = client.Projects.GetOrganizationProject(ctx, owner, projectNumber) - } else { - project, resp, err = client.Projects.GetUserProject(ctx, owner, projectNumber) - } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get project", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get project", resp, body), nil, nil - } + if ownerType == "org" { + project, resp, err = client.Projects.GetOrganizationProject(ctx, owner, projectNumber) + } else { + project, resp, err = client.Projects.GetUserProject(ctx, owner, projectNumber) + } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get project", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - minimalProject := convertToMinimalProject(project) - r, err := json.Marshal(minimalProject) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get project", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + minimalProject := convertToMinimalProject(project) + r, err := json.Marshal(minimalProject) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -276,68 +272,66 @@ func ListProjectFields(t translations.TranslationHelperFunc) inventory.ServerToo Required: []string{"owner_type", "owner", "project_number"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := extractPaginationOptionsFromArgs(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pagination, err := extractPaginationOptionsFromArgs(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projectFields []*github.ProjectV2Field + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.ListProjectsOptions{ - ListProjectsPaginationOptions: pagination, - } + var resp *github.Response + var projectFields []*github.ProjectV2Field - if ownerType == "org" { - projectFields, resp, err = client.Projects.ListOrganizationProjectFields(ctx, owner, projectNumber, opts) - } else { - projectFields, resp, err = client.Projects.ListUserProjectFields(ctx, owner, projectNumber, opts) - } + opts := &github.ListProjectsOptions{ + ListProjectsPaginationOptions: pagination, + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list project fields", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + projectFields, resp, err = client.Projects.ListOrganizationProjectFields(ctx, owner, projectNumber, opts) + } else { + projectFields, resp, err = client.Projects.ListUserProjectFields(ctx, owner, projectNumber, opts) + } - response := map[string]any{ - "fields": projectFields, - "pageInfo": buildPageInfo(resp), - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list project fields", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + response := map[string]any{ + "fields": projectFields, + "pageInfo": buildPageInfo(resp), + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -376,62 +370,60 @@ func GetProjectField(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner_type", "owner", "project_number", "field_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - fieldID, err := RequiredBigInt(args, "field_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + fieldID, err := RequiredBigInt(args, "field_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projectField *github.ProjectV2Field + var resp *github.Response + var projectField *github.ProjectV2Field - if ownerType == "org" { - projectField, resp, err = client.Projects.GetOrganizationProjectField(ctx, owner, projectNumber, fieldID) - } else { - projectField, resp, err = client.Projects.GetUserProjectField(ctx, owner, projectNumber, fieldID) - } + if ownerType == "org" { + projectField, resp, err = client.Projects.GetOrganizationProjectField(ctx, owner, projectNumber, fieldID) + } else { + projectField, resp, err = client.Projects.GetUserProjectField(ctx, owner, projectNumber, fieldID) + } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get project field", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get project field", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get project field", resp, body), nil, nil - } - r, err := json.Marshal(projectField) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - - return utils.NewToolResultText(string(r)), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get project field", resp, body), nil, nil + } + r, err := json.Marshal(projectField) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -489,87 +481,85 @@ func ListProjectItems(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner_type", "owner", "project_number"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - queryStr, err := OptionalParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - fields, err := OptionalBigIntArrayParam(args, "fields") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + queryStr, err := OptionalParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := extractPaginationOptionsFromArgs(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + fields, err := OptionalBigIntArrayParam(args, "fields") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pagination, err := extractPaginationOptionsFromArgs(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projectItems []*github.ProjectV2Item - var queryPtr *string + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if queryStr != "" { - queryPtr = &queryStr - } + var resp *github.Response + var projectItems []*github.ProjectV2Item + var queryPtr *string - opts := &github.ListProjectItemsOptions{ - Fields: fields, - ListProjectsOptions: github.ListProjectsOptions{ - ListProjectsPaginationOptions: pagination, - Query: queryPtr, - }, - } + if queryStr != "" { + queryPtr = &queryStr + } - if ownerType == "org" { - projectItems, resp, err = client.Projects.ListOrganizationProjectItems(ctx, owner, projectNumber, opts) - } else { - projectItems, resp, err = client.Projects.ListUserProjectItems(ctx, owner, projectNumber, opts) - } + opts := &github.ListProjectItemsOptions{ + Fields: fields, + ListProjectsOptions: github.ListProjectsOptions{ + ListProjectsPaginationOptions: pagination, + Query: queryPtr, + }, + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - ProjectListFailedError, - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + projectItems, resp, err = client.Projects.ListOrganizationProjectItems(ctx, owner, projectNumber, opts) + } else { + projectItems, resp, err = client.Projects.ListUserProjectItems(ctx, owner, projectNumber, opts) + } - response := map[string]any{ - "items": projectItems, - "pageInfo": buildPageInfo(resp), - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + ProjectListFailedError, + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + response := map[string]any{ + "items": projectItems, + "pageInfo": buildPageInfo(resp), + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -615,69 +605,67 @@ func GetProjectItem(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner_type", "owner", "project_number", "item_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - itemID, err := RequiredBigInt(args, "item_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - fields, err := OptionalBigIntArrayParam(args, "fields") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + itemID, err := RequiredBigInt(args, "item_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + fields, err := OptionalBigIntArrayParam(args, "fields") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projectItem *github.ProjectV2Item - var opts *github.GetProjectItemOptions + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if len(fields) > 0 { - opts = &github.GetProjectItemOptions{ - Fields: fields, - } - } + var resp *github.Response + var projectItem *github.ProjectV2Item + var opts *github.GetProjectItemOptions - if ownerType == "org" { - projectItem, resp, err = client.Projects.GetOrganizationProjectItem(ctx, owner, projectNumber, itemID, opts) - } else { - projectItem, resp, err = client.Projects.GetUserProjectItem(ctx, owner, projectNumber, itemID, opts) + if len(fields) > 0 { + opts = &github.GetProjectItemOptions{ + Fields: fields, } + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get project item", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + projectItem, resp, err = client.Projects.GetOrganizationProjectItem(ctx, owner, projectNumber, itemID, opts) + } else { + projectItem, resp, err = client.Projects.GetUserProjectItem(ctx, owner, projectNumber, itemID, opts) + } - r, err := json.Marshal(projectItem) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get project item", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(projectItem) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -721,76 +709,74 @@ func AddProjectItem(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner_type", "owner", "project_number", "item_type", "item_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - itemID, err := RequiredBigInt(args, "item_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + itemID, err := RequiredBigInt(args, "item_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - itemType, err := RequiredParam[string](args, "item_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - if itemType != "issue" && itemType != "pull_request" { - return utils.NewToolResultError("item_type must be either 'issue' or 'pull_request'"), nil, nil - } + itemType, err := RequiredParam[string](args, "item_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if itemType != "issue" && itemType != "pull_request" { + return utils.NewToolResultError("item_type must be either 'issue' or 'pull_request'"), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - newItem := &github.AddProjectItemOptions{ - ID: itemID, - Type: toNewProjectType(itemType), - } + newItem := &github.AddProjectItemOptions{ + ID: itemID, + Type: toNewProjectType(itemType), + } - var resp *github.Response - var addedItem *github.ProjectV2Item + var resp *github.Response + var addedItem *github.ProjectV2Item - if ownerType == "org" { - addedItem, resp, err = client.Projects.AddOrganizationProjectItem(ctx, owner, projectNumber, newItem) - } else { - addedItem, resp, err = client.Projects.AddUserProjectItem(ctx, owner, projectNumber, newItem) - } + if ownerType == "org" { + addedItem, resp, err = client.Projects.AddOrganizationProjectItem(ctx, owner, projectNumber, newItem) + } else { + addedItem, resp, err = client.Projects.AddUserProjectItem(ctx, owner, projectNumber, newItem) + } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + ProjectAddFailedError, + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - ProjectAddFailedError, - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectAddFailedError, resp, body), nil, nil - } - r, err := json.Marshal(addedItem) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - - return utils.NewToolResultText(string(r)), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectAddFailedError, resp, body), nil, nil } + r, err := json.Marshal(addedItem) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -833,78 +819,76 @@ func UpdateProjectItem(t translations.TranslationHelperFunc) inventory.ServerToo Required: []string{"owner_type", "owner", "project_number", "item_id", "updated_field"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - itemID, err := RequiredBigInt(args, "item_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + itemID, err := RequiredBigInt(args, "item_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - rawUpdatedField, exists := args["updated_field"] - if !exists { - return utils.NewToolResultError("missing required parameter: updated_field"), nil, nil - } + rawUpdatedField, exists := args["updated_field"] + if !exists { + return utils.NewToolResultError("missing required parameter: updated_field"), nil, nil + } - fieldValue, ok := rawUpdatedField.(map[string]any) - if !ok || fieldValue == nil { - return utils.NewToolResultError("field_value must be an object"), nil, nil - } + fieldValue, ok := rawUpdatedField.(map[string]any) + if !ok || fieldValue == nil { + return utils.NewToolResultError("field_value must be an object"), nil, nil + } - updatePayload, err := buildUpdateProjectItem(fieldValue) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + updatePayload, err := buildUpdateProjectItem(fieldValue) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var updatedItem *github.ProjectV2Item + var resp *github.Response + var updatedItem *github.ProjectV2Item - if ownerType == "org" { - updatedItem, resp, err = client.Projects.UpdateOrganizationProjectItem(ctx, owner, projectNumber, itemID, updatePayload) - } else { - updatedItem, resp, err = client.Projects.UpdateUserProjectItem(ctx, owner, projectNumber, itemID, updatePayload) - } + if ownerType == "org" { + updatedItem, resp, err = client.Projects.UpdateOrganizationProjectItem(ctx, owner, projectNumber, itemID, updatePayload) + } else { + updatedItem, resp, err = client.Projects.UpdateUserProjectItem(ctx, owner, projectNumber, itemID, updatePayload) + } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + ProjectUpdateFailedError, + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - ProjectUpdateFailedError, - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectUpdateFailedError, resp, body), nil, nil - } - r, err := json.Marshal(updatedItem) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - - return utils.NewToolResultText(string(r)), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectUpdateFailedError, resp, body), nil, nil } + r, err := json.Marshal(updatedItem) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -943,55 +927,53 @@ func DeleteProjectItem(t translations.TranslationHelperFunc) inventory.ServerToo Required: []string{"owner_type", "owner", "project_number", "item_id"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - itemID, err := RequiredBigInt(args, "item_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + itemID, err := RequiredBigInt(args, "item_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - if ownerType == "org" { - resp, err = client.Projects.DeleteOrganizationProjectItem(ctx, owner, projectNumber, itemID) - } else { - resp, err = client.Projects.DeleteUserProjectItem(ctx, owner, projectNumber, itemID) - } + var resp *github.Response + if ownerType == "org" { + resp, err = client.Projects.DeleteOrganizationProjectItem(ctx, owner, projectNumber, itemID) + } else { + resp, err = client.Projects.DeleteUserProjectItem(ctx, owner, projectNumber, itemID) + } + + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + ProjectDeleteFailedError, + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusNoContent { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - ProjectDeleteFailedError, - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusNoContent { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectDeleteFailedError, resp, body), nil, nil + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultText("project item successfully deleted"), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectDeleteFailedError, resp, body), nil, nil } + return utils.NewToolResultText("project item successfully deleted"), nil, nil }, ) } diff --git a/pkg/github/projects_test.go b/pkg/github/projects_test.go index 67ecd8800..e443b9ecd 100644 --- a/pkg/github/projects_test.go +++ b/pkg/github/projects_test.go @@ -146,7 +146,7 @@ func Test_ListProjects(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -285,7 +285,7 @@ func Test_GetProject(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -437,7 +437,7 @@ func Test_ListProjectFields(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -597,7 +597,7 @@ func Test_GetProjectField(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -803,7 +803,7 @@ func Test_ListProjectItems(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -1000,7 +1000,7 @@ func Test_GetProjectItem(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -1230,7 +1230,7 @@ func Test_AddProjectItem(t *testing.T) { handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -1514,7 +1514,7 @@ func Test_UpdateProjectItem(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -1681,7 +1681,7 @@ func Test_DeleteProjectItem(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 16bb1bafc..d51c14fa4 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -69,68 +69,66 @@ Possible options: }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pullNumber, err := RequiredInt(args, "pullNumber") + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + switch method { + case "get": + result, err := GetPullRequest(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) + return result, nil, err + case "get_diff": + result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) + return result, nil, err + case "get_status": + result, err := GetPullRequestStatus(ctx, client, owner, repo, pullNumber) + return result, nil, err + case "get_files": + result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) + return result, nil, err + case "get_review_comments": + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil } - pagination, err := OptionalPaginationParams(args) + cursorPagination, err := OptionalCursorPaginationParams(args) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - switch method { - case "get": - result, err := GetPullRequest(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) - return result, nil, err - case "get_diff": - result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) - return result, nil, err - case "get_status": - result, err := GetPullRequestStatus(ctx, client, owner, repo, pullNumber) - return result, nil, err - case "get_files": - result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) - return result, nil, err - case "get_review_comments": - gqlClient, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } - cursorPagination, err := OptionalCursorPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - result, err := GetPullRequestReviewComments(ctx, gqlClient, deps.GetRepoAccessCache(), owner, repo, pullNumber, cursorPagination, deps.GetFlags()) - return result, nil, err - case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) - return result, nil, err - case "get_comments": - result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) - return result, nil, err - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil - } + result, err := GetPullRequestReviewComments(ctx, gqlClient, deps.GetRepoAccessCache(), owner, repo, pullNumber, cursorPagination, deps.GetFlags()) + return result, nil, err + case "get_reviews": + result, err := GetPullRequestReviews(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) + return result, nil, err + case "get_comments": + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) + return result, nil, err + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil } }) } @@ -520,92 +518,90 @@ func CreatePullRequest(t translations.TranslationHelperFunc) inventory.ServerToo }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - title, err := RequiredParam[string](args, "title") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - head, err := RequiredParam[string](args, "head") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - base, err := RequiredParam[string](args, "base") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - body, err := OptionalParam[string](args, "body") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - draft, err := OptionalParam[bool](args, "draft") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + title, err := RequiredParam[string](args, "title") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + head, err := RequiredParam[string](args, "head") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + base, err := RequiredParam[string](args, "base") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - maintainerCanModify, err := OptionalParam[bool](args, "maintainer_can_modify") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + body, err := OptionalParam[string](args, "body") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - newPR := &github.NewPullRequest{ - Title: github.Ptr(title), - Head: github.Ptr(head), - Base: github.Ptr(base), - } + draft, err := OptionalParam[bool](args, "draft") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if body != "" { - newPR.Body = github.Ptr(body) - } + maintainerCanModify, err := OptionalParam[bool](args, "maintainer_can_modify") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - newPR.Draft = github.Ptr(draft) - newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) + newPR := &github.NewPullRequest{ + Title: github.Ptr(title), + Head: github.Ptr(head), + Base: github.Ptr(base), + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create pull request", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if body != "" { + newPR.Body = github.Ptr(body) + } - if resp.StatusCode != http.StatusCreated { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create pull request", resp, bodyBytes), nil, nil - } + newPR.Draft = github.Ptr(draft) + newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) - // Return minimal response with just essential information - minimalResponse := MinimalResponse{ - ID: fmt.Sprintf("%d", pr.GetID()), - URL: pr.GetHTMLURL(), - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create pull request", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalResponse) + if resp.StatusCode != http.StatusCreated { + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create pull request", resp, bodyBytes), nil, nil + } + + // Return minimal response with just essential information + minimalResponse := MinimalResponse{ + ID: fmt.Sprintf("%d", pr.GetID()), + URL: pr.GetHTMLURL(), + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }) } @@ -673,214 +669,188 @@ func UpdatePullRequest(t translations.TranslationHelperFunc) inventory.ServerToo }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pullNumber, err := RequiredInt(args, "pullNumber") + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + _, draftProvided := args["draft"] + var draftValue bool + if draftProvided { + draftValue, err = OptionalParam[bool](args, "draft") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } + } - _, draftProvided := args["draft"] - var draftValue bool - if draftProvided { - draftValue, err = OptionalParam[bool](args, "draft") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - } + update := &github.PullRequest{} + restUpdateNeeded := false - update := &github.PullRequest{} - restUpdateNeeded := false + if title, ok, err := OptionalParamOK[string](args, "title"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.Title = github.Ptr(title) + restUpdateNeeded = true + } - if title, ok, err := OptionalParamOK[string](args, "title"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.Title = github.Ptr(title) - restUpdateNeeded = true - } + if body, ok, err := OptionalParamOK[string](args, "body"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.Body = github.Ptr(body) + restUpdateNeeded = true + } - if body, ok, err := OptionalParamOK[string](args, "body"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.Body = github.Ptr(body) - restUpdateNeeded = true - } + if state, ok, err := OptionalParamOK[string](args, "state"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.State = github.Ptr(state) + restUpdateNeeded = true + } - if state, ok, err := OptionalParamOK[string](args, "state"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.State = github.Ptr(state) - restUpdateNeeded = true - } + if base, ok, err := OptionalParamOK[string](args, "base"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} + restUpdateNeeded = true + } - if base, ok, err := OptionalParamOK[string](args, "base"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} - restUpdateNeeded = true - } + if maintainerCanModify, ok, err := OptionalParamOK[bool](args, "maintainer_can_modify"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.MaintainerCanModify = github.Ptr(maintainerCanModify) + restUpdateNeeded = true + } - if maintainerCanModify, ok, err := OptionalParamOK[bool](args, "maintainer_can_modify"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.MaintainerCanModify = github.Ptr(maintainerCanModify) - restUpdateNeeded = true - } + // Handle reviewers separately + reviewers, err := OptionalStringArrayParam(args, "reviewers") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + // If no updates, no draft change, and no reviewers, return error early + if !restUpdateNeeded && !draftProvided && len(reviewers) == 0 { + return utils.NewToolResultError("No update parameters provided."), nil, nil + } - // Handle reviewers separately - reviewers, err := OptionalStringArrayParam(args, "reviewers") + // Handle REST API updates (title, body, state, base, maintainer_can_modify) + if restUpdateNeeded { + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - // If no updates, no draft change, and no reviewers, return error early - if !restUpdateNeeded && !draftProvided && len(reviewers) == 0 { - return utils.NewToolResultError("No update parameters provided."), nil, nil + _, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + resp, + err, + ), nil, nil } + defer func() { _ = resp.Body.Close() }() - // Handle REST API updates (title, body, state, base, maintainer_can_modify) - if restUpdateNeeded { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - _, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update pull request", resp, bodyBytes), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update pull request", resp, bodyBytes), nil, nil } + } - // Handle draft status changes using GraphQL - if draftProvided { - gqlClient, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil - } - - var prQuery struct { - Repository struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - - err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers - }) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil, nil - } - - currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) - - if currentIsDraft != draftValue { - if draftValue { - // Convert to draft - var mutation struct { - ConvertPullRequestToDraft struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"convertPullRequestToDraft(input: $input)"` - } + // Handle draft status changes using GraphQL + if draftProvided { + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil + } - err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ - PullRequestID: prQuery.Repository.PullRequest.ID, - }, nil) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil, nil - } - } else { - // Mark as ready for review - var mutation struct { - MarkPullRequestReadyForReview struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"markPullRequestReadyForReview(input: $input)"` - } + var prQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } - err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ - PullRequestID: prQuery.Repository.PullRequest.ID, - }, nil) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil, nil - } - } - } + err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers + }) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil, nil } - // Handle reviewer requests - if len(reviewers) > 0 { - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) - reviewersRequest := github.ReviewersRequest{ - Reviewers: reviewers, - } + if currentIsDraft != draftValue { + if draftValue { + // Convert to draft + var mutation struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + } - _, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to request reviewers", - resp, - err, - ), nil, nil - } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() + err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil, nil + } + } else { + // Mark as ready for review + var mutation struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` } - }() - if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) + err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil, nil } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to request reviewers", resp, bodyBytes), nil, nil } } + } - // Get the final state of the PR to return + // Handle reviewer requests + if len(reviewers) > 0 { client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) + reviewersRequest := github.ReviewersRequest{ + Reviewers: reviewers, + } + + _, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil, nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request reviewers", + resp, + err, + ), nil, nil } defer func() { if resp != nil && resp.Body != nil { @@ -888,19 +858,43 @@ func UpdatePullRequest(t translations.TranslationHelperFunc) inventory.ServerToo } }() - // Return minimal response with just essential information - minimalResponse := MinimalResponse{ - ID: fmt.Sprintf("%d", finalPR.GetID()), - URL: finalPR.GetHTMLURL(), + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to request reviewers", resp, bodyBytes), nil, nil } + } - r, err := json.Marshal(minimalResponse) - if err != nil { - return utils.NewToolResultErrorFromErr("Failed to marshal response", err), nil, nil + // Get the final state of the PR to return + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil, nil + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() } + }() + + // Return minimal response with just essential information + minimalResponse := MinimalResponse{ + ID: fmt.Sprintf("%d", finalPR.GetID()), + URL: finalPR.GetHTMLURL(), + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("Failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }) } @@ -956,95 +950,93 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - head, err := OptionalParam[string](args, "head") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - base, err := OptionalParam[string](args, "base") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + head, err := OptionalParam[string](args, "head") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + base, err := OptionalParam[string](args, "base") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.PullRequestListOptions{ - State: state, - Head: head, - Base: base, - Sort: sort, - Direction: direction, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + opts := &github.PullRequestListOptions{ + State: state, + Head: head, + Base: base, + Sort: sort, + Direction: direction, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list pull requests", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list pull requests", - resp, - err, - ), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - defer func() { _ = resp.Body.Close() }() + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list pull requests", resp, bodyBytes), nil, nil + } - if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list pull requests", resp, bodyBytes), nil, nil + // sanitize title/body on each PR + for _, pr := range prs { + if pr == nil { + continue } - - // sanitize title/body on each PR - for _, pr := range prs { - if pr == nil { - continue - } - if pr.Title != nil { - pr.Title = github.Ptr(sanitize.Sanitize(*pr.Title)) - } - if pr.Body != nil { - pr.Body = github.Ptr(sanitize.Sanitize(*pr.Body)) - } + if pr.Title != nil { + pr.Title = github.Ptr(sanitize.Sanitize(*pr.Title)) } - - r, err := json.Marshal(prs) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + if pr.Body != nil { + pr.Body = github.Ptr(sanitize.Sanitize(*pr.Body)) } + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(prs) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }) } @@ -1094,67 +1086,65 @@ func MergePullRequest(t translations.TranslationHelperFunc) inventory.ServerTool }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - commitTitle, err := OptionalParam[string](args, "commit_title") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - commitMessage, err := OptionalParam[string](args, "commit_message") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - mergeMethod, err := OptionalParam[string](args, "merge_method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - options := &github.PullRequestOptions{ - CommitTitle: commitTitle, - MergeMethod: mergeMethod, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + commitTitle, err := OptionalParam[string](args, "commit_title") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + commitMessage, err := OptionalParam[string](args, "commit_message") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + mergeMethod, err := OptionalParam[string](args, "merge_method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to merge pull request", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + options := &github.PullRequestOptions{ + CommitTitle: commitTitle, + MergeMethod: mergeMethod, + } - if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to merge pull request", resp, bodyBytes), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to merge pull request", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(result) + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to merge pull request", resp, bodyBytes), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }) } @@ -1213,11 +1203,9 @@ func SearchPullRequests(t translations.TranslationHelperFunc) inventory.ServerTo }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - result, err := searchHandler(ctx, deps.GetClient, args, "pr", "failed to search pull requests") - return result, nil, err - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + result, err := searchHandler(ctx, deps.GetClient, args, "pr", "failed to search pull requests") + return result, nil, err }) } @@ -1257,63 +1245,61 @@ func UpdatePullRequestBranch(t translations.TranslationHelperFunc) inventory.Ser }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - expectedHeadSHA, err := OptionalParam[string](args, "expectedHeadSha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - opts := &github.PullRequestBranchUpdateOptions{} - if expectedHeadSHA != "" { - opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) - } - - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) - if err != nil { - // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, - // and it's not a real error. - if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { - return utils.NewToolResultText("Pull request branch update is in progress"), nil, nil - } - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request branch", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + expectedHeadSHA, err := OptionalParam[string](args, "expectedHeadSha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + opts := &github.PullRequestBranchUpdateOptions{} + if expectedHeadSHA != "" { + opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) + } - if resp.StatusCode != http.StatusAccepted { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update pull request branch", resp, bodyBytes), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) + if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return utils.NewToolResultText("Pull request branch update is in progress"), nil, nil + } + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request branch", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(result) + if resp.StatusCode != http.StatusAccepted { + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update pull request branch", resp, bodyBytes), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }) } @@ -1385,32 +1371,30 @@ Available methods: }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - var params PullRequestReviewWriteParams - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + var params PullRequestReviewWriteParams + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Given our owner, repo and PR number, lookup the GQL ID of the PR. - client, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + // Given our owner, repo and PR number, lookup the GQL ID of the PR. + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - switch params.Method { - case "create": - result, err := CreatePullRequestReview(ctx, client, params) - return result, nil, err - case "submit_pending": - result, err := SubmitPendingPullRequestReview(ctx, client, params) - return result, nil, err - case "delete_pending": - result, err := DeletePendingPullRequestReview(ctx, client, params) - return result, nil, err - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", params.Method)), nil, nil - } + switch params.Method { + case "create": + result, err := CreatePullRequestReview(ctx, client, params) + return result, nil, err + case "submit_pending": + result, err := SubmitPendingPullRequestReview(ctx, client, params) + return result, nil, err + case "delete_pending": + result, err := DeletePendingPullRequestReview(ctx, client, params) + return result, nil, err + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", params.Method)), nil, nil } }) } @@ -1710,122 +1694,120 @@ func AddCommentToPendingReview(t translations.TranslationHelperFunc) inventory.S }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - var params struct { - Owner string - Repo string - PullNumber int32 - Path string - Body string - SubjectType string - Line *int32 - Side *string - StartLine *int32 - StartSide *string - } - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + var params struct { + Owner string + Repo string + PullNumber int32 + Path string + Body string + SubjectType string + Line *int32 + Side *string + StartLine *int32 + StartSide *string + } + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } - // First we'll get the current user - var getViewerQuery struct { - Viewer struct { - Login githubv4.String - } + // First we'll get the current user + var getViewerQuery struct { + Viewer struct { + Login githubv4.String } + } - if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, - "failed to get current user", - err, - ), nil, nil - } + if err := client.Query(ctx, &getViewerQuery, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get current user", + err, + ), nil, nil + } - var getLatestReviewForViewerQuery struct { - Repository struct { - PullRequest struct { - Reviews struct { - Nodes []struct { - ID githubv4.ID - State githubv4.PullRequestReviewState - URL githubv4.URI - } - } `graphql:"reviews(first: 1, author: $author)"` - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $name)"` - } + var getLatestReviewForViewerQuery struct { + Repository struct { + PullRequest struct { + Reviews struct { + Nodes []struct { + ID githubv4.ID + State githubv4.PullRequestReviewState + URL githubv4.URI + } + } `graphql:"reviews(first: 1, author: $author)"` + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } - vars := map[string]any{ - "author": githubv4.String(getViewerQuery.Viewer.Login), - "owner": githubv4.String(params.Owner), - "name": githubv4.String(params.Repo), - "prNum": githubv4.Int(params.PullNumber), - } + vars := map[string]any{ + "author": githubv4.String(getViewerQuery.Viewer.Login), + "owner": githubv4.String(params.Owner), + "name": githubv4.String(params.Repo), + "prNum": githubv4.Int(params.PullNumber), + } - if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, - "failed to get latest review for current user", - err, - ), nil, nil - } + if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get latest review for current user", + err, + ), nil, nil + } - // Validate there is one review and the state is pending - if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 { - return utils.NewToolResultError("No pending review found for the viewer"), nil, nil - } + // Validate there is one review and the state is pending + if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 { + return utils.NewToolResultError("No pending review found for the viewer"), nil, nil + } - review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0] - if review.State != githubv4.PullRequestReviewStatePending { - errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL) - return utils.NewToolResultError(errText), nil, nil - } + review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0] + if review.State != githubv4.PullRequestReviewStatePending { + errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL) + return utils.NewToolResultError(errText), nil, nil + } - // Then we can create a new review thread comment on the review. - var addPullRequestReviewThreadMutation struct { - AddPullRequestReviewThread struct { - Thread struct { - ID githubv4.ID // We don't need this, but a selector is required or GQL complains. - } - } `graphql:"addPullRequestReviewThread(input: $input)"` - } + // Then we can create a new review thread comment on the review. + var addPullRequestReviewThreadMutation struct { + AddPullRequestReviewThread struct { + Thread struct { + ID githubv4.ID // We don't need this, but a selector is required or GQL complains. + } + } `graphql:"addPullRequestReviewThread(input: $input)"` + } - if err := client.Mutate( - ctx, - &addPullRequestReviewThreadMutation, - githubv4.AddPullRequestReviewThreadInput{ - Path: githubv4.String(params.Path), - Body: githubv4.String(params.Body), - SubjectType: newGQLStringlikePtr[githubv4.PullRequestReviewThreadSubjectType](¶ms.SubjectType), - Line: newGQLIntPtr(params.Line), - Side: newGQLStringlikePtr[githubv4.DiffSide](params.Side), - StartLine: newGQLIntPtr(params.StartLine), - StartSide: newGQLStringlikePtr[githubv4.DiffSide](params.StartSide), - PullRequestReviewID: &review.ID, - }, - nil, - ); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + if err := client.Mutate( + ctx, + &addPullRequestReviewThreadMutation, + githubv4.AddPullRequestReviewThreadInput{ + Path: githubv4.String(params.Path), + Body: githubv4.String(params.Body), + SubjectType: newGQLStringlikePtr[githubv4.PullRequestReviewThreadSubjectType](¶ms.SubjectType), + Line: newGQLIntPtr(params.Line), + Side: newGQLStringlikePtr[githubv4.DiffSide](params.Side), + StartLine: newGQLIntPtr(params.StartLine), + StartSide: newGQLStringlikePtr[githubv4.DiffSide](params.StartSide), + PullRequestReviewID: &review.ID, + }, + nil, + ); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if addPullRequestReviewThreadMutation.AddPullRequestReviewThread.Thread.ID == nil { - return utils.NewToolResultError(`Failed to add comment to pending review. Possible reasons: + if addPullRequestReviewThreadMutation.AddPullRequestReviewThread.Thread.ID == nil { + return utils.NewToolResultError(`Failed to add comment to pending review. Possible reasons: - The line number doesn't exist in the pull request diff - The file path is incorrect - The side (LEFT/RIGHT) is invalid for the specified line `), nil, nil - } - - // Return nothing interesting, just indicate success for the time being. - // In future, we may want to return the review ID, but for the moment, we're not leaking - // API implementation details to the LLM. - return utils.NewToolResultText("pull request review comment successfully added to pending review"), nil, nil } + + // Return nothing interesting, just indicate success for the time being. + // In future, we may want to return the review ID, but for the moment, we're not leaking + // API implementation details to the LLM. + return utils.NewToolResultText("pull request review comment successfully added to pending review"), nil, nil }) } @@ -1864,58 +1846,56 @@ func RequestCopilotReview(t translations.TranslationHelperFunc) inventory.Server }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - _, resp, err := client.PullRequests.RequestReviewers( - ctx, - owner, - repo, - pullNumber, - github.ReviewersRequest{ - // The login name of the copilot reviewer bot - Reviewers: []string{"copilot-pull-request-reviewer[bot]"}, - }, - ) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to request copilot review", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + _, resp, err := client.PullRequests.RequestReviewers( + ctx, + owner, + repo, + pullNumber, + github.ReviewersRequest{ + // The login name of the copilot reviewer bot + Reviewers: []string{"copilot-pull-request-reviewer[bot]"}, + }, + ) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request copilot review", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusCreated { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to request copilot review", resp, bodyBytes), nil, nil + if resp.StatusCode != http.StatusCreated { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - - // Return nothing on success, as there's not much value in returning the Pull Request itself - return utils.NewToolResultText(""), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to request copilot review", resp, bodyBytes), nil, nil } + + // Return nothing on success, as there's not much value in returning the Pull Request itself + return utils.NewToolResultText(""), nil, nil }) } diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 71ccf33bb..3cb41515d 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -118,7 +118,7 @@ func Test_GetPullRequest(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -381,7 +381,7 @@ func Test_UpdatePullRequest(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError || tc.expectedErrMsg != "" { @@ -570,7 +570,7 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError || tc.expectedErrMsg != "" { require.NoError(t, err) @@ -703,7 +703,7 @@ func Test_ListPullRequests(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -827,7 +827,7 @@ func Test_MergePullRequest(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1133,7 +1133,7 @@ func Test_SearchPullRequests(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1288,7 +1288,7 @@ func Test_GetPullRequestFiles(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1463,7 +1463,7 @@ func Test_GetPullRequestStatus(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1599,7 +1599,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1919,7 +1919,7 @@ func Test_GetPullRequestComments(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2097,7 +2097,7 @@ func Test_GetPullRequestReviews(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2259,7 +2259,7 @@ func Test_CreatePullRequest(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2475,7 +2475,7 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2589,7 +2589,7 @@ func Test_RequestCopilotReview(t *testing.T) { request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) @@ -2785,7 +2785,7 @@ func TestCreatePendingPullRequestReview(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2968,7 +2968,7 @@ func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -3073,7 +3073,7 @@ func TestSubmitPendingPullRequestReview(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -3172,7 +3172,7 @@ func TestDeletePendingPullRequestReview(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -3265,7 +3265,7 @@ index 5d6e7b2..8a4f5c3 100644 request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 27d0d76fd..d8d2b27b3 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -54,66 +54,64 @@ func GetCommit(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "sha"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := RequiredParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - includeDiff, err := OptionalBoolParamWithDefault(args, "include_diff", true) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get commit: %s", sha), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sha, err := RequiredParam[string](args, "sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + includeDiff, err := OptionalBoolParamWithDefault(args, "include_diff", true) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get commit", resp, body), nil, nil - } + opts := &github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + } - // Convert to minimal commit - minimalCommit := convertToMinimalCommit(commit, includeDiff) + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get commit: %s", sha), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalCommit) + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get commit", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + // Convert to minimal commit + minimalCommit := convertToMinimalCommit(commit, includeDiff) + + r, err := json.Marshal(minimalCommit) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -152,77 +150,75 @@ func ListCommits(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := OptionalParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - author, err := OptionalParam[string](args, "author") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // Set default perPage to 30 if not provided - perPage := pagination.PerPage - if perPage == 0 { - perPage = 30 - } - opts := &github.CommitsListOptions{ - SHA: sha, - Author: author, - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: perPage, - }, - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list commits: %s", sha), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list commits", resp, body), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sha, err := OptionalParam[string](args, "sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + author, err := OptionalParam[string](args, "author") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // Set default perPage to 30 if not provided + perPage := pagination.PerPage + if perPage == 0 { + perPage = 30 + } + opts := &github.CommitsListOptions{ + SHA: sha, + Author: author, + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: perPage, + }, + } - // Convert to minimal commits - minimalCommits := make([]MinimalCommit, len(commits)) - for i, commit := range commits { - minimalCommits[i] = convertToMinimalCommit(commit, false) - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list commits: %s", sha), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalCommits) + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list commits", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + // Convert to minimal commits + minimalCommits := make([]MinimalCommit, len(commits)) + for i, commit := range commits { + minimalCommits[i] = convertToMinimalCommit(commit, false) + } + + r, err := json.Marshal(minimalCommits) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -253,64 +249,62 @@ func ListBranches(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.BranchListOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list branches", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + opts := &github.BranchListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list branches", resp, body), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - // Convert to minimal branches - minimalBranches := make([]MinimalBranch, 0, len(branches)) - for _, branch := range branches { - minimalBranches = append(minimalBranches, convertToMinimalBranch(branch)) - } + branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list branches", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalBranches) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list branches", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + // Convert to minimal branches + minimalBranches := make([]MinimalBranch, 0, len(branches)) + for _, branch := range branches { + minimalBranches = append(minimalBranches, convertToMinimalBranch(branch)) + } + + r, err := json.Marshal(minimalBranches) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -368,149 +362,147 @@ If the SHA is not provided, the tool will attempt to acquire it by fetching the Required: []string{"owner", "repo", "path", "content", "message", "branch"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - path, err := RequiredParam[string](args, "path") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - message, err := RequiredParam[string](args, "message") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := RequiredParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + path, err := RequiredParam[string](args, "path") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + message, err := RequiredParam[string](args, "message") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := RequiredParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // json.Marshal encodes byte arrays with base64, which is required for the API. - contentBytes := []byte(content) + // json.Marshal encodes byte arrays with base64, which is required for the API. + contentBytes := []byte(content) - // Create the file options - opts := &github.RepositoryContentFileOptions{ - Message: github.Ptr(message), - Content: contentBytes, - Branch: github.Ptr(branch), - } + // Create the file options + opts := &github.RepositoryContentFileOptions{ + Message: github.Ptr(message), + Content: contentBytes, + Branch: github.Ptr(branch), + } - // If SHA is provided, set it (for updates) - sha, err := OptionalParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - if sha != "" { - opts.SHA = github.Ptr(sha) - } + // If SHA is provided, set it (for updates) + sha, err := OptionalParam[string](args, "sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if sha != "" { + opts.SHA = github.Ptr(sha) + } - // Create or update the file - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Create or update the file + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - path = strings.TrimPrefix(path, "/") + path = strings.TrimPrefix(path, "/") - // SHA validation using conditional HEAD request (efficient - no body transfer) - var previousSHA string - contentURL := fmt.Sprintf("repos/%s/%s/contents/%s", owner, repo, url.PathEscape(path)) - if branch != "" { - contentURL += "?ref=" + url.QueryEscape(branch) - } + // SHA validation using conditional HEAD request (efficient - no body transfer) + var previousSHA string + contentURL := fmt.Sprintf("repos/%s/%s/contents/%s", owner, repo, url.PathEscape(path)) + if branch != "" { + contentURL += "?ref=" + url.QueryEscape(branch) + } - if sha != "" { - // User provided SHA - validate it's still current - req, err := client.NewRequest("HEAD", contentURL, nil) - if err == nil { - req.Header.Set("If-None-Match", fmt.Sprintf(`"%s"`, sha)) - resp, _ := client.Do(ctx, req, nil) - if resp != nil { - defer resp.Body.Close() - - switch resp.StatusCode { - case http.StatusNotModified: - // SHA matches current - proceed - opts.SHA = github.Ptr(sha) - case http.StatusOK: - // SHA is stale - reject with current SHA so user can check diff - currentSHA := strings.Trim(resp.Header.Get("ETag"), `"`) - return utils.NewToolResultError(fmt.Sprintf( - "SHA mismatch: provided SHA %s is stale. Current file SHA is %s. "+ - "Use get_file_contents or compare commits to review changes before updating.", - sha, currentSHA)), nil, nil - case http.StatusNotFound: - // File doesn't exist - this is a create, ignore provided SHA - } + if sha != "" { + // User provided SHA - validate it's still current + req, err := client.NewRequest("HEAD", contentURL, nil) + if err == nil { + req.Header.Set("If-None-Match", fmt.Sprintf(`"%s"`, sha)) + resp, _ := client.Do(ctx, req, nil) + if resp != nil { + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusNotModified: + // SHA matches current - proceed + opts.SHA = github.Ptr(sha) + case http.StatusOK: + // SHA is stale - reject with current SHA so user can check diff + currentSHA := strings.Trim(resp.Header.Get("ETag"), `"`) + return utils.NewToolResultError(fmt.Sprintf( + "SHA mismatch: provided SHA %s is stale. Current file SHA is %s. "+ + "Use get_file_contents or compare commits to review changes before updating.", + sha, currentSHA)), nil, nil + case http.StatusNotFound: + // File doesn't exist - this is a create, ignore provided SHA } } - } else { - // No SHA provided - check if file exists to warn about blind update - req, err := client.NewRequest("HEAD", contentURL, nil) - if err == nil { - resp, _ := client.Do(ctx, req, nil) - if resp != nil { - defer resp.Body.Close() - if resp.StatusCode == http.StatusOK { - previousSHA = strings.Trim(resp.Header.Get("ETag"), `"`) - } - // 404 = new file, no previous SHA needed + } + } else { + // No SHA provided - check if file exists to warn about blind update + req, err := client.NewRequest("HEAD", contentURL, nil) + if err == nil { + resp, _ := client.Do(ctx, req, nil) + if resp != nil { + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + previousSHA = strings.Trim(resp.Header.Get("ETag"), `"`) } + // 404 = new file, no previous SHA needed } } + } - if previousSHA != "" { - opts.SHA = github.Ptr(previousSHA) - } - - fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create/update file", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if previousSHA != "" { + opts.SHA = github.Ptr(previousSHA) + } - if resp.StatusCode != 200 && resp.StatusCode != 201 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create/update file", resp, body), nil, nil - } + fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create/update file", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(fileContent) + if resp.StatusCode != 200 && resp.StatusCode != 201 { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create/update file", resp, body), nil, nil + } - // Warn if file was updated without SHA validation (blind update) - if sha == "" && previousSHA != "" { - return utils.NewToolResultText(fmt.Sprintf( - "Warning: File updated without SHA validation. Previous file SHA was %s. "+ - `Verify no unintended changes were overwritten: + r, err := json.Marshal(fileContent) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + // Warn if file was updated without SHA validation (blind update) + if sha == "" && previousSHA != "" { + return utils.NewToolResultText(fmt.Sprintf( + "Warning: File updated without SHA validation. Previous file SHA was %s. "+ + `Verify no unintended changes were overwritten: 1. Extract the SHA of the local version using git ls-tree HEAD %s. 2. Compare with the previous SHA above. 3. Revert changes if shas do not match. %s`, - previousSHA, path, string(r))), nil, nil - } - - return utils.NewToolResultText(string(r)), nil, nil + previousSHA, path, string(r))), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -553,71 +545,69 @@ func CreateRepository(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"name"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - name, err := RequiredParam[string](args, "name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - organization, err := OptionalParam[string](args, "organization") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - private, err := OptionalParam[bool](args, "private") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - autoInit, err := OptionalParam[bool](args, "autoInit") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo := &github.Repository{ - Name: github.Ptr(name), - Description: github.Ptr(description), - Private: github.Ptr(private), - AutoInit: github.Ptr(autoInit), - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - createdRepo, resp, err := client.Repositories.Create(ctx, organization, repo) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create repository", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + name, err := RequiredParam[string](args, "name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + organization, err := OptionalParam[string](args, "organization") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + private, err := OptionalParam[bool](args, "private") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + autoInit, err := OptionalParam[bool](args, "autoInit") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create repository", resp, body), nil, nil - } + repo := &github.Repository{ + Name: github.Ptr(name), + Description: github.Ptr(description), + Private: github.Ptr(private), + AutoInit: github.Ptr(autoInit), + } - // Return minimal response with just essential information - minimalResponse := MinimalResponse{ - ID: fmt.Sprintf("%d", createdRepo.GetID()), - URL: createdRepo.GetHTMLURL(), - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + createdRepo, resp, err := client.Repositories.Create(ctx, organization, repo) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create repository", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalResponse) + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create repository", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + // Return minimal response with just essential information + minimalResponse := MinimalResponse{ + ID: fmt.Sprintf("%d", createdRepo.GetID()), + URL: createdRepo.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -661,148 +651,146 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - path, err := OptionalParam[string](args, "path") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - path = strings.TrimPrefix(path, "/") + path, err := OptionalParam[string](args, "path") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + path = strings.TrimPrefix(path, "/") - ref, err := OptionalParam[string](args, "ref") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := OptionalParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ref, err := OptionalParam[string](args, "ref") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sha, err := OptionalParam[string](args, "sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultError("failed to get GitHub client"), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError("failed to get GitHub client"), nil, nil + } - rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to resolve git reference: %s", err)), nil, nil - } + rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to resolve git reference: %s", err)), nil, nil + } - if rawOpts.SHA != "" { - ref = rawOpts.SHA - } + if rawOpts.SHA != "" { + ref = rawOpts.SHA + } - var fileSHA string - opts := &github.RepositoryContentGetOptions{Ref: ref} + var fileSHA string + opts := &github.RepositoryContentGetOptions{Ref: ref} - // Always call GitHub Contents API first to get metadata including SHA and determine if it's a file or directory - fileContent, dirContent, respContents, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) - if respContents != nil { - defer func() { _ = respContents.Body.Close() }() - } + // Always call GitHub Contents API first to get metadata including SHA and determine if it's a file or directory + fileContent, dirContent, respContents, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) + if respContents != nil { + defer func() { _ = respContents.Body.Close() }() + } - // The path does not point to a file or directory. - // Instead let's try to find it in the Git Tree by matching the end of the path. - if err != nil || (fileContent == nil && dirContent == nil) { - return matchFiles(ctx, client, owner, repo, ref, path, rawOpts, 0) - } + // The path does not point to a file or directory. + // Instead let's try to find it in the Git Tree by matching the end of the path. + if err != nil || (fileContent == nil && dirContent == nil) { + return matchFiles(ctx, client, owner, repo, ref, path, rawOpts, 0) + } - if fileContent != nil && fileContent.SHA != nil { - fileSHA = *fileContent.SHA + if fileContent != nil && fileContent.SHA != nil { + fileSHA = *fileContent.SHA - rawClient, err := deps.GetRawClient(ctx) - if err != nil { - return utils.NewToolResultError("failed to get GitHub raw content client"), nil, nil - } - resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts) + rawClient, err := deps.GetRawClient(ctx) + if err != nil { + return utils.NewToolResultError("failed to get GitHub raw content client"), nil, nil + } + resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts) + if err != nil { + return utils.NewToolResultError("failed to get raw repository content"), nil, nil + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode == http.StatusOK { + // If the raw content is found, return it directly + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultError("failed to get raw repository content"), nil, nil + return utils.NewToolResultError("failed to read response body"), nil, nil } - defer func() { - _ = resp.Body.Close() - }() + contentType := resp.Header.Get("Content-Type") - if resp.StatusCode == http.StatusOK { - // If the raw content is found, return it directly - body, err := io.ReadAll(resp.Body) + var resourceURI string + switch { + case sha != "": + resourceURI, err = url.JoinPath("repo://", owner, repo, "sha", sha, "contents", path) if err != nil { - return utils.NewToolResultError("failed to read response body"), nil, nil + return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) } - contentType := resp.Header.Get("Content-Type") - - var resourceURI string - switch { - case sha != "": - resourceURI, err = url.JoinPath("repo://", owner, repo, "sha", sha, "contents", path) - if err != nil { - return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) - } - case ref != "": - resourceURI, err = url.JoinPath("repo://", owner, repo, ref, "contents", path) - if err != nil { - return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) - } - default: - resourceURI, err = url.JoinPath("repo://", owner, repo, "contents", path) - if err != nil { - return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) - } + case ref != "": + resourceURI, err = url.JoinPath("repo://", owner, repo, ref, "contents", path) + if err != nil { + return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) } - - // Determine if content is text or binary - isTextContent := strings.HasPrefix(contentType, "text/") || - contentType == "application/json" || - contentType == "application/xml" || - strings.HasSuffix(contentType, "+json") || - strings.HasSuffix(contentType, "+xml") - - if isTextContent { - result := &mcp.ResourceContents{ - URI: resourceURI, - Text: string(body), - MIMEType: contentType, - } - // Include SHA in the result metadata - if fileSHA != "" { - return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA), result), nil, nil - } - return utils.NewToolResultResource("successfully downloaded text file", result), nil, nil + default: + resourceURI, err = url.JoinPath("repo://", owner, repo, "contents", path) + if err != nil { + return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) } + } + // Determine if content is text or binary + isTextContent := strings.HasPrefix(contentType, "text/") || + contentType == "application/json" || + contentType == "application/xml" || + strings.HasSuffix(contentType, "+json") || + strings.HasSuffix(contentType, "+xml") + + if isTextContent { result := &mcp.ResourceContents{ URI: resourceURI, - Blob: body, + Text: string(body), MIMEType: contentType, } // Include SHA in the result metadata if fileSHA != "" { - return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA), result), nil, nil + return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA), result), nil, nil } - return utils.NewToolResultResource("successfully downloaded binary file", result), nil, nil + return utils.NewToolResultResource("successfully downloaded text file", result), nil, nil } - // Raw API call failed - return matchFiles(ctx, client, owner, repo, ref, path, rawOpts, resp.StatusCode) - } else if dirContent != nil { - // file content or file SHA is nil which means it's a directory - r, err := json.Marshal(dirContent) - if err != nil { - return utils.NewToolResultError("failed to marshal response"), nil, nil + result := &mcp.ResourceContents{ + URI: resourceURI, + Blob: body, + MIMEType: contentType, } - return utils.NewToolResultText(string(r)), nil, nil + // Include SHA in the result metadata + if fileSHA != "" { + return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA), result), nil, nil + } + return utils.NewToolResultResource("successfully downloaded binary file", result), nil, nil } - return utils.NewToolResultError("failed to get file contents"), nil, nil + // Raw API call failed + return matchFiles(ctx, client, owner, repo, ref, path, rawOpts, resp.StatusCode) + } else if dirContent != nil { + // file content or file SHA is nil which means it's a directory + r, err := json.Marshal(dirContent) + if err != nil { + return utils.NewToolResultError("failed to marshal response"), nil, nil + } + return utils.NewToolResultText(string(r)), nil, nil } + + return utils.NewToolResultError("failed to get file contents"), nil, nil }, ) } @@ -838,66 +826,64 @@ func ForkRepository(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - org, err := OptionalParam[string](args, "organization") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.RepositoryCreateForkOptions{} - if org != "" { - opts.Organization = org - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) - if err != nil { - // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, - // and it's not a real error. - if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { - return utils.NewToolResultText("Fork is in progress"), nil, nil - } - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to fork repository", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + org, err := OptionalParam[string](args, "organization") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != http.StatusAccepted { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to fork repository", resp, body), nil, nil - } + opts := &github.RepositoryCreateForkOptions{} + if org != "" { + opts.Organization = org + } - // Return minimal response with just essential information - minimalResponse := MinimalResponse{ - ID: fmt.Sprintf("%d", forkedRepo.GetID()), - URL: forkedRepo.GetHTMLURL(), - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) + if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return utils.NewToolResultText("Fork is in progress"), nil, nil + } + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to fork repository", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalResponse) + if resp.StatusCode != http.StatusAccepted { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to fork repository", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + // Return minimal response with just essential information + minimalResponse := MinimalResponse{ + ID: fmt.Sprintf("%d", forkedRepo.GetID()), + URL: forkedRepo.GetHTMLURL(), } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -946,149 +932,147 @@ func DeleteFile(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "path", "message", "branch"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - path, err := RequiredParam[string](args, "path") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - message, err := RequiredParam[string](args, "message") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := RequiredParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + path, err := RequiredParam[string](args, "path") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + message, err := RequiredParam[string](args, "message") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := RequiredParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - // Get the reference for the branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) - if err != nil { - return nil, nil, fmt.Errorf("failed to get branch reference: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Get the reference for the branch + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) + if err != nil { + return nil, nil, fmt.Errorf("failed to get branch reference: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // Get the commit object that the branch points to + baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get base commit", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Get the commit object that the branch points to - baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get base commit", - resp, - err, - ), nil, nil + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - defer func() { _ = resp.Body.Close() }() + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get commit", resp, body), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get commit", resp, body), nil, nil - } + // Create a tree entry for the file deletion by setting SHA to nil + treeEntries := []*github.TreeEntry{ + { + Path: github.Ptr(path), + Mode: github.Ptr("100644"), // Regular file mode + Type: github.Ptr("blob"), + SHA: nil, // Setting SHA to nil deletes the file + }, + } - // Create a tree entry for the file deletion by setting SHA to nil - treeEntries := []*github.TreeEntry{ - { - Path: github.Ptr(path), - Mode: github.Ptr("100644"), // Regular file mode - Type: github.Ptr("blob"), - SHA: nil, // Setting SHA to nil deletes the file - }, - } + // Create a new tree with the deletion + newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, treeEntries) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create tree", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Create a new tree with the deletion - newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, treeEntries) + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create tree", - resp, - err, - ), nil, nil + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - defer func() { _ = resp.Body.Close() }() + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create tree", resp, body), nil, nil + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create tree", resp, body), nil, nil - } + // Create a new commit with the new tree + commit := github.Commit{ + Message: github.Ptr(message), + Tree: newTree, + Parents: []*github.Commit{{SHA: baseCommit.SHA}}, + } + newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create commit", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Create a new commit with the new tree - commit := github.Commit{ - Message: github.Ptr(message), - Tree: newTree, - Parents: []*github.Commit{{SHA: baseCommit.SHA}}, - } - newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create commit", - resp, - err, - ), nil, nil + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - defer func() { _ = resp.Body.Close() }() + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create commit", resp, body), nil, nil + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create commit", resp, body), nil, nil - } + // Update the branch reference to point to the new commit + ref.Object.SHA = newCommit.SHA + _, resp, err = client.Git.UpdateRef(ctx, owner, repo, *ref.Ref, github.UpdateRef{ + SHA: *newCommit.SHA, + Force: github.Ptr(false), + }) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update reference", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Update the branch reference to point to the new commit - ref.Object.SHA = newCommit.SHA - _, resp, err = client.Git.UpdateRef(ctx, owner, repo, *ref.Ref, github.UpdateRef{ - SHA: *newCommit.SHA, - Force: github.Ptr(false), - }) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update reference", resp, body), nil, nil - } - - // Create a response similar to what the DeleteFile API would return - response := map[string]interface{}{ - "commit": newCommit, - "content": nil, + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update reference", resp, body), nil, nil + } - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Create a response similar to what the DeleteFile API would return + response := map[string]interface{}{ + "commit": newCommit, + "content": nil, + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1127,82 +1111,80 @@ func CreateBranch(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "branch"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := RequiredParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - fromBranch, err := OptionalParam[string](args, "from_branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := RequiredParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + fromBranch, err := OptionalParam[string](args, "from_branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get the source branch SHA - var ref *github.Reference + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if fromBranch == "" { - // Get default branch if from_branch not specified - repository, resp, err := client.Repositories.Get(ctx, owner, repo) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get repository", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // Get the source branch SHA + var ref *github.Reference - fromBranch = *repository.DefaultBranch - } - - // Get SHA of source branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) + if fromBranch == "" { + // Get default branch if from_branch not specified + repository, resp, err := client.Repositories.Get(ctx, owner, repo) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get reference", + "failed to get repository", resp, err, ), nil, nil } defer func() { _ = resp.Body.Close() }() - // Create new branch - newRef := github.CreateRef{ - Ref: "refs/heads/" + branch, - SHA: *ref.Object.SHA, - } + fromBranch = *repository.DefaultBranch + } - createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create branch", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // Get SHA of source branch + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get reference", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(createdRef) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Create new branch + newRef := github.CreateRef{ + Ref: "refs/heads/" + branch, + SHA: *ref.Object.SHA, + } - return utils.NewToolResultText(string(r)), nil, nil + createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create branch", + resp, + err, + ), nil, nil } + defer func() { _ = resp.Body.Close() }() + + r, err := json.Marshal(createdRef) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1259,135 +1241,133 @@ func PushFiles(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "branch", "files", "message"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := RequiredParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - message, err := RequiredParam[string](args, "message") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - // Parse files parameter - this should be an array of objects with path and content - filesObj, ok := args["files"].([]interface{}) - if !ok { - return utils.NewToolResultError("files parameter must be an array of objects with path and content"), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - // Get the reference for the branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get branch reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - // Get the commit object that the branch points to - baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get base commit", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := RequiredParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + message, err := RequiredParam[string](args, "message") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Create tree entries for all files - var entries []*github.TreeEntry + // Parse files parameter - this should be an array of objects with path and content + filesObj, ok := args["files"].([]interface{}) + if !ok { + return utils.NewToolResultError("files parameter must be an array of objects with path and content"), nil, nil + } - for _, file := range filesObj { - fileMap, ok := file.(map[string]interface{}) - if !ok { - return utils.NewToolResultError("each file must be an object with path and content"), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - path, ok := fileMap["path"].(string) - if !ok || path == "" { - return utils.NewToolResultError("each file must have a path"), nil, nil - } + // Get the reference for the branch + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get branch reference", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + // Get the commit object that the branch points to + baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get base commit", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - content, ok := fileMap["content"].(string) - if !ok { - return utils.NewToolResultError("each file must have content"), nil, nil - } + // Create tree entries for all files + var entries []*github.TreeEntry - // Create a tree entry for the file - entries = append(entries, &github.TreeEntry{ - Path: github.Ptr(path), - Mode: github.Ptr("100644"), // Regular file mode - Type: github.Ptr("blob"), - Content: github.Ptr(content), - }) + for _, file := range filesObj { + fileMap, ok := file.(map[string]interface{}) + if !ok { + return utils.NewToolResultError("each file must be an object with path and content"), nil, nil } - // Create a new tree with the file entries - newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create tree", - resp, - err, - ), nil, nil + path, ok := fileMap["path"].(string) + if !ok || path == "" { + return utils.NewToolResultError("each file must have a path"), nil, nil } - defer func() { _ = resp.Body.Close() }() - // Create a new commit - commit := github.Commit{ - Message: github.Ptr(message), - Tree: newTree, - Parents: []*github.Commit{{SHA: baseCommit.SHA}}, - } - newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create commit", - resp, - err, - ), nil, nil + content, ok := fileMap["content"].(string) + if !ok { + return utils.NewToolResultError("each file must have content"), nil, nil } - defer func() { _ = resp.Body.Close() }() - // Update the reference to point to the new commit - ref.Object.SHA = newCommit.SHA - updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, *ref.Ref, github.UpdateRef{ - SHA: *newCommit.SHA, - Force: github.Ptr(false), + // Create a tree entry for the file + entries = append(entries, &github.TreeEntry{ + Path: github.Ptr(path), + Mode: github.Ptr("100644"), // Regular file mode + Type: github.Ptr("blob"), + Content: github.Ptr(content), }) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + } - r, err := json.Marshal(updatedRef) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Create a new tree with the file entries + newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create tree", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil + // Create a new commit + commit := github.Commit{ + Message: github.Ptr(message), + Tree: newTree, + Parents: []*github.Commit{{SHA: baseCommit.SHA}}, + } + newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create commit", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + // Update the reference to point to the new commit + ref.Object.SHA = newCommit.SHA + updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, *ref.Ref, github.UpdateRef{ + SHA: *newCommit.SHA, + Force: github.Ptr(false), + }) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update reference", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + r, err := json.Marshal(updatedRef) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1418,56 +1398,54 @@ func ListTags(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + opts := &github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + } - tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list tags", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list tags", resp, body), nil, nil - } + tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list tags", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(tags) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list tags", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(tags) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1502,71 +1480,69 @@ func GetTag(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo", "tag"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - tag, err := RequiredParam[string](args, "tag") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + tag, err := RequiredParam[string](args, "tag") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // First get the tag reference - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/tags/"+tag) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get tag reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get tag reference", resp, body), nil, nil - } + // First get the tag reference + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/tags/"+tag) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get tag reference", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Then get the tag object - tagObj, resp, err := client.Git.GetTag(ctx, owner, repo, *ref.Object.SHA) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get tag object", - resp, - err, - ), nil, nil + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - defer func() { _ = resp.Body.Close() }() + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get tag reference", resp, body), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get tag object", resp, body), nil, nil - } + // Then get the tag object + tagObj, resp, err := client.Git.GetTag(ctx, owner, repo, *ref.Object.SHA) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get tag object", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(tagObj) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get tag object", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(tagObj) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1597,52 +1573,50 @@ func ListReleases(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + opts := &github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + } - releases, resp, err := client.Repositories.ListReleases(ctx, owner, repo, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list releases: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list releases", resp, body), nil, nil - } + releases, resp, err := client.Repositories.ListReleases(ctx, owner, repo, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list releases: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(releases) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list releases", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(releases) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1673,43 +1647,41 @@ func GetLatestRelease(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - release, resp, err := client.Repositories.GetLatestRelease(ctx, owner, repo) - if err != nil { - return nil, nil, fmt.Errorf("failed to get latest release: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get latest release", resp, body), nil, nil - } + release, resp, err := client.Repositories.GetLatestRelease(ctx, owner, repo) + if err != nil { + return nil, nil, fmt.Errorf("failed to get latest release: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(release) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get latest release", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(release) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1743,51 +1715,49 @@ func GetReleaseByTag(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo", "tag"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - tag, err := RequiredParam[string](args, "tag") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + tag, err := RequiredParam[string](args, "tag") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - release, resp, err := client.Repositories.GetReleaseByTag(ctx, owner, repo, tag) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get release by tag: %s", tag), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get release by tag", resp, body), nil, nil - } + release, resp, err := client.Repositories.GetReleaseByTag(ctx, owner, repo, tag) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get release by tag: %s", tag), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(release) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get release by tag", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(release) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -1992,104 +1962,102 @@ func ListStarredRepositories(t translations.TranslationHelperFunc) inventory.Ser }, }), }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - username, err := OptionalParam[string](args, "username") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.ActivityListStarredOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } - if sort != "" { - opts.Sort = sort - } - if direction != "" { - opts.Direction = direction - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - var repos []*github.StarredRepository - var resp *github.Response - if username == "" { - // List starred repositories for the authenticated user - repos, resp, err = client.Activity.ListStarred(ctx, "", opts) - } else { - // List starred repositories for a specific user - repos, resp, err = client.Activity.ListStarred(ctx, username, opts) - } - - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list starred repositories for user '%s'", username), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + username, err := OptionalParam[string](args, "username") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list starred repositories", resp, body), nil, nil - } + opts := &github.ActivityListStarredOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } + if sort != "" { + opts.Sort = sort + } + if direction != "" { + opts.Direction = direction + } - // Convert to minimal format - minimalRepos := make([]MinimalRepository, 0, len(repos)) - for _, starredRepo := range repos { - repo := starredRepo.Repository - minimalRepo := MinimalRepository{ - ID: repo.GetID(), - Name: repo.GetName(), - FullName: repo.GetFullName(), - Description: repo.GetDescription(), - HTMLURL: repo.GetHTMLURL(), - Language: repo.GetLanguage(), - Stars: repo.GetStargazersCount(), - Forks: repo.GetForksCount(), - OpenIssues: repo.GetOpenIssuesCount(), - Private: repo.GetPrivate(), - Fork: repo.GetFork(), - Archived: repo.GetArchived(), - DefaultBranch: repo.GetDefaultBranch(), - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if repo.UpdatedAt != nil { - minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") - } + var repos []*github.StarredRepository + var resp *github.Response + if username == "" { + // List starred repositories for the authenticated user + repos, resp, err = client.Activity.ListStarred(ctx, "", opts) + } else { + // List starred repositories for a specific user + repos, resp, err = client.Activity.ListStarred(ctx, username, opts) + } - minimalRepos = append(minimalRepos, minimalRepo) - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list starred repositories for user '%s'", username), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(minimalRepos) + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal starred repositories: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list starred repositories", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + // Convert to minimal format + minimalRepos := make([]MinimalRepository, 0, len(repos)) + for _, starredRepo := range repos { + repo := starredRepo.Repository + minimalRepo := MinimalRepository{ + ID: repo.GetID(), + Name: repo.GetName(), + FullName: repo.GetFullName(), + Description: repo.GetDescription(), + HTMLURL: repo.GetHTMLURL(), + Language: repo.GetLanguage(), + Stars: repo.GetStargazersCount(), + Forks: repo.GetForksCount(), + OpenIssues: repo.GetOpenIssuesCount(), + Private: repo.GetPrivate(), + Fork: repo.GetFork(), + Archived: repo.GetArchived(), + DefaultBranch: repo.GetDefaultBranch(), + } + + if repo.UpdatedAt != nil { + minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") + } + + minimalRepos = append(minimalRepos, minimalRepo) + } + + r, err := json.Marshal(minimalRepos) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal starred repositories: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -2121,42 +2089,40 @@ func StarRepository(t translations.TranslationHelperFunc) inventory.ServerTool { Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - resp, err := client.Activity.Star(ctx, owner, repo) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to star repository %s/%s", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + resp, err := client.Activity.Star(ctx, owner, repo) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to star repository %s/%s", owner, repo), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 204 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to star repository", resp, body), nil, nil + if resp.StatusCode != 204 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - - return utils.NewToolResultText(fmt.Sprintf("Successfully starred repository %s/%s", owner, repo)), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to star repository", resp, body), nil, nil } + + return utils.NewToolResultText(fmt.Sprintf("Successfully starred repository %s/%s", owner, repo)), nil, nil }, ) } @@ -2187,42 +2153,40 @@ func UnstarRepository(t translations.TranslationHelperFunc) inventory.ServerTool Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - resp, err := client.Activity.Unstar(ctx, owner, repo) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to unstar repository %s/%s", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + resp, err := client.Activity.Unstar(ctx, owner, repo) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to unstar repository %s/%s", owner, repo), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 204 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to unstar repository", resp, body), nil, nil + if resp.StatusCode != 204 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - - return utils.NewToolResultText(fmt.Sprintf("Successfully unstarred repository %s/%s", owner, repo)), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to unstar repository", resp, body), nil, nil } + + return utils.NewToolResultText(fmt.Sprintf("Successfully unstarred repository %s/%s", owner, repo)), nil, nil }, ) } diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 5e338c7e7..9d7501f35 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -341,7 +341,7 @@ func Test_GetFileContents(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -464,7 +464,7 @@ func Test_ForkRepository(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -660,7 +660,7 @@ func Test_CreateBranch(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -792,7 +792,7 @@ func Test_GetCommit(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1018,7 +1018,7 @@ func Test_ListCommits(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1373,7 +1373,7 @@ func Test_CreateOrUpdateFile(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1565,7 +1565,7 @@ func Test_CreateRepository(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1904,7 +1904,7 @@ func Test_PushFiles(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2025,7 +2025,7 @@ func Test_ListBranches(t *testing.T) { request := createMCPRequest(tt.args) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tt.wantErr { require.NoError(t, err) if tt.errContains != "" { @@ -2213,7 +2213,7 @@ func Test_DeleteFile(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2340,7 +2340,7 @@ func Test_ListTags(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2500,7 +2500,7 @@ func Test_GetTag(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2609,7 +2609,7 @@ func Test_ListReleases(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) @@ -2700,7 +2700,7 @@ func Test_GetLatestRelease(t *testing.T) { } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) @@ -2850,7 +2850,7 @@ func Test_GetReleaseByTag(t *testing.T) { request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) @@ -3364,7 +3364,7 @@ func Test_ListStarredRepositories(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3465,7 +3465,7 @@ func Test_StarRepository(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3556,7 +3556,7 @@ func Test_UnstarRepository(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { diff --git a/pkg/github/search.go b/pkg/github/search.go index ae4d69ae5..9a8b971e2 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -56,112 +56,110 @@ func SearchRepositories(t translations.TranslationHelperFunc) inventory.ServerTo }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - result, resp, err := client.Search.Repositories(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search repositories with query '%s'", query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + result, resp, err := client.Search.Repositories(ctx, query, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search repositories with query '%s'", query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to search repositories", resp, body), nil, nil + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to search repositories", resp, body), nil, nil + } - // Return either minimal or full response based on parameter - var r []byte - if minimalOutput { - minimalRepos := make([]MinimalRepository, 0, len(result.Repositories)) - for _, repo := range result.Repositories { - minimalRepo := MinimalRepository{ - ID: repo.GetID(), - Name: repo.GetName(), - FullName: repo.GetFullName(), - Description: repo.GetDescription(), - HTMLURL: repo.GetHTMLURL(), - Language: repo.GetLanguage(), - Stars: repo.GetStargazersCount(), - Forks: repo.GetForksCount(), - OpenIssues: repo.GetOpenIssuesCount(), - Private: repo.GetPrivate(), - Fork: repo.GetFork(), - Archived: repo.GetArchived(), - DefaultBranch: repo.GetDefaultBranch(), - } - - if repo.UpdatedAt != nil { - minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") - } - if repo.CreatedAt != nil { - minimalRepo.CreatedAt = repo.CreatedAt.Format("2006-01-02T15:04:05Z") - } - if repo.Topics != nil { - minimalRepo.Topics = repo.Topics - } - - minimalRepos = append(minimalRepos, minimalRepo) + // Return either minimal or full response based on parameter + var r []byte + if minimalOutput { + minimalRepos := make([]MinimalRepository, 0, len(result.Repositories)) + for _, repo := range result.Repositories { + minimalRepo := MinimalRepository{ + ID: repo.GetID(), + Name: repo.GetName(), + FullName: repo.GetFullName(), + Description: repo.GetDescription(), + HTMLURL: repo.GetHTMLURL(), + Language: repo.GetLanguage(), + Stars: repo.GetStargazersCount(), + Forks: repo.GetForksCount(), + OpenIssues: repo.GetOpenIssuesCount(), + Private: repo.GetPrivate(), + Fork: repo.GetFork(), + Archived: repo.GetArchived(), + DefaultBranch: repo.GetDefaultBranch(), } - minimalResult := &MinimalSearchRepositoriesResult{ - TotalCount: result.GetTotal(), - IncompleteResults: result.GetIncompleteResults(), - Items: minimalRepos, + if repo.UpdatedAt != nil { + minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") } - - r, err = json.Marshal(minimalResult) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal minimal response", err), nil, nil + if repo.CreatedAt != nil { + minimalRepo.CreatedAt = repo.CreatedAt.Format("2006-01-02T15:04:05Z") } - } else { - r, err = json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal full response", err), nil, nil + if repo.Topics != nil { + minimalRepo.Topics = repo.Topics } + + minimalRepos = append(minimalRepos, minimalRepo) } - return utils.NewToolResultText(string(r)), nil, nil + minimalResult := &MinimalSearchRepositoriesResult{ + TotalCount: result.GetTotal(), + IncompleteResults: result.GetIncompleteResults(), + Items: minimalRepos, + } + + r, err = json.Marshal(minimalResult) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal minimal response", err), nil, nil + } + } else { + r, err = json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal full response", err), nil, nil + } } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -200,154 +198,150 @@ func SearchCode(t translations.TranslationHelperFunc) inventory.ServerTool { }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - result, resp, err := client.Search.Code(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search code with query '%s'", query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to search code", resp, body), nil, nil - } + result, resp, err := client.Search.Code(ctx, query, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search code with query '%s'", query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(result) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to search code", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } -func userOrOrgHandler(accountType string, deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } +func userOrOrgHandler(ctx context.Context, accountType string, deps ToolDependencies, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - client, err := deps.GetClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - searchQuery := query - if !hasTypeFilter(query) { - searchQuery = "type:" + accountType + " " + query - } - result, resp, err := client.Search.Users(ctx, searchQuery, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search %ss with query '%s'", accountType, query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + searchQuery := query + if !hasTypeFilter(query) { + searchQuery = "type:" + accountType + " " + query + } + result, resp, err := client.Search.Users(ctx, searchQuery, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search %ss with query '%s'", accountType, query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to search %ss", accountType), resp, body), nil, nil + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to search %ss", accountType), resp, body), nil, nil + } - minimalUsers := make([]MinimalUser, 0, len(result.Users)) + minimalUsers := make([]MinimalUser, 0, len(result.Users)) - for _, user := range result.Users { - if user.Login != nil { - mu := MinimalUser{ - Login: user.GetLogin(), - ID: user.GetID(), - ProfileURL: user.GetHTMLURL(), - AvatarURL: user.GetAvatarURL(), - } - minimalUsers = append(minimalUsers, mu) + for _, user := range result.Users { + if user.Login != nil { + mu := MinimalUser{ + Login: user.GetLogin(), + ID: user.GetID(), + ProfileURL: user.GetHTMLURL(), + AvatarURL: user.GetAvatarURL(), } + minimalUsers = append(minimalUsers, mu) } - minimalResp := &MinimalSearchUsersResult{ - TotalCount: result.GetTotal(), - IncompleteResults: result.GetIncompleteResults(), - Items: minimalUsers, - } - if result.Total != nil { - minimalResp.TotalCount = *result.Total - } - if result.IncompleteResults != nil { - minimalResp.IncompleteResults = *result.IncompleteResults - } + } + minimalResp := &MinimalSearchUsersResult{ + TotalCount: result.GetTotal(), + IncompleteResults: result.GetIncompleteResults(), + Items: minimalUsers, + } + if result.Total != nil { + minimalResp.TotalCount = *result.Total + } + if result.IncompleteResults != nil { + minimalResp.IncompleteResults = *result.IncompleteResults + } - r, err := json.Marshal(minimalResp) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil - } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(minimalResp) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + return utils.NewToolResultText(string(r)), nil, nil } // SearchUsers creates a tool to search for GitHub users. @@ -385,8 +379,8 @@ func SearchUsers(t translations.TranslationHelperFunc) inventory.ServerTool { }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return userOrOrgHandler("user", deps) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + return userOrOrgHandler(ctx, "user", deps, args) }, ) } @@ -426,8 +420,8 @@ func SearchOrgs(t translations.TranslationHelperFunc) inventory.ServerTool { }, InputSchema: schema, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return userOrOrgHandler("org", deps) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + return userOrOrgHandler(ctx, "org", deps, args) }, ) } diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 707b55349..be1b26714 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -143,7 +143,7 @@ func Test_SearchRepositories(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -221,7 +221,7 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { request := createMCPRequest(args) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -367,7 +367,7 @@ func Test_SearchCode(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -567,7 +567,7 @@ func Test_SearchUsers(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -742,7 +742,7 @@ func Test_SearchOrgs(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index 4e3ced7e2..0de5166ba 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -45,51 +45,49 @@ func GetSecretScanningAlert(t translations.TranslationHelperFunc) inventory.Serv Required: []string{"owner", "repo", "alertNumber"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - alertNumber, err := RequiredInt(args, "alertNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + alertNumber, err := RequiredInt(args, "alertNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - alert, resp, err := client.SecretScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get alert with number '%d'", alertNumber), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil - } + alert, resp, err := client.SecretScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get alert with number '%d'", alertNumber), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(alert) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal alert: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(alert) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal alert: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -133,58 +131,56 @@ func ListSecretScanningAlerts(t translations.TranslationHelperFunc) inventory.Se Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - secretType, err := OptionalParam[string](args, "secret_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - resolution, err := OptionalParam[string](args, "resolution") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - alerts, resp, err := client.SecretScanning.ListAlertsForRepo(ctx, owner, repo, &github.SecretScanningAlertListOptions{State: state, SecretType: secretType, Resolution: resolution}) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + secretType, err := OptionalParam[string](args, "secret_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + resolution, err := OptionalParam[string](args, "resolution") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + alerts, resp, err := client.SecretScanning.ListAlertsForRepo(ctx, owner, repo, &github.SecretScanningAlertListOptions{State: state, SecretType: secretType, Resolution: resolution}) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(alerts) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal alerts: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(alerts) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal alerts: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } diff --git a/pkg/github/secret_scanning_test.go b/pkg/github/secret_scanning_test.go index b63617a46..23ac868c7 100644 --- a/pkg/github/secret_scanning_test.go +++ b/pkg/github/secret_scanning_test.go @@ -96,7 +96,7 @@ func Test_GetSecretScanningAlert(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -235,7 +235,7 @@ func Test_ListSecretScanningAlerts(t *testing.T) { request := createMCPRequest(tc.requestArgs) - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/security_advisories.go b/pkg/github/security_advisories.go index b35fb5a1c..f898de61d 100644 --- a/pkg/github/security_advisories.go +++ b/pkg/github/security_advisories.go @@ -83,127 +83,125 @@ func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) inventor }, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - ghsaID, err := OptionalParam[string](args, "ghsaId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil - } + ghsaID, err := OptionalParam[string](args, "ghsaId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil + } - typ, err := OptionalParam[string](args, "type") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid type: %v", err)), nil, nil - } + typ, err := OptionalParam[string](args, "type") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid type: %v", err)), nil, nil + } - cveID, err := OptionalParam[string](args, "cveId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid cveId: %v", err)), nil, nil - } + cveID, err := OptionalParam[string](args, "cveId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid cveId: %v", err)), nil, nil + } - eco, err := OptionalParam[string](args, "ecosystem") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ecosystem: %v", err)), nil, nil - } + eco, err := OptionalParam[string](args, "ecosystem") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ecosystem: %v", err)), nil, nil + } - sev, err := OptionalParam[string](args, "severity") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid severity: %v", err)), nil, nil - } + sev, err := OptionalParam[string](args, "severity") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid severity: %v", err)), nil, nil + } - cwes, err := OptionalStringArrayParam(args, "cwes") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid cwes: %v", err)), nil, nil - } + cwes, err := OptionalStringArrayParam(args, "cwes") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid cwes: %v", err)), nil, nil + } - isWithdrawn, err := OptionalParam[bool](args, "isWithdrawn") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid isWithdrawn: %v", err)), nil, nil - } + isWithdrawn, err := OptionalParam[bool](args, "isWithdrawn") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid isWithdrawn: %v", err)), nil, nil + } - affects, err := OptionalParam[string](args, "affects") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid affects: %v", err)), nil, nil - } + affects, err := OptionalParam[string](args, "affects") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid affects: %v", err)), nil, nil + } - published, err := OptionalParam[string](args, "published") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid published: %v", err)), nil, nil - } + published, err := OptionalParam[string](args, "published") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid published: %v", err)), nil, nil + } - updated, err := OptionalParam[string](args, "updated") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid updated: %v", err)), nil, nil - } + updated, err := OptionalParam[string](args, "updated") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid updated: %v", err)), nil, nil + } - modified, err := OptionalParam[string](args, "modified") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid modified: %v", err)), nil, nil - } + modified, err := OptionalParam[string](args, "modified") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid modified: %v", err)), nil, nil + } - opts := &github.ListGlobalSecurityAdvisoriesOptions{} + opts := &github.ListGlobalSecurityAdvisoriesOptions{} - if ghsaID != "" { - opts.GHSAID = &ghsaID - } - if typ != "" { - opts.Type = &typ - } - if cveID != "" { - opts.CVEID = &cveID - } - if eco != "" { - opts.Ecosystem = &eco - } - if sev != "" { - opts.Severity = &sev - } - if len(cwes) > 0 { - opts.CWEs = cwes - } + if ghsaID != "" { + opts.GHSAID = &ghsaID + } + if typ != "" { + opts.Type = &typ + } + if cveID != "" { + opts.CVEID = &cveID + } + if eco != "" { + opts.Ecosystem = &eco + } + if sev != "" { + opts.Severity = &sev + } + if len(cwes) > 0 { + opts.CWEs = cwes + } - if isWithdrawn { - opts.IsWithdrawn = &isWithdrawn - } + if isWithdrawn { + opts.IsWithdrawn = &isWithdrawn + } - if affects != "" { - opts.Affects = &affects - } - if published != "" { - opts.Published = &published - } - if updated != "" { - opts.Updated = &updated - } - if modified != "" { - opts.Modified = &modified - } + if affects != "" { + opts.Affects = &affects + } + if published != "" { + opts.Published = &published + } + if updated != "" { + opts.Updated = &updated + } + if modified != "" { + opts.Modified = &modified + } - advisories, resp, err := client.SecurityAdvisories.ListGlobalSecurityAdvisories(ctx, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list global security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list advisories", resp, body), nil, nil - } + advisories, resp, err := client.SecurityAdvisories.ListGlobalSecurityAdvisories(ctx, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list global security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(advisories) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list advisories", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -248,67 +246,65 @@ func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) inve Required: []string{"owner", "repo"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sortField, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sortField, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - opts := &github.ListRepositorySecurityAdvisoriesOptions{} - if direction != "" { - opts.Direction = direction - } - if sortField != "" { - opts.Sort = sortField - } - if state != "" { - opts.State = state - } + opts := &github.ListRepositorySecurityAdvisoriesOptions{} + if direction != "" { + opts.Direction = direction + } + if sortField != "" { + opts.Sort = sortField + } + if state != "" { + opts.State = state + } - advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisories(ctx, owner, repo, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list repository security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list repository advisories", resp, body), nil, nil - } + advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisories(ctx, owner, repo, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list repository security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(advisories) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list repository advisories", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -334,39 +330,37 @@ func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) inventory.S Required: []string{"ghsaId"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - ghsaID, err := RequiredParam[string](args, "ghsaId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil - } + ghsaID, err := RequiredParam[string](args, "ghsaId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil + } - advisory, resp, err := client.SecurityAdvisories.GetGlobalSecurityAdvisories(ctx, ghsaID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get advisory: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get advisory", resp, body), nil, nil - } + advisory, resp, err := client.SecurityAdvisories.GetGlobalSecurityAdvisories(ctx, ghsaID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get advisory: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(advisory) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisory: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get advisory", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(advisory) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisory: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } @@ -407,62 +401,60 @@ func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) i Required: []string{"org"}, }, }, - func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - org, err := RequiredParam[string](args, "org") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sortField, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + org, err := RequiredParam[string](args, "org") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sortField, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := deps.GetClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - opts := &github.ListRepositorySecurityAdvisoriesOptions{} - if direction != "" { - opts.Direction = direction - } - if sortField != "" { - opts.Sort = sortField - } - if state != "" { - opts.State = state - } + opts := &github.ListRepositorySecurityAdvisoriesOptions{} + if direction != "" { + opts.Direction = direction + } + if sortField != "" { + opts.Sort = sortField + } + if state != "" { + opts.State = state + } - advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisoriesForOrg(ctx, org, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list organization repository security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list organization repository advisories", resp, body), nil, nil - } + advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisoriesForOrg(ctx, org, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list organization repository security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(advisories) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + return nil, nil, fmt.Errorf("failed to read response body: %w", err) } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list organization repository advisories", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) } + + return utils.NewToolResultText(string(r)), nil, nil }, ) } diff --git a/pkg/github/security_advisories_test.go b/pkg/github/security_advisories_test.go index 3970949ec..d1e943bd7 100644 --- a/pkg/github/security_advisories_test.go +++ b/pkg/github/security_advisories_test.go @@ -110,7 +110,7 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -231,7 +231,7 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -378,7 +378,7 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) @@ -523,7 +523,7 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), &request) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) diff --git a/pkg/inventory/server_tool.go b/pkg/inventory/server_tool.go index b21df2520..362ee2643 100644 --- a/pkg/inventory/server_tool.go +++ b/pkg/inventory/server_tool.go @@ -109,6 +109,9 @@ func (st *ServerTool) RegisterFunc(s *mcp.Server, deps any) { // NewServerTool creates a ServerTool from a tool definition, toolset metadata, and a typed handler function. // The handler function takes dependencies (as any) and returns a typed handler. // Callers should type-assert deps to their typed dependencies struct. +// +// Deprecated: This creates closures at registration time. For better performance in +// per-request server scenarios, use NewServerToolWithContextHandler instead. func NewServerTool[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandlerFor[In, Out]) ServerTool { return ServerTool{ Tool: tool, @@ -127,8 +130,52 @@ func NewServerTool[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, hand } } +// NewServerToolWithContextHandler creates a ServerTool with a handler that receives deps via context. +// This is the preferred approach for tools because it doesn't create closures at registration time, +// which is critical for performance in servers that create a new instance per request. +// +// The handler function is stored directly without wrapping in a deps closure. +// Dependencies should be injected into context before calling tool handlers. +func NewServerToolWithContextHandler[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, handler mcp.ToolHandlerFor[In, Out]) ServerTool { + return ServerTool{ + Tool: tool, + Toolset: toolset, + // HandlerFunc ignores deps - deps are retrieved from context at call time + HandlerFunc: func(_ any) mcp.ToolHandler { + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var arguments In + if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { + return nil, err + } + resp, _, err := handler(ctx, req, arguments) + return resp, err + } + }, + } +} + // NewServerToolFromHandler creates a ServerTool from a tool definition, toolset metadata, and a raw handler function. // Use this when you have a handler that already conforms to mcp.ToolHandler. +// +// Deprecated: This creates closures at registration time. For better performance in +// per-request server scenarios, use NewServerToolWithRawContextHandler instead. func NewServerToolFromHandler(tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandler) ServerTool { return ServerTool{Tool: tool, Toolset: toolset, HandlerFunc: handlerFn} } + +// NewServerToolWithRawContextHandler creates a ServerTool with a raw handler that receives deps via context. +// This is the preferred approach for tools that use mcp.ToolHandler directly because it doesn't +// create closures at registration time. +// +// The handler function is stored directly without wrapping in a deps closure. +// Dependencies should be injected into context before calling tool handlers. +func NewServerToolWithRawContextHandler(tool mcp.Tool, toolset ToolsetMetadata, handler mcp.ToolHandler) ServerTool { + return ServerTool{ + Tool: tool, + Toolset: toolset, + // HandlerFunc ignores deps - deps are retrieved from context at call time + HandlerFunc: func(_ any) mcp.ToolHandler { + return handler + }, + } +}