diff --git a/go/agent.go b/go/agent.go new file mode 100644 index 0000000000..b7b6c6ef64 --- /dev/null +++ b/go/agent.go @@ -0,0 +1,205 @@ +package agentframework + +import ( + "context" + "maps" + + "github.com/google/uuid" +) + +// Agent runs a conversation with a language model. +type Agent interface { + ID() string + Name() string + Description() string + Run(ctx context.Context, messages []Message, opts ...RunOption) (*AgentResponse, error) +} + +// BaseAgent is the standard Agent implementation backed by a ChatClient. +type BaseAgent struct { + id string + name string + description string + client ChatClient + instructions []string + defaultChatOpts []ChatOption + agentMiddleware []AgentMiddleware + chatMiddleware []ChatMiddleware + functionMiddleware []FunctionMiddleware +} + +// AgentOption configures a BaseAgent. +type AgentOption func(*BaseAgent) + +// NewAgent creates a new BaseAgent with the given ChatClient and options. +func NewAgent(client ChatClient, opts ...AgentOption) *BaseAgent { + a := &BaseAgent{ + id: uuid.New().String(), + client: client, + } + for _, opt := range opts { + opt(a) + } + return a +} + +func (a *BaseAgent) ID() string { return a.id } +func (a *BaseAgent) Name() string { return a.name } +func (a *BaseAgent) Description() string { return a.description } + +// Run executes the agent with the given messages. +func (a *BaseAgent) Run(ctx context.Context, messages []Message, opts ...RunOption) (*AgentResponse, error) { + if len(messages) == 0 { + return nil, ErrEmptyMessages + } + + runOpts := NewRunOptions(opts...) + + ac := &AgentContext{ + Agent: a, + Messages: messages, + Options: &runOpts, + Metadata: make(map[string]any), + } + + terminal := func(ctx context.Context, ac *AgentContext) error { + resp, err := a.runCore(ctx, ac) + if err != nil { + return err + } + ac.Response = resp + return nil + } + + handler := buildAgentChain(a.agentMiddleware, terminal) + if err := handler(ctx, ac); err != nil { + return nil, err + } + + return ac.Response, nil +} + +func (a *BaseAgent) runCore(ctx context.Context, ac *AgentContext) (*AgentResponse, error) { + var fullMessages []Message + for _, instr := range a.instructions { + fullMessages = append(fullMessages, NewSystemMessage(instr)) + } + fullMessages = append(fullMessages, ac.Messages...) + + chatOpts := make([]ChatOption, 0, len(a.defaultChatOpts)+1) + chatOpts = append(chatOpts, a.defaultChatOpts...) + chatOpts = append(chatOpts, func(o *ChatOptions) { + merged := ac.Options.ChatOptions + if merged.Temperature != nil { + o.Temperature = merged.Temperature + } + if merged.MaxTokens != nil { + o.MaxTokens = merged.MaxTokens + } + if merged.Model != "" { + o.Model = merged.Model + } + if merged.Metadata != nil { + if o.Metadata == nil { + o.Metadata = make(map[string]any) + } + maps.Copy(o.Metadata, merged.Metadata) + } + }) + + resolvedOpts := NewChatOptions(chatOpts...) + + cc := &ChatContext{ + Client: a.client, + Messages: fullMessages, + Options: &resolvedOpts, + Metadata: make(map[string]any), + } + + chatTerminal := func(ctx context.Context, cc *ChatContext) error { + var opts []ChatOption + if cc.Options.Temperature != nil { + t := *cc.Options.Temperature + opts = append(opts, WithTemperature(t)) + } + if cc.Options.MaxTokens != nil { + n := *cc.Options.MaxTokens + opts = append(opts, WithMaxTokens(n)) + } + if cc.Options.Model != "" { + opts = append(opts, WithModel(cc.Options.Model)) + } + resp, err := cc.Client.GetResponse(ctx, cc.Messages, opts...) + if err != nil { + return err + } + cc.Response = resp + return nil + } + + chatHandler := buildChatChain(a.chatMiddleware, chatTerminal) + if err := chatHandler(ctx, cc); err != nil { + return nil, err + } + + return &AgentResponse{ + ChatResponse: *cc.Response, + AgentID: a.id, + }, nil +} + +// WithID sets the agent ID. +func WithID(id string) AgentOption { + return func(a *BaseAgent) { + a.id = id + } +} + +// WithName sets the agent name. +func WithName(name string) AgentOption { + return func(a *BaseAgent) { + a.name = name + } +} + +// WithDescription sets the agent description. +func WithDescription(desc string) AgentOption { + return func(a *BaseAgent) { + a.description = desc + } +} + +// WithInstructions sets the system instructions prepended to every request. +func WithInstructions(instructions ...string) AgentOption { + return func(a *BaseAgent) { + a.instructions = instructions + } +} + +// WithDefaultChatOptions sets default ChatOptions applied to every request. +func WithDefaultChatOptions(opts ...ChatOption) AgentOption { + return func(a *BaseAgent) { + a.defaultChatOpts = opts + } +} + +// WithAgentMiddleware appends agent-level middleware to the pipeline. +func WithAgentMiddleware(mw ...AgentMiddleware) AgentOption { + return func(a *BaseAgent) { + a.agentMiddleware = append(a.agentMiddleware, mw...) + } +} + +// WithChatMiddleware appends chat-level middleware to the pipeline. +func WithChatMiddleware(mw ...ChatMiddleware) AgentOption { + return func(a *BaseAgent) { + a.chatMiddleware = append(a.chatMiddleware, mw...) + } +} + +// WithFunctionMiddleware appends function-level middleware to the pipeline. +func WithFunctionMiddleware(mw ...FunctionMiddleware) AgentOption { + return func(a *BaseAgent) { + a.functionMiddleware = append(a.functionMiddleware, mw...) + } +} diff --git a/go/agent_test.go b/go/agent_test.go new file mode 100644 index 0000000000..8d20331748 --- /dev/null +++ b/go/agent_test.go @@ -0,0 +1,165 @@ +package agentframework_test + +import ( + "context" + "errors" + "testing" + + af "github.com/microsoft/agent-framework/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockChatClient is a test double for ChatClient. +type mockChatClient struct { + response *af.ChatResponse + err error + captured struct { + messages []af.Message + opts af.ChatOptions + } +} + +func (m *mockChatClient) GetResponse(_ context.Context, messages []af.Message, opts ...af.ChatOption) (*af.ChatResponse, error) { + m.captured.messages = messages + m.captured.opts = af.NewChatOptions(opts...) + if m.err != nil { + return nil, m.err + } + return m.response, nil +} + +func TestNewAgent(t *testing.T) { + t.Run("generates a UUID id by default", func(t *testing.T) { + agent := af.NewAgent(&mockChatClient{}) + assert.NotEmpty(t, agent.ID()) + }) + + t.Run("uses custom id when provided", func(t *testing.T) { + agent := af.NewAgent(&mockChatClient{}, af.WithID("custom-id")) + assert.Equal(t, "custom-id", agent.ID()) + }) + + t.Run("sets name and description", func(t *testing.T) { + agent := af.NewAgent(&mockChatClient{}, + af.WithName("TestBot"), + af.WithDescription("A test bot"), + ) + assert.Equal(t, "TestBot", agent.Name()) + assert.Equal(t, "A test bot", agent.Description()) + }) +} + +func TestBaseAgentRun(t *testing.T) { + t.Run("delegates to chat client and returns AgentResponse", func(t *testing.T) { + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "Paris")}, + ResponseID: "resp-1", + Usage: &af.UsageDetails{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + }, + } + agent := af.NewAgent(mock, af.WithName("Geo")) + + resp, err := agent.Run(context.Background(), []af.Message{ + af.NewUserMessage("What is the capital of France?"), + }) + + require.NoError(t, err) + assert.Equal(t, "resp-1", resp.ResponseID) + assert.Equal(t, agent.ID(), resp.AgentID) + assert.Len(t, resp.Messages, 1) + assert.Equal(t, "Paris", resp.Messages[0].Text()) + assert.Equal(t, 15, resp.Usage.TotalTokens) + }) + + t.Run("prepends system message from instructions", func(t *testing.T) { + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, + af.WithInstructions("You are helpful.", "Be concise."), + ) + + _, err := agent.Run(context.Background(), []af.Message{ + af.NewUserMessage("hi"), + }) + + require.NoError(t, err) + msgs := mock.captured.messages + require.Len(t, msgs, 3) + assert.Equal(t, af.RoleSystem, msgs[0].Role) + assert.Equal(t, "You are helpful.", msgs[0].Text()) + assert.Equal(t, af.RoleSystem, msgs[1].Role) + assert.Equal(t, "Be concise.", msgs[1].Text()) + assert.Equal(t, af.RoleUser, msgs[2].Role) + assert.Equal(t, "hi", msgs[2].Text()) + }) + + t.Run("applies default chat options", func(t *testing.T) { + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, + af.WithDefaultChatOptions(af.WithTemperature(0.3), af.WithModel("gpt-4o")), + ) + + _, err := agent.Run(context.Background(), []af.Message{ + af.NewUserMessage("hi"), + }) + + require.NoError(t, err) + assert.InDelta(t, 0.3, *mock.captured.opts.Temperature, 0.001) + assert.Equal(t, "gpt-4o", mock.captured.opts.Model) + }) + + t.Run("run options override default chat options", func(t *testing.T) { + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, + af.WithDefaultChatOptions(af.WithTemperature(0.3)), + ) + + _, err := agent.Run(context.Background(), []af.Message{ + af.NewUserMessage("hi"), + }, af.WithChatOption(af.WithTemperature(0.9))) + + require.NoError(t, err) + assert.InDelta(t, 0.9, *mock.captured.opts.Temperature, 0.001) + }) + + t.Run("propagates client error", func(t *testing.T) { + clientErr := errors.New("api error") + mock := &mockChatClient{err: clientErr} + agent := af.NewAgent(mock) + + _, err := agent.Run(context.Background(), []af.Message{ + af.NewUserMessage("hi"), + }) + + assert.ErrorIs(t, err, clientErr) + }) + + t.Run("returns error for empty messages", func(t *testing.T) { + agent := af.NewAgent(&mockChatClient{}) + + _, err := agent.Run(context.Background(), nil) + + assert.ErrorIs(t, err, af.ErrEmptyMessages) + }) + + t.Run("returns error for empty messages slice", func(t *testing.T) { + agent := af.NewAgent(&mockChatClient{}) + + _, err := agent.Run(context.Background(), []af.Message{}) + + assert.ErrorIs(t, err, af.ErrEmptyMessages) + }) +} diff --git a/go/client.go b/go/client.go new file mode 100644 index 0000000000..8e99c5a5cf --- /dev/null +++ b/go/client.go @@ -0,0 +1,8 @@ +package agentframework + +import "context" + +// ChatClient sends messages to a language model and returns a response. +type ChatClient interface { + GetResponse(ctx context.Context, messages []Message, opts ...ChatOption) (*ChatResponse, error) +} diff --git a/go/errors.go b/go/errors.go new file mode 100644 index 0000000000..65f88808da --- /dev/null +++ b/go/errors.go @@ -0,0 +1,11 @@ +package agentframework + +import "errors" + +var ( + // ErrEmptyMessages is returned when an agent is called with no messages. + ErrEmptyMessages = errors.New("agentframework: messages must not be empty") + + // ErrNilClient is returned when an agent is created with a nil ChatClient. + ErrNilClient = errors.New("agentframework: chat client must not be nil") +) diff --git a/go/go.mod b/go/go.mod new file mode 100644 index 0000000000..3d9b5c53e2 --- /dev/null +++ b/go/go.mod @@ -0,0 +1,25 @@ +module github.com/microsoft/agent-framework/go + +go 1.25.0 + +require ( + github.com/google/uuid v1.6.0 + github.com/sashabaranov/go-openai v1.41.2 + github.com/stretchr/testify v1.11.1 + go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/metric v1.43.0 + go.opentelemetry.io/otel/sdk v1.43.0 + go.opentelemetry.io/otel/sdk/metric v1.43.0 + go.opentelemetry.io/otel/trace v1.43.0 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + golang.org/x/sys v0.42.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go/go.sum b/go/go.sum new file mode 100644 index 0000000000..6f26433ec7 --- /dev/null +++ b/go/go.sum @@ -0,0 +1,46 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM= +github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/message.go b/go/message.go new file mode 100644 index 0000000000..8bfaa90a87 --- /dev/null +++ b/go/message.go @@ -0,0 +1,80 @@ +package agentframework + +import ( + "strings" + + "github.com/google/uuid" +) + +// Role represents the role of a message sender. +type Role string + +const ( + RoleSystem Role = "system" + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleTool Role = "tool" +) + +// ContentType identifies the kind of content in a message. +type ContentType string + +const ( + ContentTypeText ContentType = "text" + ContentTypeFunctionCall ContentType = "function_call" + ContentTypeFunctionResult ContentType = "function_result" + ContentTypeError ContentType = "error" + ContentTypeUsage ContentType = "usage" +) + +// Content is a single piece of content within a message. +type Content struct { + Type ContentType + Text string + Name string + Arguments map[string]any + CallID string + Result string + Message string + Code string +} + +// Message is a single message in a conversation. +type Message struct { + Role Role + Contents []Content + ID string + Metadata map[string]any +} + +// Text returns the concatenation of all text contents in the message. +func (m Message) Text() string { + var b strings.Builder + for _, c := range m.Contents { + if c.Type == ContentTypeText { + b.WriteString(c.Text) + } + } + return b.String() +} + +// NewTextMessage creates a message with a single text content. +func NewTextMessage(role Role, text string) Message { + return Message{ + Role: role, + Contents: []Content{ + {Type: ContentTypeText, Text: text}, + }, + ID: uuid.New().String(), + } +} + +// NewUserMessage creates a user message with text content. +func NewUserMessage(text string) Message { + return NewTextMessage(RoleUser, text) +} + +// NewSystemMessage creates a system message with text content. +func NewSystemMessage(text string) Message { + return NewTextMessage(RoleSystem, text) +} diff --git a/go/message_test.go b/go/message_test.go new file mode 100644 index 0000000000..2e859185d3 --- /dev/null +++ b/go/message_test.go @@ -0,0 +1,83 @@ +package agentframework_test + +import ( + "testing" + + af "github.com/microsoft/agent-framework/go" + "github.com/stretchr/testify/assert" +) + +func TestNewTextMessage(t *testing.T) { + msg := af.NewTextMessage(af.RoleAssistant, "hello") + + assert.Equal(t, af.RoleAssistant, msg.Role) + assert.Len(t, msg.Contents, 1) + assert.Equal(t, af.ContentTypeText, msg.Contents[0].Type) + assert.Equal(t, "hello", msg.Contents[0].Text) + assert.NotEmpty(t, msg.ID, "message should have an auto-generated ID") +} + +func TestNewUserMessage(t *testing.T) { + msg := af.NewUserMessage("how are you?") + + assert.Equal(t, af.RoleUser, msg.Role) + assert.Len(t, msg.Contents, 1) + assert.Equal(t, af.ContentTypeText, msg.Contents[0].Type) + assert.Equal(t, "how are you?", msg.Contents[0].Text) +} + +func TestNewSystemMessage(t *testing.T) { + msg := af.NewSystemMessage("you are helpful") + + assert.Equal(t, af.RoleSystem, msg.Role) + assert.Len(t, msg.Contents, 1) + assert.Equal(t, af.ContentTypeText, msg.Contents[0].Type) + assert.Equal(t, "you are helpful", msg.Contents[0].Text) +} + +func TestMessageText(t *testing.T) { + tests := []struct { + name string + msg af.Message + expected string + }{ + { + name: "single text content", + msg: af.NewUserMessage("hello"), + expected: "hello", + }, + { + name: "multiple text contents concatenated", + msg: af.Message{ + Role: af.RoleAssistant, + Contents: []af.Content{ + {Type: af.ContentTypeText, Text: "hello "}, + {Type: af.ContentTypeText, Text: "world"}, + }, + }, + expected: "hello world", + }, + { + name: "non-text contents skipped", + msg: af.Message{ + Role: af.RoleAssistant, + Contents: []af.Content{ + {Type: af.ContentTypeText, Text: "hello"}, + {Type: af.ContentTypeFunctionCall, Name: "foo"}, + }, + }, + expected: "hello", + }, + { + name: "empty contents", + msg: af.Message{Role: af.RoleUser}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.msg.Text()) + }) + } +} diff --git a/go/middleware.go b/go/middleware.go new file mode 100644 index 0000000000..06dd0abc2b --- /dev/null +++ b/go/middleware.go @@ -0,0 +1,107 @@ +package agentframework + +import "context" + +// --- Agent Middleware --- + +type AgentHandler func(ctx context.Context, ac *AgentContext) error + +type AgentMiddleware interface { + HandleAgent(ctx context.Context, ac *AgentContext, next AgentHandler) error +} + +type AgentMiddlewareFunc func(ctx context.Context, ac *AgentContext, next AgentHandler) error + +func (f AgentMiddlewareFunc) HandleAgent(ctx context.Context, ac *AgentContext, next AgentHandler) error { + return f(ctx, ac, next) +} + +type AgentContext struct { + Agent Agent + Messages []Message + Options *RunOptions + Response *AgentResponse + Metadata map[string]any +} + +// --- Chat Middleware --- + +type ChatHandler func(ctx context.Context, cc *ChatContext) error + +type ChatMiddleware interface { + HandleChat(ctx context.Context, cc *ChatContext, next ChatHandler) error +} + +type ChatMiddlewareFunc func(ctx context.Context, cc *ChatContext, next ChatHandler) error + +func (f ChatMiddlewareFunc) HandleChat(ctx context.Context, cc *ChatContext, next ChatHandler) error { + return f(ctx, cc, next) +} + +type ChatContext struct { + Client ChatClient + Messages []Message + Options *ChatOptions + Response *ChatResponse + Metadata map[string]any +} + +// --- Function Middleware --- + +type FunctionHandler func(ctx context.Context, fc *FunctionContext) error + +type FunctionMiddleware interface { + HandleFunction(ctx context.Context, fc *FunctionContext, next FunctionHandler) error +} + +type FunctionMiddlewareFunc func(ctx context.Context, fc *FunctionContext, next FunctionHandler) error + +func (f FunctionMiddlewareFunc) HandleFunction(ctx context.Context, fc *FunctionContext, next FunctionHandler) error { + return f(ctx, fc, next) +} + +type FunctionContext struct { + ToolName string + Arguments map[string]any + Result string + Err error + Metadata map[string]any +} + +// --- Chain Builders --- + +func buildAgentChain(middlewares []AgentMiddleware, terminal AgentHandler) AgentHandler { + handler := terminal + for i := len(middlewares) - 1; i >= 0; i-- { + mw := middlewares[i] + next := handler + handler = func(ctx context.Context, ac *AgentContext) error { + return mw.HandleAgent(ctx, ac, next) + } + } + return handler +} + +func buildChatChain(middlewares []ChatMiddleware, terminal ChatHandler) ChatHandler { + handler := terminal + for i := len(middlewares) - 1; i >= 0; i-- { + mw := middlewares[i] + next := handler + handler = func(ctx context.Context, cc *ChatContext) error { + return mw.HandleChat(ctx, cc, next) + } + } + return handler +} + +func buildFunctionChain(middlewares []FunctionMiddleware, terminal FunctionHandler) FunctionHandler { + handler := terminal + for i := len(middlewares) - 1; i >= 0; i-- { + mw := middlewares[i] + next := handler + handler = func(ctx context.Context, fc *FunctionContext) error { + return mw.HandleFunction(ctx, fc, next) + } + } + return handler +} diff --git a/go/middleware_test.go b/go/middleware_test.go new file mode 100644 index 0000000000..1ba0f4d874 --- /dev/null +++ b/go/middleware_test.go @@ -0,0 +1,206 @@ +package agentframework_test + +import ( + "context" + "errors" + "testing" + + af "github.com/microsoft/agent-framework/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAgentMiddlewareChain(t *testing.T) { + t.Run("single middleware wraps agent run", func(t *testing.T) { + var order []string + mw := af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + order = append(order, "before") + err := next(ctx, ac) + order = append(order, "after") + return err + }) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, af.WithAgentMiddleware(mw)) + resp, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.Equal(t, "ok", resp.Messages[0].Text()) + assert.Equal(t, []string{"before", "after"}, order) + }) + + t.Run("middleware chain executes in order", func(t *testing.T) { + var order []string + mw1 := af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + order = append(order, "mw1-before") + err := next(ctx, ac) + order = append(order, "mw1-after") + return err + }) + mw2 := af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + order = append(order, "mw2-before") + err := next(ctx, ac) + order = append(order, "mw2-after") + return err + }) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, af.WithAgentMiddleware(mw1, mw2)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.Equal(t, []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"}, order) + }) + + t.Run("middleware can short-circuit by not calling next", func(t *testing.T) { + mw := af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + ac.Response = &af.AgentResponse{ + ChatResponse: af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "cached")}, + }, + AgentID: ac.Agent.ID(), + } + return nil + }) + mock := &mockChatClient{} + agent := af.NewAgent(mock, af.WithAgentMiddleware(mw)) + resp, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.Equal(t, "cached", resp.Messages[0].Text()) + }) + + t.Run("middleware can return error to stop chain", func(t *testing.T) { + expectedErr := errors.New("middleware error") + mw := af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + return expectedErr + }) + mock := &mockChatClient{} + agent := af.NewAgent(mock, af.WithAgentMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + assert.ErrorIs(t, err, expectedErr) + }) + + t.Run("middleware reads agent context", func(t *testing.T) { + var capturedName string + var capturedMsgCount int + mw := af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + capturedName = ac.Agent.Name() + capturedMsgCount = len(ac.Messages) + return next(ctx, ac) + }) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, af.WithName("TestBot"), af.WithAgentMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.Equal(t, "TestBot", capturedName) + assert.Equal(t, 1, capturedMsgCount) + }) + + t.Run("no middleware — agent works normally", func(t *testing.T) { + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock) + resp, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.Equal(t, "ok", resp.Messages[0].Text()) + }) +} + +func TestChatMiddlewareChain(t *testing.T) { + t.Run("chat middleware wraps client call", func(t *testing.T) { + var order []string + chatMw := af.ChatMiddlewareFunc(func(ctx context.Context, cc *af.ChatContext, next af.ChatHandler) error { + order = append(order, "chat-before") + err := next(ctx, cc) + order = append(order, "chat-after") + return err + }) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, af.WithChatMiddleware(chatMw)) + resp, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.Equal(t, "ok", resp.Messages[0].Text()) + assert.Equal(t, []string{"chat-before", "chat-after"}, order) + }) + + t.Run("chat middleware can modify options before call", func(t *testing.T) { + chatMw := af.ChatMiddlewareFunc(func(ctx context.Context, cc *af.ChatContext, next af.ChatHandler) error { + temp := 0.42 + cc.Options.Temperature = &temp + return next(ctx, cc) + }) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, af.WithChatMiddleware(chatMw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.InDelta(t, 0.42, *mock.captured.opts.Temperature, 0.001) + }) + + t.Run("chat middleware can read response after call", func(t *testing.T) { + var capturedResponseID string + chatMw := af.ChatMiddlewareFunc(func(ctx context.Context, cc *af.ChatContext, next af.ChatHandler) error { + err := next(ctx, cc) + if cc.Response != nil { + capturedResponseID = cc.Response.ResponseID + } + return err + }) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + ResponseID: "resp-42", + }, + } + agent := af.NewAgent(mock, af.WithChatMiddleware(chatMw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.Equal(t, "resp-42", capturedResponseID) + }) + + t.Run("agent and chat middleware compose", func(t *testing.T) { + var order []string + agentMw := af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + order = append(order, "agent-before") + err := next(ctx, ac) + order = append(order, "agent-after") + return err + }) + chatMw := af.ChatMiddlewareFunc(func(ctx context.Context, cc *af.ChatContext, next af.ChatHandler) error { + order = append(order, "chat-before") + err := next(ctx, cc) + order = append(order, "chat-after") + return err + }) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + }, + } + agent := af.NewAgent(mock, + af.WithAgentMiddleware(agentMw), + af.WithChatMiddleware(chatMw), + ) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + assert.Equal(t, []string{"agent-before", "chat-before", "chat-after", "agent-after"}, order) + }) +} diff --git a/go/observability/logging.go b/go/observability/logging.go new file mode 100644 index 0000000000..7905673265 --- /dev/null +++ b/go/observability/logging.go @@ -0,0 +1,48 @@ +package observability + +import ( + "context" + "log/slog" + "time" + + af "github.com/microsoft/agent-framework/go" +) + +func NewLoggingAgentMiddleware(logger *slog.Logger) af.AgentMiddleware { + return af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + start := time.Now() + agentName := ac.Agent.Name() + msgCount := len(ac.Messages) + + logger.InfoContext(ctx, "agent run started", + slog.String("agent.name", agentName), + slog.Int("message.count", msgCount), + ) + + err := next(ctx, ac) + elapsed := time.Since(start) + + if err != nil { + logger.ErrorContext(ctx, "agent run failed", + slog.String("agent.name", agentName), + slog.Duration("duration", elapsed), + slog.String("error", err.Error()), + ) + return err + } + + attrs := []slog.Attr{ + slog.String("agent.name", agentName), + slog.Duration("duration", elapsed), + } + if ac.Response != nil && ac.Response.Usage != nil { + attrs = append(attrs, + slog.Int("usage.input_tokens", ac.Response.Usage.InputTokens), + slog.Int("usage.output_tokens", ac.Response.Usage.OutputTokens), + ) + } + logger.LogAttrs(ctx, slog.LevelInfo, "agent run completed", attrs...) + + return nil + }) +} diff --git a/go/observability/logging_test.go b/go/observability/logging_test.go new file mode 100644 index 0000000000..531a9df1d2 --- /dev/null +++ b/go/observability/logging_test.go @@ -0,0 +1,83 @@ +package observability_test + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "testing" + + af "github.com/microsoft/agent-framework/go" + "github.com/microsoft/agent-framework/go/observability" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func parseLogLines(t *testing.T, data []byte) []map[string]any { + t.Helper() + var lines []map[string]any + for _, line := range bytes.Split(data, []byte("\n")) { + if len(line) == 0 { + continue + } + var m map[string]any + require.NoError(t, json.Unmarshal(line, &m)) + lines = append(lines, m) + } + return lines +} + +func TestLoggingAgentMiddleware(t *testing.T) { + t.Run("logs agent run at info level", func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})) + + mw := observability.NewLoggingAgentMiddleware(logger) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + Usage: &af.UsageDetails{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + }, + } + agent := af.NewAgent(mock, af.WithName("TestBot"), af.WithAgentMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + + lines := parseLogLines(t, buf.Bytes()) + require.GreaterOrEqual(t, len(lines), 1) + + found := false + for _, line := range lines { + if msg, ok := line["msg"].(string); ok && msg == "agent run completed" { + found = true + assert.Equal(t, "TestBot", line["agent.name"]) + assert.Equal(t, "INFO", line["level"]) + break + } + } + assert.True(t, found, "expected 'agent run completed' log line") + }) + + t.Run("logs error at error level", func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})) + + mw := observability.NewLoggingAgentMiddleware(logger) + mock := &mockChatClient{err: assert.AnError} + agent := af.NewAgent(mock, af.WithName("FailBot"), af.WithAgentMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + assert.Error(t, err) + + lines := parseLogLines(t, buf.Bytes()) + found := false + for _, line := range lines { + if msg, ok := line["msg"].(string); ok && msg == "agent run failed" { + found = true + assert.Equal(t, "ERROR", line["level"]) + assert.Equal(t, "FailBot", line["agent.name"]) + break + } + } + assert.True(t, found, "expected 'agent run failed' log line") + }) +} diff --git a/go/observability/metrics.go b/go/observability/metrics.go new file mode 100644 index 0000000000..493ce0bbe7 --- /dev/null +++ b/go/observability/metrics.go @@ -0,0 +1,51 @@ +package observability + +import ( + "context" + "time" + + af "github.com/microsoft/agent-framework/go" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +func NewMetricsChatMiddleware(mp metric.MeterProvider) af.ChatMiddleware { + meter := mp.Meter("agentframework") + duration, _ := meter.Float64Histogram("gen_ai.client.operation.duration", + metric.WithUnit("s"), + metric.WithDescription("Duration of chat client operations"), + ) + tokenUsage, _ := meter.Int64Counter("gen_ai.client.token.usage", + metric.WithDescription("Token usage by type"), + ) + errorCount, _ := meter.Int64Counter("gen_ai.client.error.count", + metric.WithDescription("Count of chat client errors"), + ) + + return af.ChatMiddlewareFunc(func(ctx context.Context, cc *af.ChatContext, next af.ChatHandler) error { + start := time.Now() + var attrs []attribute.KeyValue + if cc.Options != nil && cc.Options.Model != "" { + attrs = append(attrs, attribute.String("gen_ai.request.model", cc.Options.Model)) + } + + err := next(ctx, cc) + elapsed := time.Since(start).Seconds() + duration.Record(ctx, elapsed, metric.WithAttributes(attrs...)) + + if err != nil { + errorCount.Add(ctx, 1, metric.WithAttributes(attrs...)) + return err + } + + if cc.Response != nil && cc.Response.Usage != nil { + tokenUsage.Add(ctx, int64(cc.Response.Usage.InputTokens), + metric.WithAttributes(append(attrs, attribute.String("gen_ai.token.type", "input"))...), + ) + tokenUsage.Add(ctx, int64(cc.Response.Usage.OutputTokens), + metric.WithAttributes(append(attrs, attribute.String("gen_ai.token.type", "output"))...), + ) + } + return nil + }) +} diff --git a/go/observability/metrics_test.go b/go/observability/metrics_test.go new file mode 100644 index 0000000000..2eac8f787e --- /dev/null +++ b/go/observability/metrics_test.go @@ -0,0 +1,74 @@ +package observability_test + +import ( + "context" + "testing" + + af "github.com/microsoft/agent-framework/go" + "github.com/microsoft/agent-framework/go/observability" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" +) + +func setupMeter() (*sdkmetric.MeterProvider, *sdkmetric.ManualReader) { + reader := sdkmetric.NewManualReader() + mp := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + return mp, reader +} + +func collectMetrics(t *testing.T, reader *sdkmetric.ManualReader) metricdata.ResourceMetrics { + t.Helper() + var rm metricdata.ResourceMetrics + err := reader.Collect(context.Background(), &rm) + require.NoError(t, err) + return rm +} + +func findMetric(rm metricdata.ResourceMetrics, name string) *metricdata.Metrics { + for _, sm := range rm.ScopeMetrics { + for i := range sm.Metrics { + if sm.Metrics[i].Name == name { + return &sm.Metrics[i] + } + } + } + return nil +} + +func TestMetricsChatMiddleware(t *testing.T) { + t.Run("records duration and token usage", func(t *testing.T) { + mp, reader := setupMeter() + defer mp.Shutdown(context.Background()) + + mw := observability.NewMetricsChatMiddleware(mp) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + Usage: &af.UsageDetails{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + }, + } + agent := af.NewAgent(mock, af.WithChatMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + + rm := collectMetrics(t, reader) + assert.NotNil(t, findMetric(rm, "gen_ai.client.operation.duration"), "duration metric should be recorded") + assert.NotNil(t, findMetric(rm, "gen_ai.client.token.usage"), "token usage metric should be recorded") + }) + + t.Run("records error count on failure", func(t *testing.T) { + mp, reader := setupMeter() + defer mp.Shutdown(context.Background()) + + mw := observability.NewMetricsChatMiddleware(mp) + mock := &mockChatClient{err: assert.AnError} + agent := af.NewAgent(mock, af.WithChatMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + assert.Error(t, err) + + rm := collectMetrics(t, reader) + assert.NotNil(t, findMetric(rm, "gen_ai.client.error.count"), "error count metric should be recorded") + }) +} diff --git a/go/observability/tracing.go b/go/observability/tracing.go new file mode 100644 index 0000000000..f7eaf6cb50 --- /dev/null +++ b/go/observability/tracing.go @@ -0,0 +1,76 @@ +package observability + +import ( + "context" + + af "github.com/microsoft/agent-framework/go" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +func NewTracingAgentMiddleware(tp trace.TracerProvider) af.AgentMiddleware { + tracer := tp.Tracer("agentframework") + return af.AgentMiddlewareFunc(func(ctx context.Context, ac *af.AgentContext, next af.AgentHandler) error { + ctx, span := tracer.Start(ctx, "agent.run", + trace.WithAttributes( + attribute.String("gen_ai.agent.name", ac.Agent.Name()), + attribute.String("gen_ai.agent.id", ac.Agent.ID()), + attribute.Int("gen_ai.request.message_count", len(ac.Messages)), + ), + ) + defer span.End() + + err := next(ctx, ac) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return err + } + + if ac.Response != nil && ac.Response.Usage != nil { + span.SetAttributes( + attribute.Int("gen_ai.usage.input_tokens", ac.Response.Usage.InputTokens), + attribute.Int("gen_ai.usage.output_tokens", ac.Response.Usage.OutputTokens), + ) + } + return nil + }) +} + +func NewTracingChatMiddleware(tp trace.TracerProvider) af.ChatMiddleware { + tracer := tp.Tracer("agentframework") + return af.ChatMiddlewareFunc(func(ctx context.Context, cc *af.ChatContext, next af.ChatHandler) error { + ctx, span := tracer.Start(ctx, "chat.get_response", + trace.WithAttributes( + attribute.Int("gen_ai.request.message_count", len(cc.Messages)), + ), + ) + defer span.End() + + if cc.Options != nil && cc.Options.Model != "" { + span.SetAttributes(attribute.String("gen_ai.request.model", cc.Options.Model)) + } + + err := next(ctx, cc) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return err + } + + if cc.Response != nil && cc.Response.Usage != nil { + span.SetAttributes( + attribute.Int("gen_ai.usage.input_tokens", cc.Response.Usage.InputTokens), + attribute.Int("gen_ai.usage.output_tokens", cc.Response.Usage.OutputTokens), + ) + } + if cc.Response != nil { + span.SetAttributes( + attribute.String("gen_ai.response.id", cc.Response.ResponseID), + attribute.Int("gen_ai.response.message_count", len(cc.Response.Messages)), + ) + } + return nil + }) +} diff --git a/go/observability/tracing_test.go b/go/observability/tracing_test.go new file mode 100644 index 0000000000..c320f5db5c --- /dev/null +++ b/go/observability/tracing_test.go @@ -0,0 +1,103 @@ +package observability_test + +import ( + "context" + "testing" + + af "github.com/microsoft/agent-framework/go" + "github.com/microsoft/agent-framework/go/observability" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +type mockChatClient struct { + response *af.ChatResponse + err error +} + +func (m *mockChatClient) GetResponse(_ context.Context, _ []af.Message, _ ...af.ChatOption) (*af.ChatResponse, error) { + if m.err != nil { + return nil, m.err + } + return m.response, nil +} + +func setupTracer() (*sdktrace.TracerProvider, *tracetest.InMemoryExporter) { + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + return tp, exporter +} + +func spanAttrMap(span tracetest.SpanStub) map[string]string { + m := make(map[string]string) + for _, attr := range span.Attributes { + m[string(attr.Key)] = attr.Value.Emit() + } + return m +} + +func TestTracingAgentMiddleware(t *testing.T) { + t.Run("creates a span for agent run", func(t *testing.T) { + tp, exporter := setupTracer() + defer tp.Shutdown(context.Background()) + + mw := observability.NewTracingAgentMiddleware(tp) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + Usage: &af.UsageDetails{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + }, + } + agent := af.NewAgent(mock, af.WithName("TestBot"), af.WithAgentMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + assert.Equal(t, "agent.run", spans[0].Name) + attrs := spanAttrMap(spans[0]) + assert.Equal(t, "TestBot", attrs["gen_ai.agent.name"]) + }) + + t.Run("records error on span when agent fails", func(t *testing.T) { + tp, exporter := setupTracer() + defer tp.Shutdown(context.Background()) + + mw := observability.NewTracingAgentMiddleware(tp) + mock := &mockChatClient{err: assert.AnError} + agent := af.NewAgent(mock, af.WithAgentMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + assert.Error(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + assert.NotEmpty(t, spans[0].Events) + }) +} + +func TestTracingChatMiddleware(t *testing.T) { + t.Run("creates a span for chat client call", func(t *testing.T) { + tp, exporter := setupTracer() + defer tp.Shutdown(context.Background()) + + mw := observability.NewTracingChatMiddleware(tp) + mock := &mockChatClient{ + response: &af.ChatResponse{ + Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "ok")}, + Usage: &af.UsageDetails{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + }, + } + agent := af.NewAgent(mock, af.WithChatMiddleware(mw)) + _, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("hi")}) + require.NoError(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + assert.Equal(t, "chat.get_response", spans[0].Name) + attrs := spanAttrMap(spans[0]) + assert.Equal(t, "10", attrs["gen_ai.usage.input_tokens"]) + assert.Equal(t, "5", attrs["gen_ai.usage.output_tokens"]) + }) +} diff --git a/go/openai/client.go b/go/openai/client.go new file mode 100644 index 0000000000..df77302b26 --- /dev/null +++ b/go/openai/client.go @@ -0,0 +1,93 @@ +package openai + +import ( + "context" + "fmt" + + af "github.com/microsoft/agent-framework/go" + gogpt "github.com/sashabaranov/go-openai" +) + +// Client is an OpenAI-backed ChatClient. +type Client struct { + inner *gogpt.Client + model string +} + +// NewClient creates a new OpenAI ChatClient. +func NewClient(apiKey string, model string, opts ...ClientOption) *Client { + var cfg clientConfig + for _, opt := range opts { + opt(&cfg) + } + + clientCfg := gogpt.DefaultConfig(apiKey) + if cfg.baseURL != "" { + clientCfg.BaseURL = cfg.baseURL + } + + return &Client{ + inner: gogpt.NewClientWithConfig(clientCfg), + model: model, + } +} + +// GetResponse sends messages to OpenAI and returns the response. +func (c *Client) GetResponse(ctx context.Context, messages []af.Message, opts ...af.ChatOption) (*af.ChatResponse, error) { + chatOpts := af.NewChatOptions(opts...) + + model := c.model + if chatOpts.Model != "" { + model = chatOpts.Model + } + + req := gogpt.ChatCompletionRequest{ + Model: model, + Messages: toOpenAIMessages(messages), + } + + if chatOpts.Temperature != nil { + req.Temperature = float32(*chatOpts.Temperature) + } + if chatOpts.MaxTokens != nil { + req.MaxTokens = *chatOpts.MaxTokens + } + + resp, err := c.inner.CreateChatCompletion(ctx, req) + if err != nil { + return nil, fmt.Errorf("openai: %w", err) + } + + return fromOpenAIResponse(resp), nil +} + +func toOpenAIMessages(messages []af.Message) []gogpt.ChatCompletionMessage { + out := make([]gogpt.ChatCompletionMessage, 0, len(messages)) + for _, m := range messages { + out = append(out, gogpt.ChatCompletionMessage{ + Role: string(m.Role), + Content: m.Text(), + }) + } + return out +} + +func fromOpenAIResponse(resp gogpt.ChatCompletionResponse) *af.ChatResponse { + messages := make([]af.Message, 0, len(resp.Choices)) + for _, choice := range resp.Choices { + messages = append(messages, af.NewTextMessage( + af.Role(choice.Message.Role), + choice.Message.Content, + )) + } + + return &af.ChatResponse{ + Messages: messages, + ResponseID: resp.ID, + Usage: &af.UsageDetails{ + InputTokens: resp.Usage.PromptTokens, + OutputTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + }, + } +} diff --git a/go/openai/client_test.go b/go/openai/client_test.go new file mode 100644 index 0000000000..c31059065f --- /dev/null +++ b/go/openai/client_test.go @@ -0,0 +1,159 @@ +package openai_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + af "github.com/microsoft/agent-framework/go" + "github.com/microsoft/agent-framework/go/openai" + gogpt "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClientGetResponse(t *testing.T) { + t.Run("sends messages and returns response", func(t *testing.T) { + var capturedReq gogpt.ChatCompletionRequest + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &capturedReq) + + resp := gogpt.ChatCompletionResponse{ + ID: "chatcmpl-123", + Choices: []gogpt.ChatCompletionChoice{ + { + Message: gogpt.ChatCompletionMessage{ + Role: "assistant", + Content: "Paris is the capital of France.", + }, + FinishReason: gogpt.FinishReasonStop, + }, + }, + Usage: gogpt.Usage{ + PromptTokens: 10, + CompletionTokens: 8, + TotalTokens: 18, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := openai.NewClient("test-key", "gpt-4o", + openai.WithBaseURL(server.URL+"/v1"), + ) + + messages := []af.Message{ + af.NewSystemMessage("You are helpful."), + af.NewUserMessage("What is the capital of France?"), + } + + resp, err := client.GetResponse(context.Background(), messages) + + require.NoError(t, err) + + // Verify request was mapped correctly + require.Len(t, capturedReq.Messages, 2) + assert.Equal(t, "system", capturedReq.Messages[0].Role) + assert.Equal(t, "You are helpful.", capturedReq.Messages[0].Content) + assert.Equal(t, "user", capturedReq.Messages[1].Role) + assert.Equal(t, "What is the capital of France?", capturedReq.Messages[1].Content) + assert.Equal(t, "gpt-4o", capturedReq.Model) + + // Verify response was mapped correctly + require.Len(t, resp.Messages, 1) + assert.Equal(t, af.RoleAssistant, resp.Messages[0].Role) + assert.Equal(t, "Paris is the capital of France.", resp.Messages[0].Text()) + assert.Equal(t, "chatcmpl-123", resp.ResponseID) + require.NotNil(t, resp.Usage) + assert.Equal(t, 10, resp.Usage.InputTokens) + assert.Equal(t, 8, resp.Usage.OutputTokens) + assert.Equal(t, 18, resp.Usage.TotalTokens) + }) + + t.Run("applies chat options", func(t *testing.T) { + var capturedReq gogpt.ChatCompletionRequest + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &capturedReq) + + resp := gogpt.ChatCompletionResponse{ + Choices: []gogpt.ChatCompletionChoice{ + {Message: gogpt.ChatCompletionMessage{Role: "assistant", Content: "ok"}}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := openai.NewClient("test-key", "gpt-4o", + openai.WithBaseURL(server.URL+"/v1"), + ) + + _, err := client.GetResponse(context.Background(), + []af.Message{af.NewUserMessage("hi")}, + af.WithTemperature(0.5), + af.WithMaxTokens(100), + ) + + require.NoError(t, err) + assert.InDelta(t, 0.5, capturedReq.Temperature, 0.001) + assert.Equal(t, 100, capturedReq.MaxTokens) + }) + + t.Run("returns error on API failure", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":{"message":"server error"}}`)) + })) + defer server.Close() + + client := openai.NewClient("test-key", "gpt-4o", + openai.WithBaseURL(server.URL+"/v1"), + ) + + _, err := client.GetResponse(context.Background(), + []af.Message{af.NewUserMessage("hi")}, + ) + + assert.Error(t, err) + }) + + t.Run("model option overrides default model", func(t *testing.T) { + var capturedReq gogpt.ChatCompletionRequest + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &capturedReq) + + resp := gogpt.ChatCompletionResponse{ + Choices: []gogpt.ChatCompletionChoice{ + {Message: gogpt.ChatCompletionMessage{Role: "assistant", Content: "ok"}}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := openai.NewClient("test-key", "gpt-4o", + openai.WithBaseURL(server.URL+"/v1"), + ) + + _, err := client.GetResponse(context.Background(), + []af.Message{af.NewUserMessage("hi")}, + af.WithModel("gpt-3.5-turbo"), + ) + + require.NoError(t, err) + assert.Equal(t, "gpt-3.5-turbo", capturedReq.Model) + }) +} diff --git a/go/openai/options.go b/go/openai/options.go new file mode 100644 index 0000000000..addf0455ab --- /dev/null +++ b/go/openai/options.go @@ -0,0 +1,15 @@ +package openai + +// ClientOption configures the OpenAI Client. +type ClientOption func(*clientConfig) + +type clientConfig struct { + baseURL string +} + +// WithBaseURL sets the base URL for the OpenAI API (useful for testing or proxies). +func WithBaseURL(url string) ClientOption { + return func(c *clientConfig) { + c.baseURL = url + } +} diff --git a/go/options.go b/go/options.go new file mode 100644 index 0000000000..dffbc93f8e --- /dev/null +++ b/go/options.go @@ -0,0 +1,66 @@ +package agentframework + +// ChatOptions holds configuration for a ChatClient.GetResponse call. +type ChatOptions struct { + Temperature *float64 + MaxTokens *int + Model string + Metadata map[string]any +} + +// ChatOption is a functional option for ChatOptions. +type ChatOption func(*ChatOptions) + +// NewChatOptions creates ChatOptions by applying the given options. +func NewChatOptions(opts ...ChatOption) ChatOptions { + var o ChatOptions + for _, opt := range opts { + opt(&o) + } + return o +} + +// WithTemperature sets the sampling temperature. +func WithTemperature(t float64) ChatOption { + return func(o *ChatOptions) { + o.Temperature = &t + } +} + +// WithMaxTokens sets the maximum number of tokens to generate. +func WithMaxTokens(n int) ChatOption { + return func(o *ChatOptions) { + o.MaxTokens = &n + } +} + +// WithModel sets the model name. +func WithModel(m string) ChatOption { + return func(o *ChatOptions) { + o.Model = m + } +} + +// RunOptions holds configuration for an Agent.Run call. +type RunOptions struct { + ChatOptions ChatOptions +} + +// RunOption is a functional option for RunOptions. +type RunOption func(*RunOptions) + +// NewRunOptions creates RunOptions by applying the given options. +func NewRunOptions(opts ...RunOption) RunOptions { + var o RunOptions + for _, opt := range opts { + opt(&o) + } + return o +} + +// WithChatOption wraps a ChatOption into a RunOption. +func WithChatOption(co ChatOption) RunOption { + return func(o *RunOptions) { + co(&o.ChatOptions) + } +} diff --git a/go/options_test.go b/go/options_test.go new file mode 100644 index 0000000000..b6473219f0 --- /dev/null +++ b/go/options_test.go @@ -0,0 +1,59 @@ +package agentframework_test + +import ( + "testing" + + af "github.com/microsoft/agent-framework/go" + "github.com/stretchr/testify/assert" +) + +func TestChatOptions(t *testing.T) { + t.Run("WithTemperature sets temperature", func(t *testing.T) { + opts := af.NewChatOptions(af.WithTemperature(0.7)) + assert.NotNil(t, opts.Temperature) + assert.InDelta(t, 0.7, *opts.Temperature, 0.001) + }) + + t.Run("WithMaxTokens sets max tokens", func(t *testing.T) { + opts := af.NewChatOptions(af.WithMaxTokens(100)) + assert.NotNil(t, opts.MaxTokens) + assert.Equal(t, 100, *opts.MaxTokens) + }) + + t.Run("WithModel sets model", func(t *testing.T) { + opts := af.NewChatOptions(af.WithModel("gpt-4o")) + assert.Equal(t, "gpt-4o", opts.Model) + }) + + t.Run("zero value has nil pointers", func(t *testing.T) { + opts := af.NewChatOptions() + assert.Nil(t, opts.Temperature) + assert.Nil(t, opts.MaxTokens) + assert.Empty(t, opts.Model) + }) + + t.Run("multiple options compose", func(t *testing.T) { + opts := af.NewChatOptions( + af.WithTemperature(0.5), + af.WithMaxTokens(200), + af.WithModel("gpt-4o"), + ) + assert.InDelta(t, 0.5, *opts.Temperature, 0.001) + assert.Equal(t, 200, *opts.MaxTokens) + assert.Equal(t, "gpt-4o", opts.Model) + }) +} + +func TestRunOptions(t *testing.T) { + t.Run("WithChatOption applies to inner ChatOptions", func(t *testing.T) { + opts := af.NewRunOptions(af.WithChatOption(af.WithTemperature(0.9))) + assert.NotNil(t, opts.ChatOptions.Temperature) + assert.InDelta(t, 0.9, *opts.ChatOptions.Temperature, 0.001) + }) + + t.Run("zero value has empty ChatOptions", func(t *testing.T) { + opts := af.NewRunOptions() + assert.Nil(t, opts.ChatOptions.Temperature) + assert.Nil(t, opts.ChatOptions.MaxTokens) + }) +} diff --git a/go/response.go b/go/response.go new file mode 100644 index 0000000000..8420fc1650 --- /dev/null +++ b/go/response.go @@ -0,0 +1,21 @@ +package agentframework + +// UsageDetails contains token usage information from a chat response. +type UsageDetails struct { + InputTokens int + OutputTokens int + TotalTokens int +} + +// ChatResponse is the result of a ChatClient.GetResponse call. +type ChatResponse struct { + Messages []Message + ResponseID string + Usage *UsageDetails +} + +// AgentResponse is the result of an Agent.Run call. +type AgentResponse struct { + ChatResponse + AgentID string +}