Skip to content

Commit 2ca985c

Browse files
committed
feat: add MCP injection support to responses streaming interceptor
1 parent e164b50 commit 2ca985c

File tree

7 files changed

+1231
-202
lines changed

7 files changed

+1231
-202
lines changed

fixtures/fixtures.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ var (
6969
//go:embed openai/responses/blocking/prev_response_id.txtar
7070
OaiResponsesBlockingPrevResponseID []byte
7171

72-
//go:embed openai/responses/blocking/wrong_response_format.txtar
73-
OaiResponsesBlockingWrongResponseFormat []byte
74-
7572
//go:embed openai/responses/blocking/single_injected_tool.txtar
76-
OaiResponsesSingleInjectedTool []byte
73+
OaiResponsesBlockingSingleInjectedTool []byte
7774

7875
//go:embed openai/responses/blocking/single_injected_tool_error.txtar
79-
OaiResponsesSingleInjectedToolError []byte
76+
OaiResponsesBlockingSingleInjectedToolError []byte
77+
78+
//go:embed openai/responses/blocking/wrong_response_format.txtar
79+
OaiResponsesBlockingWrongResponseFormat []byte
8080
)
8181

8282
var (
@@ -104,6 +104,12 @@ var (
104104
//go:embed openai/responses/streaming/prev_response_id.txtar
105105
OaiResponsesStreamingPrevResponseID []byte
106106

107+
//go:embed openai/responses/streaming/single_injected_tool.txtar
108+
OaiResponsesStreamingSingleInjectedTool []byte
109+
110+
//go:embed openai/responses/streaming/single_injected_tool_error.txtar
111+
OaiResponsesStreamingSingleInjectedToolError []byte
112+
107113
//go:embed openai/responses/streaming/stream_error.txtar
108114
OaiResponsesStreamingStreamError []byte
109115

fixtures/openai/responses/streaming/single_injected_tool.txtar

Lines changed: 595 additions & 0 deletions
Large diffs are not rendered by default.

fixtures/openai/responses/streaming/single_injected_tool_error.txtar

Lines changed: 250 additions & 0 deletions
Large diffs are not rendered by default.

intercept/responses/blocking.go

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package responses
33
import (
44
"context"
55
"errors"
6-
"fmt"
76
"net/http"
87
"time"
98

@@ -15,7 +14,6 @@ import (
1514
"github.com/google/uuid"
1615
"github.com/openai/openai-go/v3/option"
1716
"github.com/openai/openai-go/v3/responses"
18-
"github.com/tidwall/sjson"
1917
"go.opentelemetry.io/otel/attribute"
2018
"go.opentelemetry.io/otel/trace"
2119
)
@@ -62,98 +60,54 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
6260

6361
var (
6462
response *responses.Response
63+
err error
6564
upstreamErr error
6665
respCopy responseCopier
6766
)
6867

69-
for {
68+
shouldLoop := true
69+
recordPromptOnce := true
70+
for shouldLoop {
7071
srv := i.newResponsesService()
7172
respCopy = responseCopier{}
7273

7374
opts := i.requestOptions(&respCopy)
7475
opts = append(opts, option.WithRequestTimeout(time.Second*600))
7576
response, upstreamErr = i.newResponse(ctx, srv, opts)
7677

77-
if upstreamErr != nil {
78+
if upstreamErr != nil || response == nil {
7879
break
7980
}
8081

81-
// response could be nil eg. fixtures/openai/responses/blocking/wrong_response_format.txtar
82-
if response == nil {
83-
break
82+
// Record prompt usage on first successful response.
83+
if recordPromptOnce {
84+
recordPromptOnce = false
85+
i.recordUserPrompt(ctx, response.ID)
8486
}
8587

86-
// Record prompt usage on first successful response.
87-
i.recordUserPrompt(ctx, response.ID)
88+
// Record token usage for each inner loop iteration
8889
i.recordTokenUsage(ctx, response)
8990

9091
// Check if there any injected tools to invoke.
91-
pending := i.getPendingInjectedToolCalls(ctx, response)
92-
if len(pending) == 0 {
93-
// No injected tools, record non-injected tool usage.
94-
i.recordNonInjectedToolUsage(ctx, response)
95-
96-
// No injected function calls need to be invoked, flow is complete.
97-
break
98-
}
99-
100-
shouldLoop, err := i.handleInnerAgenticLoop(ctx, pending, response)
92+
pending := i.getPendingInjectedToolCalls(response)
93+
shouldLoop, err = i.handleInnerAgenticLoop(ctx, pending, response)
10194
if err != nil {
10295
i.sendCustomErr(ctx, w, http.StatusInternalServerError, err)
10396
shouldLoop = false
10497
}
105-
106-
if !shouldLoop {
107-
break
108-
}
10998
}
11099

100+
i.recordNonInjectedToolUsage(ctx, response)
101+
111102
if upstreamErr != nil && !respCopy.responseReceived.Load() {
112103
// no response received from upstream, return custom error
113104
i.sendCustomErr(ctx, w, http.StatusInternalServerError, upstreamErr)
114105
}
115106

116-
err := respCopy.forwardResp(w)
117-
107+
err = respCopy.forwardResp(w)
118108
return errors.Join(upstreamErr, err)
119109
}
120110

121-
// handleInnerAgenticLoop orchestrates the inner agentic loop whereby injected tools
122-
// are invoked and their results are sent back to the model.
123-
// This is in contrast to regular tool calls which will be handled by the client
124-
// in its own agentic loop.
125-
func (i *BlockingResponsesInterceptor) handleInnerAgenticLoop(ctx context.Context, pending []responses.ResponseFunctionToolCall, response *responses.Response) (bool, error) {
126-
// Invoke any injected function calls.
127-
// The Responses API refers to what we call "tools" as "functions", so we keep the terminology
128-
// consistent in this package.
129-
// See https://platform.openai.com/docs/guides/function-calling
130-
results, err := i.handleInjectedToolCalls(ctx, pending, response)
131-
if err != nil {
132-
return false, fmt.Errorf("failed to handle injected tool calls: %w", err)
133-
}
134-
135-
// No tool results means no tools were invocable, so the flow is complete.
136-
if len(results) == 0 {
137-
return false, nil
138-
}
139-
140-
// We'll use the tool results to issue another request to provide the model with.
141-
i.prepareRequestForAgenticLoop(response)
142-
i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, results...)
143-
144-
// TODO: we should avoid re-marshaling Input, but since it changes from a string to
145-
// a list in this loop, we have to.
146-
// See responsesInterceptionBase.requestOptions for more details about marshaling issues.
147-
i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input", i.req.Input)
148-
if err != nil {
149-
i.logger.Error(ctx, "failure to marshal new input in inner agentic loop", slog.Error(err))
150-
// TODO: what should be returned under this condition?
151-
return false, nil
152-
}
153-
154-
return true, nil
155-
}
156-
157111
func (i *BlockingResponsesInterceptor) newResponse(ctx context.Context, srv responses.ResponseService, opts []option.RequestOption) (_ *responses.Response, outErr error) {
158112
ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
159113
defer tracing.EndSpanErr(span, &outErr)

intercept/responses/injected_tools.go

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,31 @@ func (i *responsesInterceptionBase) disableParallelToolCalls() {
7676
}
7777
}
7878

79+
// handleInnerAgenticLoop orchestrates the inner agentic loop whereby injected tools
80+
// are invoked and their results are sent back to the model.
81+
// This is in contrast to regular tool calls which will be handled by the client
82+
// in its own agentic loop.
83+
func (i *responsesInterceptionBase) handleInnerAgenticLoop(ctx context.Context, pending []responses.ResponseFunctionToolCall, response *responses.Response) (bool, error) {
84+
// Invoke any injected function calls.
85+
// The Responses API refers to what we call "tools" as "functions", so we keep the terminology
86+
// consistent in this package.
87+
// See https://platform.openai.com/docs/guides/function-calling
88+
results, err := i.handleInjectedToolCalls(ctx, pending, response)
89+
if err != nil {
90+
return false, fmt.Errorf("failed to handle injected tool calls: %w", err)
91+
}
92+
93+
// No tool results means no tools were invocable, so the flow is complete.
94+
if len(results) == 0 {
95+
return false, nil
96+
}
97+
98+
// We'll use the tool results to issue another request to provide the model with.
99+
i.prepareRequestForAgenticLoop(ctx, response, results)
100+
101+
return true, nil
102+
}
103+
79104
// handleInjectedToolCalls checks for function calls that we need to handle in our inner agentic loop.
80105
// These are functions injected by the MCP proxy.
81106
// Returns a list of tool call results.
@@ -99,19 +124,55 @@ func (i *responsesInterceptionBase) handleInjectedToolCalls(ctx context.Context,
99124

100125
// prepareRequestForAgenticLoop prepares the request by setting the output of the given
101126
// response as input to the next request, in order for the tool call result(s) to make function correctly.
102-
func (i *responsesInterceptionBase) prepareRequestForAgenticLoop(response *responses.Response) {
127+
func (i *responsesInterceptionBase) prepareRequestForAgenticLoop(ctx context.Context, response *responses.Response, toolResults []responses.ResponseInputItemUnionParam) error {
128+
var err error
129+
103130
// Unset the string input; we need a list now.
104-
i.req.Input.OfString = param.Opt[string]{}
131+
if i.req.Input.OfString.Valid() {
132+
// convert old string value to list item
133+
i.req.Input.OfInputItemList = responses.ResponseInputParam{
134+
responses.ResponseInputItemParamOfMessage(
135+
i.req.Input.OfString.Value,
136+
responses.EasyInputMessageRoleUser,
137+
),
138+
}
139+
if i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input", i.req.Input.OfInputItemList); err != nil {
140+
i.logger.Error(ctx, "failure to marshal str output to new input in inner agentic loop", slog.Error(err))
141+
return fmt.Errorf("failed to marshal input: %v", err)
142+
}
143+
144+
// clear old value
145+
i.req.Input.OfString = param.Opt[string]{}
146+
}
147+
inputSize := len(i.req.Input.OfInputItemList)
105148

106149
// OutputText is also available, but by definition the trigger for a function call is not a simple
107150
// text response from the model.
108151
for _, output := range response.Output {
109-
i.appendOutputToInput(i.req, output)
152+
if inputItem := i.convertOutputToInput(output); inputItem != nil {
153+
i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, *inputItem)
154+
}
155+
}
156+
157+
for _, result := range toolResults {
158+
i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, result)
110159
}
160+
161+
// Append newly added items to reqPayload field
162+
// New items are appended to limit Input re-marshaling.
163+
// See responsesInterceptionBase.requestOptions for more details about marshaling issues.
164+
for j := inputSize; j < len(i.req.Input.OfInputItemList); j++ {
165+
if i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input.-1", i.req.Input.OfInputItemList[j]); err != nil {
166+
i.logger.Error(ctx, "failure to marshal output to new input in inner agentic loop", slog.Error(err))
167+
return fmt.Errorf("failed to marshal input: %v", err)
168+
}
169+
}
170+
171+
return nil
111172
}
112173

113174
// getPendingInjectedToolCalls extracts function calls from the response that are managed by MCP proxy
114-
func (i *responsesInterceptionBase) getPendingInjectedToolCalls(ctx context.Context, response *responses.Response) []responses.ResponseFunctionToolCall {
175+
func (i *responsesInterceptionBase) getPendingInjectedToolCalls(response *responses.Response) []responses.ResponseFunctionToolCall {
115176
var calls []responses.ResponseFunctionToolCall
116177

117178
for _, item := range response.Output {
@@ -171,14 +232,14 @@ func (i *responsesInterceptionBase) invokeInjectedTool(ctx context.Context, resp
171232
return responses.ResponseInputItemParamOfFunctionCallOutput(fc.CallID, output)
172233
}
173234

174-
// appendOutputToInput converts a response output item to an input item and appends it to the
235+
// convertOutputToInput converts a response output item to an input item and appends it to the
175236
// request's input list. This is used in agentic loops where we need to feed the model's output
176237
// back as input for the next iteration (e.g., when processing tool call results).
177238
//
178239
// The conversion uses the openai-go library's ToParam() methods where available, which leverage
179240
// param.Override() with raw JSON to preserve all fields. For types without ToParam(), we use
180241
// the ResponseInputItemParamOf* helper functions.
181-
func (i *responsesInterceptionBase) appendOutputToInput(req *ResponsesNewParamsWrapper, item responses.ResponseOutputItemUnion) {
242+
func (i *responsesInterceptionBase) convertOutputToInput(item responses.ResponseOutputItemUnion) *responses.ResponseInputItemUnionParam {
182243
var inputItem responses.ResponseInputItemUnionParam
183244

184245
switch item.Type {
@@ -228,8 +289,8 @@ func (i *responsesInterceptionBase) appendOutputToInput(req *ResponsesNewParamsW
228289
// - mcp_call, mcp_list_tools, mcp_approval_request: MCP-specific outputs
229290
default:
230291
i.logger.Debug(context.Background(), "skipping output item type for input", slog.F("type", item.Type))
231-
return
292+
return nil
232293
}
233294

234-
req.Input.OfInputItemList = append(req.Input.OfInputItemList, inputItem)
295+
return &inputItem
235296
}

0 commit comments

Comments
 (0)