Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 135 additions & 27 deletions go/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package agentframework

import (
"context"
"fmt"
"maps"

"github.com/google/uuid"
Expand Down Expand Up @@ -32,6 +33,8 @@ type BaseAgent struct {
agentMiddleware []AgentMiddleware
chatMiddleware []ChatMiddleware
functionMiddleware []FunctionMiddleware
tools map[string]*FunctionTool
maxToolRounds int
}

// AgentOption configures a BaseAgent.
Expand Down Expand Up @@ -194,45 +197,131 @@ func (a *BaseAgent) runCore(ctx context.Context, ac *AgentContext) (*AgentRespon
}
})

resolvedOpts := NewChatOptions(chatOpts...)
if len(a.tools) > 0 {
toolDefs := make([]FunctionTool, 0, len(a.tools))
for _, t := range a.tools {
toolDefs = append(toolDefs, *t)
}
chatOpts = append(chatOpts, WithToolDefs(toolDefs...))
}

cc := &ChatContext{
Client: a.client,
Messages: fullMessages,
Options: &resolvedOpts,
Metadata: make(map[string]any),
maxRounds := a.maxToolRounds
if maxRounds <= 0 {
maxRounds = 10
}

chatTerminal := func(ctx context.Context, cc *ChatContext) error {
var opts []ChatOption
if cc.Options.Temperature != nil {
t := *cc.Options.Temperature
opts = append(opts, WithTemperature(t))
for range maxRounds {
resolvedOpts := NewChatOptions(chatOpts...)

cc := &ChatContext{
Client: a.client,
Messages: fullMessages,
Options: &resolvedOpts,
Metadata: make(map[string]any),
}

chatTerminal := func(ctx context.Context, cc *ChatContext) error {
var opts []ChatOption
if cc.Options.Temperature != nil {
t := *cc.Options.Temperature
opts = append(opts, WithTemperature(t))
}
if cc.Options.MaxTokens != nil {
n := *cc.Options.MaxTokens
opts = append(opts, WithMaxTokens(n))
}
if cc.Options.Model != "" {
opts = append(opts, WithModel(cc.Options.Model))
}
if len(cc.Options.Tools) > 0 {
opts = append(opts, WithToolDefs(cc.Options.Tools...))
}
resp, err := cc.Client.GetResponse(ctx, cc.Messages, opts...)
if err != nil {
return err
}
cc.Response = resp
return nil
}

chatHandler := buildChatChain(a.chatMiddleware, chatTerminal)
if err := chatHandler(ctx, cc); err != nil {
return nil, err
}
if cc.Options.MaxTokens != nil {
n := *cc.Options.MaxTokens
opts = append(opts, WithMaxTokens(n))

toolCalls := extractToolCalls(cc.Response)
if len(toolCalls) == 0 {
return &AgentResponse{
ChatResponse: *cc.Response,
AgentID: a.id,
}, nil
}
if cc.Options.Model != "" {
opts = append(opts, WithModel(cc.Options.Model))

fullMessages = append(fullMessages, cc.Response.Messages...)

for _, tc := range toolCalls {
result, err := a.invokeTool(ctx, tc)
var resultContent Content
if err != nil {
resultContent = Content{
Type: ContentTypeFunctionResult,
CallID: tc.CallID,
Result: "Error: " + err.Error(),
}
} else {
resultContent = Content{
Type: ContentTypeFunctionResult,
CallID: tc.CallID,
Result: result,
}
}
fullMessages = append(fullMessages, Message{
Role: RoleTool,
Contents: []Content{resultContent},
})
}
resp, err := cc.Client.GetResponse(ctx, cc.Messages, opts...)
if err != nil {
return err
}

return nil, ErrMaxToolRounds
}

func extractToolCalls(resp *ChatResponse) []Content {
var calls []Content
for _, msg := range resp.Messages {
for _, c := range msg.Contents {
if c.Type == ContentTypeFunctionCall {
calls = append(calls, c)
}
}
cc.Response = resp
}
return calls
}

func (a *BaseAgent) invokeTool(ctx context.Context, tc Content) (string, error) {
tool, ok := a.tools[tc.Name]
if !ok {
return "", fmt.Errorf("agentframework: unknown tool %q", tc.Name)
}

fc := &FunctionContext{
ToolName: tc.Name,
Arguments: tc.Arguments,
Metadata: make(map[string]any),
}

terminal := func(ctx context.Context, fc *FunctionContext) error {
result, err := tool.Invoke(ctx, fc.Arguments)
fc.Result = result
fc.Err = err
return nil
}

chatHandler := buildChatChain(a.chatMiddleware, chatTerminal)
if err := chatHandler(ctx, cc); err != nil {
return nil, err
handler := buildFunctionChain(a.functionMiddleware, terminal)
if err := handler(ctx, fc); err != nil {
return "", err
}

return &AgentResponse{
ChatResponse: *cc.Response,
AgentID: a.id,
}, nil
return fc.Result, fc.Err
}

// WithID sets the agent ID.
Expand Down Expand Up @@ -290,3 +379,22 @@ func WithFunctionMiddleware(mw ...FunctionMiddleware) AgentOption {
a.functionMiddleware = append(a.functionMiddleware, mw...)
}
}

// WithTools registers one or more FunctionTools on the agent.
func WithTools(tools ...*FunctionTool) AgentOption {
return func(a *BaseAgent) {
if a.tools == nil {
a.tools = make(map[string]*FunctionTool)
}
for _, t := range tools {
a.tools[t.Name] = t
}
}
}

// WithMaxToolRounds sets the maximum number of tool invocation rounds.
func WithMaxToolRounds(n int) AgentOption {
return func(a *BaseAgent) {
a.maxToolRounds = n
}
}
156 changes: 156 additions & 0 deletions go/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agentframework_test
import (
"context"
"errors"
"fmt"
"testing"

af "github.com/microsoft/agent-framework/go"
Expand Down Expand Up @@ -258,3 +259,158 @@ func TestBaseAgentRun(t *testing.T) {
assert.ErrorIs(t, err, af.ErrEmptyMessages)
})
}

type dynamicChatClient struct {
fn func(ctx context.Context, messages []af.Message, opts ...af.ChatOption) (*af.ChatResponse, error)
}

func (d *dynamicChatClient) GetResponse(ctx context.Context, messages []af.Message, opts ...af.ChatOption) (*af.ChatResponse, error) {
return d.fn(ctx, messages, opts...)
}

func TestToolLoop(t *testing.T) {
t.Run("agent invokes tool and sends result back to LLM", func(t *testing.T) {
callCount := 0

type callCapture struct {
messages []af.Message
}
var calls []callCapture

dynamicMock := &dynamicChatClient{
fn: func(ctx context.Context, messages []af.Message, opts ...af.ChatOption) (*af.ChatResponse, error) {
calls = append(calls, callCapture{messages: messages})
callCount++
if callCount == 1 {
return &af.ChatResponse{
Messages: []af.Message{
{
Role: af.RoleAssistant,
Contents: []af.Content{
{
Type: af.ContentTypeFunctionCall,
Name: "get_weather",
CallID: "call-1",
Arguments: map[string]any{"location": "Seattle"},
},
},
},
},
}, nil
}
return &af.ChatResponse{
Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "It's 72°F in Seattle")},
}, nil
},
}

weatherTool := &af.FunctionTool{
Name: "get_weather",
Description: "Get weather",
InputSchema: map[string]any{"type": "object"},
Invoke: func(ctx context.Context, args map[string]any) (string, error) {
return "72°F in " + args["location"].(string), nil
},
}

agent := af.NewAgent(dynamicMock, af.WithTools(weatherTool))

resp, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("weather in Seattle?")})
require.NoError(t, err)
assert.Equal(t, "It's 72°F in Seattle", resp.Messages[0].Text())
assert.Equal(t, 2, callCount, "LLM should be called twice")

secondCallMsgs := calls[1].messages
var foundToolResult bool
for _, msg := range secondCallMsgs {
if msg.Role == af.RoleTool {
for _, c := range msg.Contents {
if c.Type == af.ContentTypeFunctionResult {
foundToolResult = true
assert.Equal(t, "72°F in Seattle", c.Result)
assert.Equal(t, "call-1", c.CallID)
}
}
}
}
assert.True(t, foundToolResult, "expected tool result in second LLM call")
})

t.Run("max tool rounds exceeded returns error", func(t *testing.T) {
dynamicMock := &dynamicChatClient{
fn: func(ctx context.Context, messages []af.Message, opts ...af.ChatOption) (*af.ChatResponse, error) {
return &af.ChatResponse{
Messages: []af.Message{
{
Role: af.RoleAssistant,
Contents: []af.Content{
{
Type: af.ContentTypeFunctionCall,
Name: "loop_tool",
CallID: "call-loop",
Arguments: map[string]any{},
},
},
},
},
}, nil
},
}

loopTool := &af.FunctionTool{
Name: "loop_tool",
Invoke: func(ctx context.Context, args map[string]any) (string, error) { return "again", nil },
}

agent := af.NewAgent(dynamicMock,
af.WithTools(loopTool),
af.WithMaxToolRounds(3),
)

_, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("go")})
assert.ErrorIs(t, err, af.ErrMaxToolRounds)
})

