Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ test-race:
CGO_ENABLED=1 go test -count=1 -race ./...

coverage:
go test -coverprofile=coverage.out ./...
go test -coverprofile=coverage.out -coverpkg=./... ./...
go tool cover -func=coverage.out | tail -n 1

coverage-html:
Expand Down
27 changes: 23 additions & 4 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -1792,16 +1793,31 @@ const mockToolName = "coder_list_workspaces"

// callAccumulator tracks all tool invocations by name and each instance's arguments.
type callAccumulator struct {
calls map[string][]any
callsMu sync.Mutex
calls map[string][]any
callsMu sync.Mutex
toolErrors map[string]string
}

func newCallAccumulator() *callAccumulator {
return &callAccumulator{
calls: make(map[string][]any),
calls: make(map[string][]any),
toolErrors: make(map[string]string),
}
}

func (a *callAccumulator) setToolError(tool string, errMsg string) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
a.toolErrors[tool] = errMsg
}

func (a *callAccumulator) getToolError(tool string) (string, bool) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
errMsg, ok := a.toolErrors[tool]
return errMsg, ok
}

func (a *callAccumulator) addCall(tool string, args any) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
Expand Down Expand Up @@ -1831,12 +1847,15 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
// Accumulate tool calls & their arguments.
acc := newCallAccumulator()