t.Run("tool invocation error sent as error content to LLM", func(t *testing.T) {
callCount := 0
dynamicMock := &dynamicChatClient{
fn: func(ctx context.Context, messages []af.Message, opts ...af.ChatOption) (*af.ChatResponse, error) {
callCount++
if callCount == 1 {
return &af.ChatResponse{
Messages: []af.Message{
{
Role: af.RoleAssistant,
Contents: []af.Content{
{
Type: af.ContentTypeFunctionCall,
Name: "failing_tool",
CallID: "call-fail",
Arguments: map[string]any{},
},
},
},
},
}, nil
}
return &af.ChatResponse{
Messages: []af.Message{af.NewTextMessage(af.RoleAssistant, "handled error")},
}, nil
},
}

failTool := &af.FunctionTool{
Name: "failing_tool",
Invoke: func(ctx context.Context, args map[string]any) (string, error) {
return "", fmt.Errorf("tool failed")
},
}

agent := af.NewAgent(dynamicMock, af.WithTools(failTool))

resp, err := agent.Run(context.Background(), []af.Message{af.NewUserMessage("go")})
require.NoError(t, err)
assert.Equal(t, "handled error", resp.Messages[0].Text())
assert.Equal(t, 2, callCount)
})
}
3 changes: 3 additions & 0 deletions go/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ var (
// ErrStreamingNotSupported is returned when RunStreaming is called but the
// underlying ChatClient does not implement StreamingChatClient.
ErrStreamingNotSupported = errors.New("agentframework: streaming not supported by chat client")

// ErrMaxToolRounds is returned when the tool invocation loop exceeds the maximum number of rounds.
ErrMaxToolRounds = errors.New("agentframework: maximum tool invocation rounds exceeded")
)
Loading