for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
acc.addCall(request.Params.Name, request.Params.Arguments)
if errMsg, ok := acc.getToolError(request.Params.Name); ok {
return nil, errors.New(errMsg)
}
return mcplib.NewToolResultText("mock"), nil
})
}
Expand Down
10 changes: 8 additions & 2 deletions fixtures/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ var (
//go:embed openai/responses/blocking/simple.txtar
OaiResponsesBlockingSimple []byte

//go:embed openai/responses/blocking/builtin_tool.txtar
OaiResponsesBlockingBuiltinTool []byte
//go:embed openai/responses/blocking/single_builtin_tool.txtar
OaiResponsesBlockingSingleBuiltinTool []byte

//go:embed openai/responses/blocking/cached_input_tokens.txtar
OaiResponsesBlockingCachedInputTokens []byte
Expand All @@ -68,6 +68,12 @@ var (

//go:embed openai/responses/blocking/wrong_response_format.txtar
OaiResponsesBlockingWrongResponseFormat []byte

//go:embed openai/responses/blocking/single_injected_tool.txtar
OaiResponsesSingleInjectedTool []byte

//go:embed openai/responses/blocking/single_injected_tool_error.txtar
OaiResponsesSingleInjectedToolError []byte
)

var (
Expand Down
1,522 changes: 1,522 additions & 0 deletions fixtures/openai/responses/blocking/single_injected_tool.txtar

Large diffs are not rendered by default.

1,522 changes: 1,522 additions & 0 deletions fixtures/openai/responses/blocking/single_injected_tool_error.txtar

Large diffs are not rendered by default.

30 changes: 10 additions & 20 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import (
"github.com/google/uuid"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
oaiconst "github.com/openai/openai-go/v3/shared/constant"
"github.com/openai/openai-go/v3/shared/constant"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

const (
Expand All @@ -42,6 +42,7 @@ type responsesInterceptionBase struct {
reqPayload []byte
cfg config.OpenAI
model string
tracer trace.Tracer
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
logger slog.Logger
Expand Down Expand Up @@ -97,18 +98,6 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.
return err
}

// keeping the same logic for 'parallel_tool_calls' as in chat-completions
// https://github.com/coder/aibridge/blob/7535a71e91a1d214a31a9b59bb810befb26141bc/intercept/chatcompletions/streaming.go#L99
if len(i.req.Tools) > 0 {
var err error
i.reqPayload, err = sjson.SetBytes(i.reqPayload, "parallel_tool_calls", false)
if err != nil {
err = fmt.Errorf("failed set parallel_tool_calls parameter: %w", err)
i.sendCustomErr(ctx, w, http.StatusInternalServerError, err)
return err
}
}

return nil
}

Expand Down Expand Up @@ -174,7 +163,7 @@ func (i *responsesInterceptionBase) lastUserPrompt() (string, error) {
inputItems := gjson.GetBytes(i.reqPayload, "input").Array()
for i := len(inputItems) - 1; i >= 0; i-- {
item := inputItems[i]
if item.Get("role").Str == "user" {
if item.Get("role").Str == string(constant.ValueOf[constant.User]()) {
var sb strings.Builder

// content can be a string or array of objects:
Expand All @@ -194,7 +183,8 @@ func (i *responsesInterceptionBase) lastUserPrompt() (string, error) {
}
}

return "", errors.New("failed to find last user prompt")
// Request was likely not human-initiated.
return "", nil
}

func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) {
Expand All @@ -204,8 +194,8 @@ func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, respon
return
}

// No prompt found: last request was not human-initiated.
if prompt == "" {
i.logger.Warn(ctx, "got empty last prompt, skipping prompt recording")
return
}

Expand All @@ -224,7 +214,7 @@ func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, respon
}
}

func (i *responsesInterceptionBase) recordToolUsage(ctx context.Context, response *responses.Response) {
func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Context, response *responses.Response) {
if response == nil {
i.logger.Warn(ctx, "got empty response, skipping tool usage recording")
return
Expand All @@ -235,9 +225,9 @@ func (i *responsesInterceptionBase) recordToolUsage(ctx context.Context, respons

// recording other function types to be considered: https://github.com/coder/aibridge/issues/121
switch item.Type {
case string(oaiconst.ValueOf[oaiconst.FunctionCall]()):
case string(constant.ValueOf[constant.FunctionCall]()):
args = i.parseFunctionCallJSONArgs(ctx, item.Arguments)
case string(oaiconst.ValueOf[oaiconst.CustomToolCall]()):
case string(constant.ValueOf[constant.CustomToolCall]()):
args = item.Input
default:
continue
Expand Down
26 changes: 5 additions & 21 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestLastUserPrompt(t *testing.T) {
},
{
name: "array_single_input_string",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingBuiltinTool),
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool),
expected: "Is 3 + 5 a prime number? Use the add function to calculate the sum.",
},
{
Expand Down Expand Up @@ -58,7 +58,7 @@ func TestLastUserPrompt(t *testing.T) {
}
}

func TestLastUserPromptErr(t *testing.T) {
func TestLastUserPromptEmptyPrompt(t *testing.T) {
t.Parallel()

t.Run("nil_struct", func(t *testing.T) {
Expand All @@ -71,45 +71,30 @@ func TestLastUserPromptErr(t *testing.T) {
require.Contains(t, "cannot get last user prompt: nil struct", err.Error())
})

t.Run("nil_struct", func(t *testing.T) {
t.Parallel()

base := responsesInterceptionBase{}
prompt, err := base.lastUserPrompt()
require.Error(t, err)
require.Empty(t, prompt)
require.Contains(t, "cannot get last user prompt: nil req struct", err.Error())
})

// Other cases where the user prompt might be empty.
tests := []struct {
name string
reqPayload []byte
wantErrMsg string
}{
{
name: "empty_input",
reqPayload: []byte(`{"model": "gpt-4o", "input": []}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "no_user_role",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_empty_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_empty_content_array",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_non_input_text_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`),
wantErrMsg: "failed to find last user prompt",
},
}

Expand All @@ -127,9 +112,8 @@ func TestLastUserPromptErr(t *testing.T) {
}

prompt, err := base.lastUserPrompt()
require.Error(t, err)
require.NoError(t, err)
require.Empty(t, prompt)
require.Contains(t, tc.wantErrMsg, err.Error())
})
}
}
Expand Down Expand Up @@ -318,7 +302,7 @@ func TestRecordToolUsage(t *testing.T) {
logger: slog.Make(),
}

base.recordToolUsage(t.Context(), tc.response)
base.recordNonInjectedToolUsage(t.Context(), tc.response)

tools := rec.RecordedToolUsages()
require.Len(t, tools, len(tc.expected))
Expand Down
Loading