diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 3116a44a..aeb81d6f 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.3.6" + ".": "0.3.7" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index a8b65ea7..bc87b568 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,20 @@ # Changelog -<<<<<<< HEAD -======= +## [0.3.7](https://github.com/a2aproject/a2a-go/compare/v0.3.6...v0.3.7) (2026-02-20) + + +### Features + +* implement tasks/list RPC ([#210](https://github.com/a2aproject/a2a-go/issues/210)) ([6e04698](https://github.com/a2aproject/a2a-go/commit/6e04698e63d4cb7d67a2aed700babd9bd6664b51)) +* retry cancelations ([#222](https://github.com/a2aproject/a2a-go/issues/222)) ([3057474](https://github.com/a2aproject/a2a-go/commit/30574743207ae68a96588fc81926ad7de3d9887e)) + + +### Bug Fixes + +* handle internal error: tck failure ([#186](https://github.com/a2aproject/a2a-go/issues/186)) ([b55fbfd](https://github.com/a2aproject/a2a-go/commit/b55fbfd6417fb48a89dd838611f35a6899d17e12)) +* **sse:** support data: prefix without space ([#188](https://github.com/a2aproject/a2a-go/issues/188)) ([6657a6d](https://github.com/a2aproject/a2a-go/commit/6657a6dc3b6872d425f03d6f340b7c1a82c55810)), closes [#162](https://github.com/a2aproject/a2a-go/issues/162) + + ## [0.3.6](https://github.com/a2aproject/a2a-go/compare/v0.3.5...v0.3.6) (2026-01-30) diff --git a/a2a/errors.go b/a2a/errors.go index d094fd84..be7a0f18 100644 --- a/a2a/errors.go +++ b/a2a/errors.go @@ -73,6 +73,9 @@ var ( // ErrUnauthorized indicates that the caller does not have permission to execute the specified operation. ErrUnauthorized = errors.New("permission denied") + + // ErrConcurrentTaskModification indicates that optimistic concurrency control failed during task update attempt. + ErrConcurrentTaskModification = errors.New("concurrent task modification") ) // Error provides control over the message and details returned to clients. diff --git a/a2aclient/agentcard/doc.go b/a2aclient/agentcard/doc.go index 7cf7a53d..49ec8d78 100644 --- a/a2aclient/agentcard/doc.go +++ b/a2aclient/agentcard/doc.go @@ -23,8 +23,7 @@ A [Resolver] can be created with a custom [http.Client] or a package-level Defau resolver := agentcard.NewResolver(customClient) card, err := resolver.Resolve(ctx, baseURL) -By default the request is sent for a well-known card location, but custom -this can be configured by providing [ResolveOption]s. +By default the request is sent for a well-known card location, but this can be customized by providing [ResolveOption]s. card, err := resolver.Resolve( ctx, diff --git a/a2aclient/doc.go b/a2aclient/doc.go index c3f64ef5..2fac2b4d 100644 --- a/a2aclient/doc.go +++ b/a2aclient/doc.go @@ -32,7 +32,7 @@ using either package-level functions or [Factory] methods. // or - card, err := agentcard.DefaultResolved.Resolve(ctx, url) + card, err := agentcard.DefaultResolver.Resolve(ctx, url) if err != nil { log.Fatalf("Failed to resolve an AgentCard: %v", err) } diff --git a/a2aclient/factory.go b/a2aclient/factory.go index a88e1ec6..5cf75de1 100644 --- a/a2aclient/factory.go +++ b/a2aclient/factory.go @@ -163,6 +163,11 @@ func createTransport(ctx context.Context, candidates []transportCandidate, card if len(failures) > 0 { log.Info(ctx, "some transports failed to connect", "failures", failures) } + + if selected.endpoint.Tenant != "" { + transport = &tenantTransportDecorator{base: transport, tenant: selected.endpoint.Tenant} + } + return transport, selected, nil } diff --git a/a2aclient/factory_test.go b/a2aclient/factory_test.go index 8eeab477..c41ab11b 100644 --- a/a2aclient/factory_test.go +++ b/a2aclient/factory_test.go @@ -248,3 +248,27 @@ func TestFactory_TransportSelection(t *testing.T) { }) } } + +func TestFactory_Tenant(t *testing.T) { + ctx := t.Context() + factory := NewFactory(WithTransport(a2a.TransportProtocolJSONRPC, TransportFactoryFn(func(ctx context.Context, card *a2a.AgentCard, iface *a2a.AgentInterface) (Transport, error) { + return unimplementedTransport{}, nil + }))) + iface := a2a.NewAgentInterface("https://agent.com", a2a.TransportProtocolJSONRPC) + iface.Tenant = "my-tenant" + + client, err := factory.CreateFromEndpoints(ctx, []*a2a.AgentInterface{iface}) + if err != nil { + t.Fatalf("CreateFromEndpoints() error = %v, want nil", err) + } + decorator, ok := client.transport.(*tenantTransportDecorator) + if !ok { + t.Fatalf("client.transport type = %T, want *tenantTransportDecorator", client.transport) + } + if decorator.tenant != "my-tenant" { + t.Errorf("decorator.tenant = %q, want %q", decorator.tenant, "my-tenant") + } + if _, ok := decorator.base.(unimplementedTransport); !ok { + t.Errorf("decorator.base type = %T, want unimplementedTransport", decorator.base) + } +} diff --git a/a2aclient/jsonrpc_test.go b/a2aclient/jsonrpc_test.go index 3b139f38..832d1932 100644 --- a/a2aclient/jsonrpc_test.go +++ b/a2aclient/jsonrpc_test.go @@ -644,3 +644,31 @@ func TestJSONRPCTransport_ErrorDetails(t *testing.T) { t.Errorf("got wrong details (+got,-want) diff = %s", diff) } } + +func TestJSONRPCTransport_Tenant(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req := mustDecodeJSONRPC(t, r, "ListTasks") + params, ok := req.Params.(map[string]any) + if !ok { + t.Fatalf("expected map[string]any params, got %T", req.Params) + } + if params["tenant"] != "my-tenant" { + t.Errorf("expected tenant my-tenant, got %v", params["tenant"]) + } + + resp := newResponse(req, json.RawMessage(`{"tasks":[]}`)) + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + iface := a2a.NewAgentInterface(server.URL, a2a.TransportProtocolJSONRPC) + iface.Tenant = "my-tenant" + transport := NewJSONRPCTransport(iface.URL, nil) + // Apply decorator manually as we are bypassing the factory + transport = &tenantTransportDecorator{base: transport, tenant: iface.Tenant} + + _, err := transport.ListTasks(t.Context(), ServiceParams{}, &a2a.ListTasksRequest{}) + if err != nil { + t.Fatalf("ListTasks failed: %v", err) + } +} diff --git a/a2aclient/rest.go b/a2aclient/rest.go index 6272dc71..e791561d 100644 --- a/a2aclient/rest.go +++ b/a2aclient/rest.go @@ -33,7 +33,7 @@ import ( // RESTTransport implemetns Transport using RESTful HTTP API. type RESTTransport struct { - url string + url *url.URL httpClient *http.Client } @@ -41,9 +41,9 @@ type RESTTransport struct { // By default, an HTTP client with 5-second timeout is used. // For production deployments, provide a client with appropriate timeout, retry policy, // and connection pooling configured for your requirements. -func NewRESTTransport(tURL string, client *http.Client) Transport { +func NewRESTTransport(u *url.URL, client *http.Client) Transport { t := &RESTTransport{ - url: tURL, + url: u, httpClient: client, } @@ -60,29 +60,58 @@ func WithRESTTransport(client *http.Client) FactoryOption { return WithTransport( a2a.TransportProtocolHTTPJSON, TransportFactoryFn(func(ctx context.Context, card *a2a.AgentCard, iface *a2a.AgentInterface) (Transport, error) { - return NewRESTTransport(iface.URL, client), nil + u, err := url.Parse(iface.URL) + if err != nil { + return nil, fmt.Errorf("failed to parse endpoint URL: %w", err) + } + return NewRESTTransport(u, client), nil }), ) } +type restRequest struct { + method string + params ServiceParams + path string + payload any + streaming bool + tenant string +} + // sendRequest prepares the HTTP request and sends it to the server. // It returns the HTTP response with the Body OPEN. // The caller is responsible for closing the response body. -func (t *RESTTransport) sendRequest(ctx context.Context, method string, params ServiceParams, path string, payload any, acceptHeader string) (*http.Response, error) { - reqBody, err := json.Marshal(payload) +func (t *RESTTransport) sendRequest(ctx context.Context, req *restRequest) (*http.Response, error) { + reqBody, err := json.Marshal(req.payload) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w: %w", err, a2a.ErrInvalidRequest) } - fullURL := t.url + path - httpReq, err := http.NewRequestWithContext(ctx, method, fullURL, bytes.NewBuffer(reqBody)) + rel, err := url.Parse(req.path) + if err != nil { + return nil, fmt.Errorf("failed to parse path: %w", err) + } + + u := t.url + if req.tenant != "" { + u = u.JoinPath(req.tenant, rel.Path) + } else { + u = u.JoinPath(rel.Path) + } + u.RawQuery = rel.RawQuery + fullURL := u.String() + httpReq, err := http.NewRequestWithContext(ctx, req.method, fullURL, bytes.NewBuffer(reqBody)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", acceptHeader) + if req.streaming { + httpReq.Header.Set("Accept", sse.ContentEventStream) + } else { + httpReq.Header.Set("Accept", "application/json") + } - for k, vals := range params { + for k, vals := range req.params { for _, v := range vals { httpReq.Header.Add(k, v) } @@ -106,11 +135,12 @@ func (t *RESTTransport) sendRequest(ctx context.Context, method string, params S } // doRequest is an adapter for Single Response calls -func (t *RESTTransport) doRequest(ctx context.Context, method string, params ServiceParams, path string, payload any, result any) error { - resp, err := t.sendRequest(ctx, method, params, path, payload, "application/json") +func (t *RESTTransport) doRequest(ctx context.Context, req *restRequest, result any) error { + resp, err := t.sendRequest(ctx, req) if err != nil { return err } + defer func() { if err := resp.Body.Close(); err != nil { log.Error(ctx, "failed to close http response body", err) @@ -126,9 +156,10 @@ func (t *RESTTransport) doRequest(ctx context.Context, method string, params Ser } // doStreamingRequest is an adapter for Streaming Response calls -func (t *RESTTransport) doStreamingRequest(ctx context.Context, method string, params ServiceParams, path string, payload any) iter.Seq2[a2a.Event, error] { +func (t *RESTTransport) doStreamingRequest(ctx context.Context, req *restRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - resp, err := t.sendRequest(ctx, method, params, path, payload, sse.ContentEventStream) + req.streaming = true + resp, err := t.sendRequest(ctx, req) if err != nil { yield(nil, err) return @@ -171,7 +202,13 @@ func (t *RESTTransport) GetTask(ctx context.Context, params ServiceParams, req * } var task a2a.Task - if err := t.doRequest(ctx, "GET", params, path, nil, &task); err != nil { + if err := t.doRequest(ctx, &restRequest{ + method: "GET", + params: params, + path: path, + tenant: req.Tenant, + payload: nil, + }, &task); err != nil { return nil, err } return &task, nil @@ -210,7 +247,13 @@ func (t *RESTTransport) ListTasks(ctx context.Context, params ServiceParams, req var result a2a.ListTasksResponse - if err := t.doRequest(ctx, "GET", params, path, nil, &result); err != nil { + if err := t.doRequest(ctx, &restRequest{ + method: "GET", + params: params, + path: path, + tenant: req.Tenant, + payload: nil, + }, &result); err != nil { return nil, err } return &result, nil @@ -221,7 +264,13 @@ func (t *RESTTransport) CancelTask(ctx context.Context, params ServiceParams, re path := rest.MakeCancelTaskPath(string(req.ID)) var result a2a.Task - if err := t.doRequest(ctx, "POST", params, path, nil, &result); err != nil { + if err := t.doRequest(ctx, &restRequest{ + method: "POST", + params: params, + path: path, + tenant: req.Tenant, + payload: nil, + }, &result); err != nil { return nil, err } return &result, nil @@ -232,7 +281,13 @@ func (t *RESTTransport) SendMessage(ctx context.Context, params ServiceParams, r path := rest.MakeSendMessagePath() var result json.RawMessage - if err := t.doRequest(ctx, "POST", params, path, req, &result); err != nil { + if err := t.doRequest(ctx, &restRequest{ + method: "POST", + params: params, + path: path, + tenant: req.Tenant, + payload: req, + }, &result); err != nil { return nil, err } @@ -256,13 +311,25 @@ func (t *RESTTransport) SendMessage(ctx context.Context, params ServiceParams, r // SubscribeToTask implements [a2a.Transport]. func (t *RESTTransport) SubscribeToTask(ctx context.Context, params ServiceParams, req *a2a.SubscribeToTaskRequest) iter.Seq2[a2a.Event, error] { path := rest.MakeSubscribeTaskPath(string(req.ID)) - return t.doStreamingRequest(ctx, "POST", params, path, nil) + return t.doStreamingRequest(ctx, &restRequest{ + method: "POST", + params: params, + path: path, + tenant: req.Tenant, + payload: nil, + }) } // SendStreamingMessage implements [a2a.Transport]. func (t *RESTTransport) SendStreamingMessage(ctx context.Context, params ServiceParams, req *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] { path := rest.MakeStreamMessagePath() - return t.doStreamingRequest(ctx, "POST", params, path, req) + return t.doStreamingRequest(ctx, &restRequest{ + method: "POST", + params: params, + path: path, + tenant: req.Tenant, + payload: req, + }) } // GetTaskPushConfig implements [a2a.Transport]. @@ -270,7 +337,13 @@ func (t *RESTTransport) GetTaskPushConfig(ctx context.Context, params ServicePar path := rest.MakeGetPushConfigPath(string(req.TaskID), string(req.ID)) var config a2a.TaskPushConfig - if err := t.doRequest(ctx, "GET", params, path, nil, &config); err != nil { + if err := t.doRequest(ctx, &restRequest{ + method: "GET", + params: params, + path: path, + tenant: req.Tenant, + payload: nil, + }, &config); err != nil { return nil, err } return &config, nil @@ -281,7 +354,13 @@ func (t *RESTTransport) ListTaskPushConfigs(ctx context.Context, params ServiceP path := rest.MakeListPushConfigsPath(string(req.TaskID)) var configs []*a2a.TaskPushConfig - if err := t.doRequest(ctx, "GET", params, path, nil, &configs); err != nil { + if err := t.doRequest(ctx, &restRequest{ + method: "GET", + params: params, + path: path, + tenant: req.Tenant, + payload: nil, + }, &configs); err != nil { return nil, err } return configs, nil @@ -292,7 +371,13 @@ func (t *RESTTransport) CreateTaskPushConfig(ctx context.Context, params Service path := rest.MakeCreatePushConfigPath(string(req.TaskID)) var config a2a.TaskPushConfig - if err := t.doRequest(ctx, "POST", params, path, req, &config); err != nil { + if err := t.doRequest(ctx, &restRequest{ + method: "POST", + params: params, + path: path, + tenant: req.Tenant, + payload: req, + }, &config); err != nil { return nil, err } return &config, nil @@ -301,7 +386,13 @@ func (t *RESTTransport) CreateTaskPushConfig(ctx context.Context, params Service // DeleteTaskPushConfig implements [a2a.Transport]. func (t *RESTTransport) DeleteTaskPushConfig(ctx context.Context, params ServiceParams, req *a2a.DeleteTaskPushConfigRequest) error { path := rest.MakeDeletePushConfigPath(string(req.TaskID), string(req.ID)) - return t.doRequest(ctx, "DELETE", params, path, nil, nil) + return t.doRequest(ctx, &restRequest{ + method: "DELETE", + params: params, + path: path, + tenant: req.Tenant, + payload: nil, + }, nil) } // GetExtendedAgentCard implements [a2a.Transport]. @@ -309,7 +400,13 @@ func (t *RESTTransport) GetExtendedAgentCard(ctx context.Context, params Service path := rest.MakeGetExtendedAgentCardPath() var card a2a.AgentCard - if err := t.doRequest(ctx, "GET", params, path, nil, &card); err != nil { + if err := t.doRequest(ctx, &restRequest{ + method: "GET", + params: params, + path: path, + tenant: req.Tenant, + payload: nil, + }, &card); err != nil { return nil, err } return &card, nil diff --git a/a2aclient/rest_test.go b/a2aclient/rest_test.go index eb042cac..6125c2d6 100644 --- a/a2aclient/rest_test.go +++ b/a2aclient/rest_test.go @@ -17,6 +17,7 @@ package a2aclient import ( "net/http" "net/http/httptest" + "net/url" "testing" "github.com/a2aproject/a2a-go/v1/a2a" @@ -36,7 +37,7 @@ func TestRESTTransport_GetTask(t *testing.T) { _, _ = w.Write([]byte(`{"kind":"task","id":"task-123","contextId":"ctx-123","status":{"state":"COMPLETED"},"history":[{"state":"COMPLETED"},{"state":"WORKING"}]}`)) })) defer server.Close() - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) historyLength := 2 task, err := transport.GetTask(t.Context(), ServiceParams{}, &a2a.GetTaskRequest{ ID: "task-123", @@ -83,7 +84,7 @@ func TestRESTTransport_ListTasks(t *testing.T) { })) defer server.Close() - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) wantResult := &a2a.ListTasksResponse{ Tasks: []*a2a.Task{ { @@ -128,7 +129,7 @@ func TestRESTTransport_CancelTask(t *testing.T) { _, _ = w.Write([]byte(`{"kind":"task","id":"task-123","contextId":"ctx-123","status":{"state":"CANCELED"}}`)) })) defer server.Close() - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) task, err := transport.CancelTask(t.Context(), ServiceParams{}, &a2a.CancelTaskRequest{ ID: "task-123", @@ -157,7 +158,7 @@ func TestRESTTransport_SendMessage(t *testing.T) { _, _ = w.Write([]byte(`{"task":{"id":"task-123","contextId":"ctx-123","status":{"state":"SUBMITTED"}}}`)) })) defer server.Close() - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) result, err := transport.SendMessage(t.Context(), ServiceParams{}, &a2a.SendMessageRequest{ Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("test message")), @@ -204,8 +205,7 @@ func TestRESTTransport_ResubscribeToTask(t *testing.T) { } })) defer server.Close() - - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) events := []a2a.Event{} for event, err := range transport.SubscribeToTask(t.Context(), ServiceParams{}, &a2a.SubscribeToTaskRequest{ @@ -258,8 +258,7 @@ func TestRESTTransport_SendStreamingMessage(t *testing.T) { } })) defer server.Close() - - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) events := []a2a.Event{} for event, err := range transport.SendStreamingMessage(t.Context(), ServiceParams{}, &a2a.SendMessageRequest{ @@ -299,8 +298,7 @@ func TestRESTTransport_GetTaskPushConfig(t *testing.T) { _, _ = w.Write([]byte(`{"taskId":"task-123","config":{"id":"config-123","url":"https://webhook.example.com"}}`)) })) defer server.Close() - - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) config, err := transport.GetTaskPushConfig(t.Context(), ServiceParams{}, &a2a.GetTaskPushConfigRequest{ TaskID: a2a.TaskID("task-123"), @@ -338,8 +336,7 @@ func TestRESTTransport_ListTaskPushConfigs(t *testing.T) { ]`)) })) defer server.Close() - - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) configs, err := transport.ListTaskPushConfigs(t.Context(), ServiceParams{}, &a2a.ListTaskPushConfigRequest{ TaskID: a2a.TaskID("task-123"), @@ -372,8 +369,7 @@ func TestRESTTransport_SetTaskPushConfig(t *testing.T) { _, _ = w.Write([]byte(`{"taskId":"task-123","config":{"id":"config-123","url":"https://webhook.example.com"}}`)) })) defer server.Close() - - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) config, err := transport.CreateTaskPushConfig(t.Context(), ServiceParams{}, &a2a.CreateTaskPushConfigRequest{ TaskID: "task-123", @@ -407,8 +403,7 @@ func TestRESTTransport_DeleteTaskPushConfig(t *testing.T) { } })) defer server.Close() - - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) err := transport.DeleteTaskPushConfig(t.Context(), ServiceParams{}, &a2a.DeleteTaskPushConfigRequest{ TaskID: a2a.TaskID("task-123"), @@ -433,8 +428,7 @@ func TestRESTTransport_GetAgentCard(t *testing.T) { _, _ = w.Write([]byte(`{"supportedInterfaces":[{"url":"http://example.com"}], "name": "Test agent", "description":"test"}`)) })) defer server.Close() - - transport := NewRESTTransport(server.URL, server.Client()) + transport := newRESTTransport(t, server) card, err := transport.GetExtendedAgentCard(t.Context(), ServiceParams{}, &a2a.GetExtendedAgentCardRequest{}) if err != nil { @@ -448,3 +442,33 @@ func TestRESTTransport_GetAgentCard(t *testing.T) { t.Errorf("got card Description %s, want test", card.Description) } } + +func TestRESTTransport_Tenant(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/my-tenant/tasks" { + t.Errorf("expected path /my-tenant/tasks, got %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"tasks":[]}`)) + })) + defer server.Close() + transport := newRESTTransport(t, server) + transport = &tenantTransportDecorator{ + base: transport, + tenant: "my-tenant", + } + + _, err := transport.ListTasks(t.Context(), ServiceParams{}, &a2a.ListTasksRequest{}) + if err != nil { + t.Fatalf("ListTasks failed: %v", err) + } +} + +func newRESTTransport(t *testing.T, server *httptest.Server) Transport { + t.Helper() + u, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("url.Parse(%q) error = %v", server.URL, err) + } + return NewRESTTransport(u, server.Client()) +} diff --git a/a2aclient/transport.go b/a2aclient/transport.go index 2b06c8f3..2f640f15 100644 --- a/a2aclient/transport.go +++ b/a2aclient/transport.go @@ -133,3 +133,76 @@ func (unimplementedTransport) GetExtendedAgentCard(ctx context.Context, params S func (unimplementedTransport) Destroy() error { return nil } + +type tenantTransportDecorator struct { + base Transport + tenant string +} + +var _ Transport = (*tenantTransportDecorator)(nil) + +func (d *tenantTransportDecorator) updateTenant(current string) string { + if current != "" { + return current + } + return d.tenant +} + +func (d *tenantTransportDecorator) GetTask(ctx context.Context, params ServiceParams, req *a2a.GetTaskRequest) (*a2a.Task, error) { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.GetTask(ctx, params, req) +} + +func (d *tenantTransportDecorator) ListTasks(ctx context.Context, params ServiceParams, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.ListTasks(ctx, params, req) +} + +func (d *tenantTransportDecorator) CancelTask(ctx context.Context, params ServiceParams, req *a2a.CancelTaskRequest) (*a2a.Task, error) { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.CancelTask(ctx, params, req) +} + +func (d *tenantTransportDecorator) SendMessage(ctx context.Context, params ServiceParams, req *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.SendMessage(ctx, params, req) +} + +func (d *tenantTransportDecorator) SubscribeToTask(ctx context.Context, params ServiceParams, req *a2a.SubscribeToTaskRequest) iter.Seq2[a2a.Event, error] { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.SubscribeToTask(ctx, params, req) +} + +func (d *tenantTransportDecorator) SendStreamingMessage(ctx context.Context, params ServiceParams, req *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.SendStreamingMessage(ctx, params, req) +} + +func (d *tenantTransportDecorator) GetTaskPushConfig(ctx context.Context, params ServiceParams, req *a2a.GetTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.GetTaskPushConfig(ctx, params, req) +} + +func (d *tenantTransportDecorator) ListTaskPushConfigs(ctx context.Context, params ServiceParams, req *a2a.ListTaskPushConfigRequest) ([]*a2a.TaskPushConfig, error) { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.ListTaskPushConfigs(ctx, params, req) +} + +func (d *tenantTransportDecorator) CreateTaskPushConfig(ctx context.Context, params ServiceParams, req *a2a.CreateTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.CreateTaskPushConfig(ctx, params, req) +} + +func (d *tenantTransportDecorator) DeleteTaskPushConfig(ctx context.Context, params ServiceParams, req *a2a.DeleteTaskPushConfigRequest) error { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.DeleteTaskPushConfig(ctx, params, req) +} + +func (d *tenantTransportDecorator) GetExtendedAgentCard(ctx context.Context, params ServiceParams, req *a2a.GetExtendedAgentCardRequest) (*a2a.AgentCard, error) { + req.Tenant = d.updateTenant(req.Tenant) + return d.base.GetExtendedAgentCard(ctx, params, req) +} + +func (d *tenantTransportDecorator) Destroy() error { + return d.base.Destroy() +} diff --git a/a2asrv/agentexec.go b/a2asrv/agentexec.go index e414c27d..926dc87b 100644 --- a/a2asrv/agentexec.go +++ b/a2asrv/agentexec.go @@ -138,6 +138,7 @@ func (f *factory) CreateExecutor(ctx context.Context, tid a2a.TaskID, params *a2 if callCtx, ok := CallContextFrom(ctx); ok { execCtx.ctx.User = callCtx.User execCtx.ctx.ServiceParams = callCtx.ServiceParams() + execCtx.ctx.Tenant = callCtx.Tenant() } if params.Config != nil && params.Config.PushConfig != nil { @@ -215,6 +216,7 @@ func (f *factory) loadExecutionContext(ctx context.Context, tid a2a.TaskID, para TaskID: storedTask.ID, ContextID: storedTask.ContextID, Metadata: params.Metadata, + Tenant: params.Tenant, }, task: &taskstore.StoredTask{ Task: storedTask, @@ -259,6 +261,7 @@ func (f *factory) CreateCanceler(ctx context.Context, params *a2a.CancelTaskRequ if callCtx, ok := CallContextFrom(ctx); ok { execCtx.User = callCtx.User execCtx.ServiceParams = callCtx.ServiceParams() + execCtx.Tenant = callCtx.Tenant() } canceler := &canceler{agent: f.agent, execCtx: execCtx, task: task, interceptors: f.interceptors} diff --git a/a2asrv/handler.go b/a2asrv/handler.go index 8aaaeb88..c11403f2 100644 --- a/a2asrv/handler.go +++ b/a2asrv/handler.go @@ -249,7 +249,7 @@ func (h *defaultRequestHandler) ListTasks(ctx context.Context, req *a2a.ListTask // CancelTask implements RequestHandler. func (h *defaultRequestHandler) CancelTask(ctx context.Context, req *a2a.CancelTaskRequest) (*a2a.Task, error) { - if req == nil { + if req == nil || req.ID == "" { return nil, a2a.ErrInvalidParams } diff --git a/a2asrv/handler_test.go b/a2asrv/handler_test.go index 635a79ef..88d734f6 100644 --- a/a2asrv/handler_test.go +++ b/a2asrv/handler_test.go @@ -1347,7 +1347,7 @@ func TestRequestHandler_CancelTask(t *testing.T) { }, { name: "nil params", - params: nil, + params: &a2a.CancelTaskRequest{}, wantErr: a2a.ErrInvalidParams, }, { diff --git a/a2asrv/intercepted_handler.go b/a2asrv/intercepted_handler.go index 5dd2f76f..45af9e35 100644 --- a/a2asrv/intercepted_handler.go +++ b/a2asrv/intercepted_handler.go @@ -51,36 +51,29 @@ var _ RequestHandler = (*InterceptedHandler)(nil) // GetTask implements RequestHandler. func (h *InterceptedHandler) GetTask(ctx context.Context, req *a2a.GetTaskRequest) (*a2a.Task, error) { - ctx, callCtx := attachMethodCallContext(ctx, "GetTask") - if req != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) - } + ctx, callCtx := attachMethodCallContext(ctx, "GetTask", req.Tenant) + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) return doCall(ctx, callCtx, h, req, h.Handler.GetTask) } // ListTasks implements RequestHandler. func (h *InterceptedHandler) ListTasks(ctx context.Context, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { - ctx, callCtx := attachMethodCallContext(ctx, "ListTasks") - if req != nil { - ctx = h.withLoggerContext(ctx) - } + ctx, callCtx := attachMethodCallContext(ctx, "ListTasks", req.Tenant) + ctx = h.withLoggerContext(ctx) return doCall(ctx, callCtx, h, req, h.Handler.ListTasks) } // CancelTask implements RequestHandler. func (h *InterceptedHandler) CancelTask(ctx context.Context, req *a2a.CancelTaskRequest) (*a2a.Task, error) { - ctx, callCtx := attachMethodCallContext(ctx, "CancelTask") - if req != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) - } + ctx, callCtx := attachMethodCallContext(ctx, "CancelTask", req.Tenant) + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) return doCall(ctx, callCtx, h, req, h.Handler.CancelTask) } // SendMessage implements RequestHandler. func (h *InterceptedHandler) SendMessage(ctx context.Context, req *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { - ctx, callCtx := attachMethodCallContext(ctx, "SendMessage") - if req != nil && req.Message != nil { - msg := req.Message + ctx, callCtx := attachMethodCallContext(ctx, "SendMessage", req.Tenant) + if msg := req.Message; msg != nil { ctx = h.withLoggerContext( ctx, slog.String("message_id", msg.ID), @@ -96,12 +89,12 @@ func (h *InterceptedHandler) SendMessage(ctx context.Context, req *a2a.SendMessa // SendStreamingMessage implements RequestHandler. func (h *InterceptedHandler) SendStreamingMessage(ctx context.Context, req *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - ctx, callCtx := attachMethodCallContext(ctx, "SendStreamingMessage") + ctx, callCtx := attachMethodCallContext(ctx, "SendStreamingMessage", req.Tenant) if err := checkRequiredExtensions(h, callCtx); err != nil { yield(nil, err) return } - if req != nil && req.Message != nil { + if req.Message != nil { msg := req.Message ctx = h.withLoggerContext( ctx, @@ -137,14 +130,12 @@ func (h *InterceptedHandler) SendStreamingMessage(ctx context.Context, req *a2a. // SubscribeToTask implements RequestHandler. func (h *InterceptedHandler) SubscribeToTask(ctx context.Context, req *a2a.SubscribeToTaskRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - ctx, callCtx := attachMethodCallContext(ctx, "SubscribeToTask") + ctx, callCtx := attachMethodCallContext(ctx, "SubscribeToTask", req.Tenant) if err := checkRequiredExtensions(h, callCtx); err != nil { yield(nil, err) return } - if req != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) - } + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) ctx, res := interceptBefore[*a2a.SubscribeToTaskRequest, a2a.SendMessageResult](ctx, h, callCtx, req) if res.earlyErr != nil { yield(nil, res.earlyErr) @@ -169,37 +160,29 @@ func (h *InterceptedHandler) SubscribeToTask(ctx context.Context, req *a2a.Subsc // GetTaskPushConfig implements RequestHandler. func (h *InterceptedHandler) GetTaskPushConfig(ctx context.Context, req *a2a.GetTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { - ctx, callCtx := attachMethodCallContext(ctx, "GetTaskPushConfig") - if req != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) - } + ctx, callCtx := attachMethodCallContext(ctx, "GetTaskPushConfig", req.Tenant) + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) return doCall(ctx, callCtx, h, req, h.Handler.GetTaskPushConfig) } // ListTaskPushConfigs implements RequestHandler. func (h *InterceptedHandler) ListTaskPushConfigs(ctx context.Context, req *a2a.ListTaskPushConfigRequest) ([]*a2a.TaskPushConfig, error) { - ctx, callCtx := attachMethodCallContext(ctx, "ListTaskPushConfigs") - if req != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) - } + ctx, callCtx := attachMethodCallContext(ctx, "ListTaskPushConfigs", req.Tenant) + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) return doCall(ctx, callCtx, h, req, h.Handler.ListTaskPushConfigs) } // CreateTaskPushConfig implements RequestHandler. func (h *InterceptedHandler) CreateTaskPushConfig(ctx context.Context, req *a2a.CreateTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { - ctx, callCtx := attachMethodCallContext(ctx, "CreateTaskPushConfig") - if req != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) - } + ctx, callCtx := attachMethodCallContext(ctx, "CreateTaskPushConfig", req.Tenant) + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) return doCall(ctx, callCtx, h, req, h.Handler.CreateTaskPushConfig) } // DeleteTaskPushConfig implements RequestHandler. func (h *InterceptedHandler) DeleteTaskPushConfig(ctx context.Context, req *a2a.DeleteTaskPushConfigRequest) error { - ctx, callCtx := attachMethodCallContext(ctx, "DeleteTaskPushConfig") - if req != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) - } + ctx, callCtx := attachMethodCallContext(ctx, "DeleteTaskPushConfig", req.Tenant) + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) ctx, res := interceptBefore[*a2a.DeleteTaskPushConfigRequest, struct{}](ctx, h, callCtx, req) if res.earlyErr != nil { return res.earlyErr @@ -218,7 +201,7 @@ func (h *InterceptedHandler) DeleteTaskPushConfig(ctx context.Context, req *a2a. // GetExtendedAgentCard implements RequestHandler. func (h *InterceptedHandler) GetExtendedAgentCard(ctx context.Context, req *a2a.GetExtendedAgentCardRequest) (*a2a.AgentCard, error) { - ctx, callCtx := attachMethodCallContext(ctx, "GetExtendedAgentCard") + ctx, callCtx := attachMethodCallContext(ctx, "GetExtendedAgentCard", req.Tenant) ctx = h.withLoggerContext(ctx) return doCall(ctx, callCtx, h, req, h.Handler.GetExtendedAgentCard) } @@ -304,12 +287,16 @@ func (h *InterceptedHandler) withLoggerContext(ctx context.Context, attrs ...any // attachMethodCallContext is a private utility function which modifies CallContext.method if a CallContext // was passed by a transport implementation or initializes a new CallContext with the provided method. -func attachMethodCallContext(ctx context.Context, method string) (context.Context, *CallContext) { +func attachMethodCallContext(ctx context.Context, method string, tenant string) (context.Context, *CallContext) { callCtx, ok := CallContextFrom(ctx) if !ok { ctx, callCtx = NewCallContext(ctx, nil) } + callCtx.method = method + if tenant != "" { + callCtx.tenant = tenant + } return ctx, callCtx } diff --git a/a2asrv/middleware.go b/a2asrv/middleware.go index 7edb4e19..ffc9bffa 100644 --- a/a2asrv/middleware.go +++ b/a2asrv/middleware.go @@ -47,6 +47,9 @@ type CallContext struct { // User can be set by authentication middleware to provide information about // the user who initiated the request. User *User + + // tenant is an optional ID of the agent owner. + tenant string } // Method returns the name of the RequestHandler method which is being executed. @@ -59,6 +62,11 @@ func (cc *CallContext) ServiceParams() *ServiceParams { return cc.svcParams } +// Tenant returns the tenant ID of the current call context. +func (cc *CallContext) Tenant() string { + return cc.tenant +} + // Extensions returns a struct which provides an API for working with extensions in the current call context. func (cc *CallContext) Extensions() *Extensions { return &Extensions{callCtx: cc} diff --git a/a2asrv/reqctx.go b/a2asrv/reqctx.go index 3ffe4bc7..22974883 100644 --- a/a2asrv/reqctx.go +++ b/a2asrv/reqctx.go @@ -54,6 +54,8 @@ type ExecutorContext struct { User *User // ServiceParams of the request which triggered the execution. ServiceParams *ServiceParams + // Tenant is an optional ID of the agent owner. + Tenant string } var _ a2a.TaskInfoProvider = (*ExecutorContext)(nil) diff --git a/a2asrv/rest.go b/a2asrv/rest.go index 9018acee..c4f21e5d 100644 --- a/a2asrv/rest.go +++ b/a2asrv/rest.go @@ -17,13 +17,16 @@ package a2asrv import ( "context" "encoding/json" + "fmt" "iter" "net/http" + "net/url" "strconv" "strings" "time" "github.com/a2aproject/a2a-go/v1/a2a" + "github.com/a2aproject/a2a-go/v1/internal/pathtemplate" "github.com/a2aproject/a2a-go/v1/internal/rest" "github.com/a2aproject/a2a-go/v1/internal/sse" "github.com/a2aproject/a2a-go/v1/log" @@ -48,6 +51,37 @@ func NewRESTHandler(handler RequestHandler) http.Handler { return mux } +// NewTenantRESTHandler creates an [http.Handler] which implements the HTTP+JSON A2A protocol binding. +// It extracts tenant information from the URL path based on the provided template, strips the prefix, +// and attaches the tenant ID (part inside {}) to the request context. +// Examples of templates: +// - "/{*}" +// - "/locations/*/projects/{*}" +// - "/{locations/*/projects/*}" +func NewTenantRESTHandler(tenantTemplate string, handler RequestHandler) http.Handler { + compiledTemplate, err := pathtemplate.New(tenantTemplate) + if err != nil { + panic(fmt.Errorf("invalid template: %w", err)) + } + restHandler := NewRESTHandler(handler) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + matchResult, ok := compiledTemplate.Match(r.URL.Path) + if !ok { + http.NotFound(w, r) + return + } + + r2 := new(http.Request) + *r2 = *r + r2 = r2.WithContext(attachTenant(r.Context(), matchResult.Captured)) + r2.URL = new(url.URL) + *r2.URL = *r.URL + r2.URL.Path = matchResult.Rest + r2.URL.RawPath = "" + restHandler.ServeHTTP(w, r2) + }) +} + func handleSendMessage(handler RequestHandler) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() @@ -56,6 +90,7 @@ func handleSendMessage(handler RequestHandler) http.HandlerFunc { writeRESTError(ctx, rw, a2a.ErrParseError, a2a.TaskID("")) return } + fillTenant(ctx, &message.Tenant) result, err := handler.SendMessage(ctx, &message) @@ -78,6 +113,7 @@ func handleStreamMessage(handler RequestHandler) http.HandlerFunc { writeRESTError(ctx, rw, a2a.ErrParseError, a2a.TaskID("")) return } + fillTenant(ctx, &message.Tenant) handleStreamingRequest(handler.SendStreamingMessage(ctx, &message), rw, req) } } @@ -104,6 +140,7 @@ func handleGetTask(handler RequestHandler) http.HandlerFunc { ID: a2a.TaskID(taskID), HistoryLength: historyLength, } + fillTenant(ctx, ¶ms.Tenant) result, err := handler.GetTask(ctx, params) if err != nil { @@ -150,6 +187,7 @@ func handleListTasks(handler RequestHandler) http.HandlerFunc { parse("historyLength", &request.HistoryLength) parse("statusTimestampAfter", &request.StatusTimestampAfter) parse("includeArtifacts", &request.IncludeArtifacts) + fillTenant(ctx, &request.Tenant) if err != nil { writeRESTError(ctx, rw, a2a.ErrInvalidRequest, a2a.TaskID("")) return @@ -179,7 +217,9 @@ func handlePOSTTasks(handler RequestHandler) http.HandlerFunc { handleCancelTask(handler, taskID, rw, req) } else if before, ok := strings.CutSuffix(idAndAction, ":subscribe"); ok { taskID := before - handleStreamingRequest(handler.SubscribeToTask(ctx, &a2a.SubscribeToTaskRequest{ID: a2a.TaskID(taskID)}), rw, req) + req2 := &a2a.SubscribeToTaskRequest{ID: a2a.TaskID(taskID)} + fillTenant(ctx, &req2.Tenant) + handleStreamingRequest(handler.SubscribeToTask(ctx, req2), rw, req) } else { writeRESTError(ctx, rw, a2a.ErrInvalidRequest, a2a.TaskID("")) return @@ -193,6 +233,7 @@ func handleCancelTask(handler RequestHandler, taskID string, rw http.ResponseWri id := &a2a.CancelTaskRequest{ ID: a2a.TaskID(taskID), } + fillTenant(ctx, &id.Tenant) result, err := handler.CancelTask(ctx, id) @@ -290,6 +331,7 @@ func handleCreateTaskPushConfig(handler RequestHandler) http.HandlerFunc { TaskID: a2a.TaskID(taskID), Config: *config, } + fillTenant(ctx, ¶ms.Tenant) result, err := handler.CreateTaskPushConfig(ctx, params) @@ -319,6 +361,7 @@ func handleGetTaskPushConfig(handler RequestHandler) http.HandlerFunc { TaskID: a2a.TaskID(taskID), ID: configID, } + fillTenant(ctx, ¶ms.Tenant) result, err := handler.GetTaskPushConfig(ctx, params) @@ -346,6 +389,7 @@ func handleListTaskPushConfigs(handler RequestHandler) http.HandlerFunc { params := &a2a.ListTaskPushConfigRequest{ TaskID: a2a.TaskID(taskID), } + fillTenant(ctx, ¶ms.Tenant) result, err := handler.ListTaskPushConfigs(ctx, params) @@ -374,6 +418,7 @@ func handleDeleteTaskPushConfig(handler RequestHandler) http.HandlerFunc { TaskID: a2a.TaskID(taskID), ID: configID, } + fillTenant(ctx, ¶ms.Tenant) err := handler.DeleteTaskPushConfig(ctx, params) @@ -388,7 +433,9 @@ func handleGetExtendedAgentCard(handler RequestHandler) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() // TODO: extract tenant from path - result, err := handler.GetExtendedAgentCard(ctx, &a2a.GetExtendedAgentCardRequest{}) + req2 := &a2a.GetExtendedAgentCardRequest{} + fillTenant(ctx, &req2.Tenant) + result, err := handler.GetExtendedAgentCard(ctx, req2) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID("")) @@ -410,3 +457,22 @@ func writeRESTError(ctx context.Context, rw http.ResponseWriter, err error, task log.Error(ctx, "failed to write error response", err) } } + +type tenantKeyType struct{} + +func fillTenant(ctx context.Context, tenant *string) { + if t := tenantFromContext(ctx); t != "" { + *tenant = t + } +} + +func attachTenant(parent context.Context, tenant string) context.Context { + return context.WithValue(parent, tenantKeyType{}, tenant) +} + +func tenantFromContext(ctx context.Context) string { + if tenant, ok := ctx.Value(tenantKeyType{}).(string); ok { + return tenant + } + return "" +} diff --git a/a2asrv/rest_test.go b/a2asrv/rest_test.go index 7dde6518..0cef7cd6 100644 --- a/a2asrv/rest_test.go +++ b/a2asrv/rest_test.go @@ -108,40 +108,58 @@ func TestREST_RequestRouting(t *testing.T) { } ctx := t.Context() - lastCalledMethod := make(chan string, 1) + lastCallCtx := make(chan *CallContext, 1) interceptor := &mockInterceptor{ beforeFn: func(ctx context.Context, callCtx *CallContext, req *Request) (context.Context, any, error) { - lastCalledMethod <- callCtx.Method() + lastCallCtx <- callCtx return ctx, nil, nil }, } + reqHandler := NewHandler( &mockAgentExecutor{}, WithCallInterceptors(interceptor), WithExtendedAgentCard(&a2a.AgentCard{}), ) - server := httptest.NewServer(NewRESTHandler(reqHandler)) - - client, err := a2aclient.NewFromEndpoints(ctx, []*a2a.AgentInterface{ - a2a.NewAgentInterface(server.URL, a2a.TransportProtocolHTTPJSON), - }) - if err != nil { - t.Fatalf("a2aclient.NewFromEndpoints() error = %v", err) - } + for _, tenant := range []string{"", "my-tenant"} { + var transport http.Handler + if tenant == "" { + transport = NewRESTHandler(reqHandler) + } else { + transport = NewTenantRESTHandler("/{*}", reqHandler) + } + server := httptest.NewServer(transport) + t.Cleanup(server.Close) - for _, tc := range testCases { - t.Run(tc.method, func(t *testing.T) { - _, _ = tc.call(ctx, client) - select { - case calledMethod := <-lastCalledMethod: - if calledMethod != tc.method { - t.Fatalf("wrong method called: got %q, want %q", calledMethod, tc.method) - } - case <-time.After(5 * time.Second): - t.Fatalf("Routing failed") + for _, tc := range testCases { + name := tc.method + if tenant != "" { + name += " (with tenant)" } - }) + t.Run(name, func(t *testing.T) { + iface := a2a.NewAgentInterface(server.URL, a2a.TransportProtocolHTTPJSON) + if tenant != "" { + iface.Tenant = tenant + } + client, err := a2aclient.NewFromEndpoints(ctx, []*a2a.AgentInterface{iface}) + if err != nil { + t.Fatalf("a2aclient.NewFromEndpoints() error = %v", err) + } + _, _ = tc.call(ctx, client) + select { + case callCtx := <-lastCallCtx: + if callCtx.Tenant() != tenant { + t.Fatalf("callCtx.Tenant() = %q, want %q", callCtx.Tenant(), tenant) + } + if callCtx.Method() != tc.method { + t.Fatalf("callCtx.Method() = %q, want %q", callCtx.Method(), tc.method) + } + case <-time.After(1 * time.Second): + t.Fatalf("Routing failed") + } + }) + } } } @@ -163,7 +181,7 @@ func TestREST_Validations(t *testing.T) { name string methods []string path string - body interface{} + body any }{ { name: "SendMessage", @@ -349,3 +367,104 @@ func TestREST_InvalidPayloads(t *testing.T) { }) } } + +func TestRESTTenant(t *testing.T) { + tid := a2a.NewTaskID() + tests := []struct { + name string + template string + path string + wantTenant string + wantErr bool + }{ + { + name: "simple", + template: "/{*}", + path: "/my-tenant/tasks/" + string(tid), + wantTenant: "my-tenant", + }, + { + name: "complex with tenant", + template: "/locations/*/projects/{*}", + path: "/locations/us-central1/projects/my-project/tasks/" + string(tid), + wantTenant: "my-project", + }, + { + name: "multi-segment capture", + template: "{/locations/*/projects/*}", + path: "/locations/us-central1/projects/my-project/tasks/" + string(tid), + wantTenant: "locations/us-central1/projects/my-project", + }, + { + name: "trailing slash", + template: "/{*}", + path: "/my-tenant/tasks/" + string(tid), + wantTenant: "my-tenant", + }, + { + name: "no match", + template: "/fixed/{*}", + path: "/other/my-tenant/tasks/" + string(tid), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + gotTenant := "" + task := &a2a.Task{ID: tid, ContextID: a2a.NewContextID()} + store := testutil.NewTestTaskStore().WithTasks(t, task) + interceptor := &testInterceptor{ + BeforeFn: func(ctx context.Context, callCtx *CallContext, req *Request) (context.Context, any, error) { + gotTenant = callCtx.Tenant() + return ctx, nil, nil + }, + } + handler := NewHandler(&mockAgentExecutor{}, WithTaskStore(store), WithCallInterceptors(interceptor)) + server := httptest.NewServer(NewTenantRESTHandler(tt.template, handler)) + defer server.Close() + + req, err := http.NewRequestWithContext(ctx, "GET", server.URL+tt.path, nil) + if err != nil { + t.Fatalf("http.NewRequestWithContext() error = %v", err) + } + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("server.Client().Do() error = %v", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Errorf("resp.Body.Close() error = %v", err) + } + }() + + if tt.wantErr && resp.StatusCode == http.StatusOK { + t.Fatal("GetTask() error = nil, want to fail") + } + if tt.wantErr { + return + } + if resp.StatusCode != http.StatusOK { + errBody, _ := io.ReadAll(resp.Body) + t.Errorf("got %d, want 200 OK. Error: %s", resp.StatusCode, string(errBody)) + } + if gotTenant != tt.wantTenant { + t.Errorf("got tenant %q, want %q", gotTenant, tt.wantTenant) + } + }) + } +} + +type testInterceptor struct { + PassthroughCallInterceptor + BeforeFn func(ctx context.Context, callCtx *CallContext, req *Request) (context.Context, any, error) +} + +func (i *testInterceptor) Before(ctx context.Context, callCtx *CallContext, req *Request) (context.Context, any, error) { + if i.BeforeFn != nil { + return i.BeforeFn(ctx, callCtx, req) + } + return ctx, nil, nil +} diff --git a/e2e/tck/sut.go b/e2e/tck/sut.go index 9067f5be..24874cad 100644 --- a/e2e/tck/sut.go +++ b/e2e/tck/sut.go @@ -23,6 +23,7 @@ import ( "net" "net/http" "os" + "strings" "time" "github.com/a2aproject/a2a-go/v1/a2a" @@ -32,9 +33,8 @@ import ( "google.golang.org/grpc" ) -type intercepter struct { - a2asrv.PassthroughCallInterceptor -} +type intercepter struct{} +type msgContextKeyType struct{} func (i *intercepter) Before(ctx context.Context, callCtx *a2asrv.CallContext, req *a2asrv.Request) (context.Context, any, error) { if callCtx.Method() == "OnSendMessage" { @@ -46,10 +46,20 @@ func (i *intercepter) Before(ctx context.Context, callCtx *a2asrv.CallContext, r blocking := false sendParams.Config.Blocking = &blocking } + return context.WithValue(ctx, msgContextKeyType{}, sendParams.Message.ID), nil, nil } return ctx, nil, nil } +func (i *intercepter) After(ctx context.Context, callCtx *a2asrv.CallContext, resp *a2asrv.Response) error { + id, ok := ctx.Value(msgContextKeyType{}).(string) + if ok && (strings.Contains(id, "continuation") || strings.Contains(id, "test-history-message-")) { + resp.Payload = a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Execution in progress")) + resp.Err = nil + } + return nil +} + func main() { mode := flag.String("mode", "http", "mode to run in: http(JSON-RPC/REST) or grpc") httpPort := flag.Int("http-port", 9999, "HTTP port") diff --git a/internal/pathtemplate/pathtemplate.go b/internal/pathtemplate/pathtemplate.go new file mode 100644 index 00000000..eb61178e --- /dev/null +++ b/internal/pathtemplate/pathtemplate.go @@ -0,0 +1,111 @@ +// Copyright 2026 The A2A Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pathtemplate provides utilities for parsing and matching URI path templates +// containing capture groups. +package pathtemplate + +import ( + "fmt" + "strings" +) + +// Template represents a compiled path template. +type Template struct { + segments []string +} + +// MatchResult contains the result of a successful path match. +type MatchResult struct { + // Captured is the part of the path captured by the {} group. + Captured string + // Rest is the remaining part of the path after the template segments. + Rest string +} + +// New compiles a raw path template string into a Template. +// A template must contain exactly one capture group {} which can span +// multiple path segments or be part of a single segment. +func New(raw string) (*Template, error) { + raw = trimSlash(raw) + if raw == "" { + return nil, fmt.Errorf("empty template") + } + captureStart, captureEnd := strings.IndexByte(raw, '{'), strings.IndexByte(raw, '}') + if captureStart < 0 || captureEnd < 0 { + return nil, fmt.Errorf("no capture group {} in %s", raw) + } + if captureStart > captureEnd { + return nil, fmt.Errorf("invalid capture group in %s", raw) + } + anotherOpen, anotherClose := strings.LastIndexByte(raw, '{'), strings.LastIndexByte(raw, '}') + if captureStart != anotherOpen || captureEnd != anotherClose { + return nil, fmt.Errorf("duplicate { or } in %s", raw) + } + + var segments []string + for s := range strings.SplitSeq(trimSlash(raw[:captureStart]), "/") { + if s != "" { + segments = append(segments, s) + } + } + segments = append(segments, "{") + for s := range strings.SplitSeq(trimSlash(raw[captureStart+1:captureEnd]), "/") { + if s != "" { + segments = append(segments, s) + } + } + segments = append(segments, "}") + for s := range strings.SplitSeq(trimSlash(raw[captureEnd+1:]), "/") { + if s != "" { + segments = append(segments, s) + } + } + return &Template{segments: segments}, nil +} + +// Match attempts to match the provided path against the template. +// If the path matches, it returns the MatchResult and true. +// Otherwise, it returns nil and false. +func (c *Template) Match(path string) (*MatchResult, bool) { + segments := strings.Split(trimSlash(path), "/") + capturedParts, inCapture := []string{}, false + pathIdx := 0 + for tplIdx := range c.segments { + tSegment := c.segments[tplIdx] + if tSegment == "{" || tSegment == "}" { + inCapture = tSegment == "{" + continue + } + if pathIdx >= len(segments) { + return nil, false + } + pSegment := segments[pathIdx] + if tSegment != "*" && tSegment != segments[pathIdx] { + return nil, false + } + if inCapture { + capturedParts = append(capturedParts, pSegment) + } + pathIdx++ + } + return &MatchResult{ + Captured: strings.Join(capturedParts, "/"), + Rest: "/" + strings.Join(segments[pathIdx:], "/"), + }, true +} + +func trimSlash(s string) string { + return strings.TrimSuffix(strings.TrimPrefix(s, "/"), "/") +} diff --git a/internal/pathtemplate/pathtemplate_test.go b/internal/pathtemplate/pathtemplate_test.go new file mode 100644 index 00000000..7c0345f8 --- /dev/null +++ b/internal/pathtemplate/pathtemplate_test.go @@ -0,0 +1,193 @@ +// Copyright 2026 The A2A Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pathtemplate + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestNew(t *testing.T) { + t.Parallel() + tests := []struct { + name string + raw string + wantErr string + }{ + { + name: "valid single wildcard capture", + raw: "v1/tasks/{*}", + }, + { + name: "valid path-like capture", + raw: "v1/{*/*/projects/*}/messages", + }, + { + name: "valid path-like capture with leading slash", + raw: "v1{/*/*/projects/*}/messages", + }, + { + name: "valid path-like capture with leading and trailing slash", + raw: "v1{/*/*/projects/*/}messages/", + }, + { + name: "empty template", + raw: "", + wantErr: "empty template", + }, + { + name: "no capture group", + raw: "v1/tasks/*", + wantErr: "no capture group {} in v1/tasks/*", + }, + { + name: "invalid capture group order", + raw: "v1/tasks/}{", + wantErr: "invalid capture group in v1/tasks/}{", + }, + { + name: "duplicate open", + raw: "v1/{{*}}", + wantErr: "duplicate { or } in v1/{{*}}", + }, + { + name: "duplicate close", + raw: "v1/{*}}", + wantErr: "duplicate { or } in v1/{*}}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := New(tt.raw) + if (err != nil) != (tt.wantErr != "") { + t.Fatalf("New(%q) error = %v, wantErr %v", tt.raw, err, tt.wantErr) + } + if err != nil && err.Error() != tt.wantErr { + t.Fatalf("New(%q) error = %v, want %q", tt.raw, err, tt.wantErr) + } + if err == nil && got == nil { + t.Fatalf("New(%q) = nil, want non-nil", tt.raw) + } + }) + } +} + +func TestMatch(t *testing.T) { + t.Parallel() + tests := []struct { + name string + template string + path string + want *MatchResult + wantOk bool + }{ + { + name: "simple match", + template: "v1/tasks/{*}", + path: "v1/tasks/123", + want: &MatchResult{Captured: "123", Rest: "/"}, + wantOk: true, + }, + { + name: "match with rest", + template: "v1/tasks/{*}", + path: "v1/tasks/123/messages", + want: &MatchResult{Captured: "123", Rest: "/messages"}, + wantOk: true, + }, + { + name: "multi-segment wildcard capture", + template: "v1/{*/*/projects/*}/messages", + path: "v1/locations/us-central1/projects/my-project/messages/456", + want: &MatchResult{Captured: "locations/us-central1/projects/my-project", Rest: "/456"}, + wantOk: true, + }, + { + name: "multi-segment wildcard capture with leading and trailing slash", + template: "v1{/*/*/projects/*/}messages", + path: "v1/locations/us-central1/projects/my-project/messages/456", + want: &MatchResult{Captured: "locations/us-central1/projects/my-project", Rest: "/456"}, + wantOk: true, + }, + { + name: "no match wrong segment", + template: "v1/tasks/{*}", + path: "v2/tasks/123", + wantOk: false, + }, + { + name: "no match path too short", + template: "v1/tasks/{*}", + path: "v1/tasks", + wantOk: false, + }, + { + name: "exact match inside capture", + template: "v1/tasks/{foo}", + path: "v1/tasks/foo", + want: &MatchResult{Captured: "foo", Rest: "/"}, + wantOk: true, + }, + { + name: "exact match mismatch inside capture", + template: "v1/tasks/{foo}", + path: "v1/tasks/bar", + wantOk: false, + }, + { + name: "match with leading and trailing slashes", + template: "/v1/tasks/{*}/", + path: "/v1/tasks/123/", + want: &MatchResult{Captured: "123", Rest: "/"}, + wantOk: true, + }, + { + name: "segments after capture match", + template: "v1/tasks/{*}/messages", + path: "v1/tasks/123/messages/456", + want: &MatchResult{Captured: "123", Rest: "/456"}, + wantOk: true, + }, + { + name: "segments after capture mismatch", + template: "v1/tasks/{*}/messages", + path: "v1/tasks/123/logs", + wantOk: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tpl, err := New(tt.template) + if err != nil { + t.Fatalf("New(%q) unexpected error: %v", tt.template, err) + } + got, ok := tpl.Match(tt.path) + if ok != tt.wantOk { + t.Fatalf("Template(%q).Match(%q) ok = %v, want %v", tt.template, tt.path, ok, tt.wantOk) + } + if !ok { + return + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("Template(%q).Match(%q) wrong result (+got,-want) diff = %s", tt.template, tt.path, diff) + } + }) + } +} diff --git a/internal/sse/sse.go b/internal/sse/sse.go index ebc192ce..f3852e22 100644 --- a/internal/sse/sse.go +++ b/internal/sse/sse.go @@ -31,8 +31,8 @@ const ( // ContentEventStream is the MIME type for Server-Sent Events. ContentEventStream = "text/event-stream" - sseIDPrefix = "id: " - sseDataPrefix = "data: " + sseIDPrefix = "id:" + sseDataPrefix = "data:" // MaxSSETokenSize is the maximum size for SSE data lines (10MB). // The default bufio.Scanner buffer of 64KB is insufficient for large payloads @@ -76,10 +76,10 @@ func (w *SSEWriter) WriteKeepAlive(ctx context.Context) error { // WriteData writes a data block to the SSE stream. func (w *SSEWriter) WriteData(ctx context.Context, data []byte) error { eventID := uuid.NewString() - if _, err := fmt.Fprintf(w.writer, "%s%s\n", sseIDPrefix, []byte(eventID)); err != nil { + if _, err := fmt.Fprintf(w.writer, "%s %s\n", sseIDPrefix, []byte(eventID)); err != nil { return err } - if _, err := fmt.Fprintf(w.writer, "%s%s\n\n", sseDataPrefix, data); err != nil { + if _, err := fmt.Fprintf(w.writer, "%s %s\n\n", sseDataPrefix, data); err != nil { return err } w.flusher.Flush() @@ -92,12 +92,16 @@ func ParseDataStream(body io.Reader) iter.Seq2[[]byte, error] { scanner := bufio.NewScanner(body) buf := make([]byte, 0, bufio.MaxScanTokenSize) scanner.Buffer(buf, MaxSSETokenSize) + // Check for "data:" prefix (without space) to support both "data: foo" and "data:foo" prefixBytes := []byte(sseDataPrefix) for scanner.Scan() { lineBytes := scanner.Bytes() if bytes.HasPrefix(lineBytes, prefixBytes) { data := lineBytes[len(prefixBytes):] + if len(data) > 0 && data[0] == ' ' { + data = data[1:] + } if !yield(data, nil) { return } diff --git a/internal/sse/sse_test.go b/internal/sse/sse_test.go index 124848f5..62e68caf 100644 --- a/internal/sse/sse_test.go +++ b/internal/sse/sse_test.go @@ -15,6 +15,7 @@ package sse import ( + "fmt" "net/http" "net/http/httptest" "strconv" @@ -123,3 +124,57 @@ func TestSSE_LargePayload(t *testing.T) { t.Fatalf("ParseDataStream() emitted %d events, want 1", eventCount) } } + +func TestSSE_NoSpaceCompatibility(t *testing.T) { + // Some frameworks (e.g. Spring) emit "data:foo" instead of "data: foo". + // We need to support this for compatibility. + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + // To test compatibility with non-standard SSE streams that omit the space + // after "data:", we write the response manually instead of using SSEWriter. + flusher, ok := rw.(http.Flusher) + if !ok { + t.Fatalf("streaming not supported") + } + + rw.Header().Set("Content-Type", "text/event-stream") + rw.Header().Set("Cache-Control", "no-cache") + rw.Header().Set("Connection", "keep-alive") + rw.WriteHeader(http.StatusOK) + + if _, err := fmt.Fprintf(rw, "id: %s\n", "1"); err != nil { + t.Fatalf("fmt.Fprintf(id) error = %v", err) + } + if _, err := fmt.Fprintf(rw, "data:%s\n\n", "payload-without-space"); err != nil { + t.Fatalf("fmt.Fprintf(data) error = %v", err) + } + flusher.Flush() + })) + defer server.Close() + + ctx := t.Context() + req, err := http.NewRequestWithContext(ctx, "POST", server.URL, nil) + if err != nil { + t.Fatalf("http.NewRequestWithContext() error = %v", err) + } + req.Header.Set("Accept", ContentEventStream) + client := http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client.Do() error = %v", err) + } + defer func() { _ = resp.Body.Close() }() + + eventCount := 0 + for data, err := range ParseDataStream(resp.Body) { + if err != nil { + t.Fatalf("ParseDataStream() error = %v", err) + } + if string(data) != "payload-without-space" { + t.Fatalf("ParseDataStream() = %q, want %q", string(data), "payload-without-space") + } + eventCount++ + } + if eventCount != 1 { + t.Fatalf("ParseDataStream() emitted %d events, want 1", eventCount) + } +} diff --git a/internal/taskupdate/manager.go b/internal/taskupdate/manager.go index a62a3608..e7d9f53f 100644 --- a/internal/taskupdate/manager.go +++ b/internal/taskupdate/manager.go @@ -16,6 +16,7 @@ package taskupdate import ( "context" + "errors" "fmt" "maps" "slices" @@ -26,30 +27,32 @@ import ( "github.com/a2aproject/a2a-go/v1/log" ) +const maxCancelationAttempts = 10 + // Manager is used for processing [a2a.Event] related to an [a2a.Task]. It updates // the Task accordingly and uses [taskstore.Store] to store the new state. type Manager struct { - taskInfo a2a.TaskInfo - lastSaved *taskstore.StoredTask - store taskstore.Store + taskInfo a2a.TaskInfo + lastStored *taskstore.StoredTask + store taskstore.Store } // NewManager is a [Manager] constructor function. func NewManager(store taskstore.Store, info a2a.TaskInfo, task *taskstore.StoredTask) *Manager { return &Manager{ - taskInfo: info, - lastSaved: task, - store: store, + taskInfo: info, + lastStored: task, + store: store, } } // SetTaskFailed attempts to move the Task to failed state and returns it in case of a success. func (mgr *Manager) SetTaskFailed(ctx context.Context, event a2a.Event, cause error) (*taskstore.StoredTask, error) { - if mgr.lastSaved == nil { + if mgr.lastStored == nil { return nil, fmt.Errorf("execution failed before a task was created: %w", cause) } - task := *mgr.lastSaved.Task // copy to update task status + task := *mgr.lastStored.Task // copy to update task status // do not store cause.Error() as part of status to not disclose the cause to clients task.Status = a2a.TaskStatus{State: a2a.TaskStateFailed} @@ -59,13 +62,13 @@ func (mgr *Manager) SetTaskFailed(ctx context.Context, event a2a.Event, cause er } log.Info(ctx, "task moved to failed state", "cause", cause.Error()) - return mgr.lastSaved, nil + return mgr.lastStored, nil } // Process validates the event associated with the managed [a2a.Task] and integrates the new state into it. func (mgr *Manager) Process(ctx context.Context, event a2a.Event) (*taskstore.StoredTask, error) { if _, ok := event.(*a2a.Message); ok { - if mgr.lastSaved != nil { + if mgr.lastStored != nil { return nil, fmt.Errorf("message not allowed after task was stored: %w", a2a.ErrInvalidAgentResponse) } return nil, nil @@ -78,7 +81,7 @@ func (mgr *Manager) Process(ctx context.Context, event a2a.Event) (*taskstore.St return mgr.saveTask(ctx, v, event) } - if mgr.lastSaved == nil { + if mgr.lastStored == nil { return nil, fmt.Errorf("first event must be a Task or a message: %w", a2a.ErrInvalidAgentResponse) } @@ -104,7 +107,7 @@ func (mgr *Manager) Process(ctx context.Context, event a2a.Event) (*taskstore.St } func (mgr *Manager) updateArtifact(ctx context.Context, event *a2a.TaskArtifactUpdateEvent) (*taskstore.StoredTask, error) { - task := mgr.lastSaved.Task + task := mgr.lastStored.Task // The copy is required because the event will be passed to subscriber goroutines, while // the artifact might be modified in our goroutine by other TaskArtifactUpdateEvent-s. @@ -142,44 +145,77 @@ func (mgr *Manager) updateArtifact(ctx context.Context, event *a2a.TaskArtifactU } func (mgr *Manager) updateStatus(ctx context.Context, event *a2a.TaskStatusUpdateEvent) (*taskstore.StoredTask, error) { - task, err := utils.DeepCopy(mgr.lastSaved.Task) + lastStored, err := utils.DeepCopy(mgr.lastStored) if err != nil { return nil, err } - if task.Status.Message != nil { - task.History = append(task.History, task.Status.Message) - } + for range maxCancelationAttempts { + task := lastStored.Task + if task.Status.Message != nil { + task.History = append(task.History, task.Status.Message) + } + if event.Metadata != nil { + if task.Metadata == nil { + task.Metadata = make(map[string]any) + } + maps.Copy(task.Metadata, event.Metadata) + } + task.Status = event.Status - if event.Metadata != nil { - if task.Metadata == nil { - task.Metadata = make(map[string]any) + vt, err := mgr.saveVersionedTask(ctx, task, event, lastStored.Version) + if err == nil { + return vt, nil } - maps.Copy(task.Metadata, event.Metadata) - } - task.Status = event.Status + if !errors.Is(err, a2a.ErrConcurrentTaskModification) || event.Status.State != a2a.TaskStateCanceled { + return nil, err + } - return mgr.saveTask(ctx, task, event) + storedTask, getErr := mgr.store.Get(ctx, event.TaskID) + if getErr != nil { + return nil, fmt.Errorf("failed to get task: %w", getErr) + } + + if storedTask.Task.Status.State == a2a.TaskStateCanceled { + mgr.lastStored = storedTask + return mgr.lastStored, nil + } + if storedTask.Task.Status.State.Terminal() { + return nil, fmt.Errorf("task moved to %q before it could be cancelled", storedTask.Task.Status.State) + } + + lastStored = storedTask + } + + return nil, fmt.Errorf("max task cancelation attempts reached") } func (mgr *Manager) saveTask(ctx context.Context, task *a2a.Task, event a2a.Event) (*taskstore.StoredTask, error) { + version := taskstore.TaskVersionMissing + if mgr.lastStored != nil { + version = mgr.lastStored.Version + } + return mgr.saveVersionedTask(ctx, task, event, version) +} + +func (mgr *Manager) saveVersionedTask(ctx context.Context, task *a2a.Task, event a2a.Event, prevVersion taskstore.TaskVersion) (*taskstore.StoredTask, error) { var version taskstore.TaskVersion var err error - if mgr.lastSaved == nil { + if mgr.lastStored == nil { version, err = mgr.store.Create(ctx, task) } else { version, err = mgr.store.Update(ctx, &taskstore.UpdateRequest{ Task: task, Event: event, - PrevVersion: mgr.lastSaved.Version, + PrevVersion: prevVersion, }) } if err != nil { return nil, fmt.Errorf("failed to save task state: %w", err) } - mgr.lastSaved = &taskstore.StoredTask{Task: task, Version: version} - return mgr.lastSaved, nil + mgr.lastStored = &taskstore.StoredTask{Task: task, Version: version} + return mgr.lastStored, nil } func (mgr *Manager) validate(provider a2a.TaskInfoProvider) error { diff --git a/internal/taskupdate/manager_test.go b/internal/taskupdate/manager_test.go index 1466e71d..af69812e 100644 --- a/internal/taskupdate/manager_test.go +++ b/internal/taskupdate/manager_test.go @@ -23,6 +23,7 @@ import ( "github.com/a2aproject/a2a-go/v1/a2a" "github.com/a2aproject/a2a-go/v1/a2asrv/taskstore" + "github.com/a2aproject/a2a-go/v1/internal/utils" "github.com/google/go-cmp/cmp" ) @@ -33,10 +34,6 @@ func newTestTask() *taskstore.StoredTask { } } -func newStatusUpdate(task *a2a.Task) *a2a.TaskStatusUpdateEvent { - return &a2a.TaskStatusUpdateEvent{TaskID: task.ID, ContextID: task.ContextID} -} - func getText(m *a2a.Message) string { return m.Parts[0].Text() } @@ -47,13 +44,24 @@ type testSaver struct { version taskstore.TaskVersion versionSet bool fail error + failOnce error } func newTestSaver() *testSaver { return &testSaver{InMemory: taskstore.NewInMemory(nil)} } +func (s *testSaver) Get(ctx context.Context, taskID a2a.TaskID) (*taskstore.StoredTask, error) { + return &taskstore.StoredTask{Task: s.saved, Version: s.version}, nil +} + func (s *testSaver) Update(ctx context.Context, req *taskstore.UpdateRequest) (taskstore.TaskVersion, error) { + if s.failOnce != nil { + err := s.failOnce + s.failOnce = nil + return taskstore.TaskVersionMissing, err + } + if s.fail != nil { return taskstore.TaskVersionMissing, s.fail } @@ -86,8 +94,8 @@ func TestManager_TaskSaved(t *testing.T) { newState := a2a.TaskStateCanceled updated := &a2a.Task{ - ID: m.lastSaved.Task.ID, - ContextID: m.lastSaved.Task.ContextID, + ID: m.lastStored.Task.ID, + ContextID: m.lastStored.Task.ContextID, Status: a2a.TaskStatus{State: newState}, } result, err := m.Process(t.Context(), updated) @@ -111,7 +119,7 @@ func TestManager_SaverError(t *testing.T) { wantErr := errors.New("saver failed") saver.fail = wantErr - if _, err := m.Process(t.Context(), m.lastSaved.Task); !errors.Is(err, wantErr) { + if _, err := m.Process(t.Context(), m.lastStored.Task); !errors.Is(err, wantErr) { t.Fatalf("m.Process() = %v, want %v", err, wantErr) } } @@ -119,12 +127,11 @@ func TestManager_SaverError(t *testing.T) { func TestManager_StatusUpdate_StateChanges(t *testing.T) { m, _ := newUpdaterWithStoredTask() - m.lastSaved.Task.Status = a2a.TaskStatus{State: a2a.TaskStateSubmitted} + m.lastStored.Task.Status = a2a.TaskStatus{State: a2a.TaskStateSubmitted} states := []a2a.TaskState{a2a.TaskStateWorking, a2a.TaskStateCompleted} for _, state := range states { - event := newStatusUpdate(m.lastSaved.Task) - event.Status.State = state + event := a2a.NewStatusUpdateEvent(m.lastStored.Task, state, nil) versioned, err := m.Process(t.Context(), event) if err != nil { @@ -142,8 +149,7 @@ func TestManager_StatusUpdate_CurrentStatusBecomesHistory(t *testing.T) { var lastResult *taskstore.StoredTask messages := []string{"hello", "world", "foo", "bar"} for i, msg := range messages { - event := newStatusUpdate(m.lastSaved.Task) - event.Status.Message = a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart(msg)) + event := a2a.NewStatusUpdateEvent(m.lastStored.Task, a2a.TaskStateWorking, a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart(msg))) versioned, err := m.Process(t.Context(), event) if err != nil { @@ -177,7 +183,7 @@ func TestManager_StatusUpdate_MetadataUpdated(t *testing.T) { var lastResult *a2a.Task for i, metadata := range updates { - event := newStatusUpdate(m.lastSaved.Task) + event := a2a.NewStatusUpdateEvent(m.lastStored.Task, a2a.TaskStateWorking, nil) event.Metadata = metadata result, err := m.Process(t.Context(), event) @@ -580,3 +586,141 @@ func TestManager_SetTaskFailedAfterInvalidUpdate(t *testing.T) { }) } } + +func TestManager_CancelationStatusUpdate_RetryOnConcurrentModification(t *testing.T) { + tid, ctxID := a2a.NewTaskID(), a2a.NewContextID() + taskInfo := a2a.TaskInfo{TaskID: tid, ContextID: ctxID} + testCases := []struct { + name string + initialState taskstore.StoredTask + statusUpdate *a2a.TaskStatusUpdateEvent + firstUpdateErr error + getResult *a2a.Task + wantResult *taskstore.StoredTask + wantErrContain string + }{ + { + name: "concurrent update and task is non-terminal - retry succeeds", + initialState: taskstore.StoredTask{ + Task: &a2a.Task{Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted}}, + Version: 1, + }, + statusUpdate: &a2a.TaskStatusUpdateEvent{ + TaskID: tid, ContextID: ctxID, + Status: a2a.TaskStatus{State: a2a.TaskStateCanceled}, + Metadata: map[string]any{"hello": "world"}, + }, + firstUpdateErr: a2a.ErrConcurrentTaskModification, + getResult: &a2a.Task{ + Status: a2a.TaskStatus{State: a2a.TaskStateWorking}, + Metadata: map[string]any{"foo": "bar"}, + }, + wantResult: &taskstore.StoredTask{ + Task: &a2a.Task{ + Status: a2a.TaskStatus{State: a2a.TaskStateCanceled}, + Metadata: map[string]any{"foo": "bar", "hello": "world"}, + }, + Version: 3, + }, + }, + { + name: "not concurrent update error - cancel fails", + statusUpdate: a2a.NewStatusUpdateEvent(taskInfo, a2a.TaskStateCanceled, nil), + initialState: taskstore.StoredTask{ + Task: &a2a.Task{Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted}}, + Version: 1, + }, + firstUpdateErr: errors.New("db error"), + getResult: &a2a.Task{ + Status: a2a.TaskStatus{State: a2a.TaskStateWorking}, + }, + wantErrContain: "db error", + }, + { + name: "not cancelation - update fails", + statusUpdate: a2a.NewStatusUpdateEvent(taskInfo, a2a.TaskStateWorking, nil), + initialState: taskstore.StoredTask{ + Task: &a2a.Task{Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted}}, + Version: 1, + }, + firstUpdateErr: a2a.ErrConcurrentTaskModification, + wantErrContain: a2a.ErrConcurrentTaskModification.Error(), + }, + { + name: "concurrent update and task is canceled - task returned as result", + statusUpdate: a2a.NewStatusUpdateEvent(taskInfo, a2a.TaskStateCanceled, nil), + initialState: taskstore.StoredTask{ + Task: &a2a.Task{Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted}}, + Version: 1, + }, + firstUpdateErr: a2a.ErrConcurrentTaskModification, + getResult: &a2a.Task{ + Status: a2a.TaskStatus{State: a2a.TaskStateCanceled}, + }, + wantResult: &taskstore.StoredTask{ + Task: &a2a.Task{Status: a2a.TaskStatus{State: a2a.TaskStateCanceled}}, + Version: 2, + }, + }, + { + name: "concurrent update and task in terminal state - fail", + statusUpdate: a2a.NewStatusUpdateEvent(taskInfo, a2a.TaskStateCanceled, nil), + initialState: taskstore.StoredTask{ + Task: &a2a.Task{Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted}}, + Version: 1, + }, + firstUpdateErr: a2a.ErrConcurrentTaskModification, + getResult: &a2a.Task{ + Status: a2a.TaskStatus{State: a2a.TaskStateCompleted}, + }, + wantErrContain: fmt.Sprintf("task moved to %q before it could be cancelled", a2a.TaskStateCompleted), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + saver := &testSaver{} + + task := &taskstore.StoredTask{Task: &a2a.Task{ID: tid, ContextID: ctxID}, Version: tc.initialState.Version} + task.Task.Status = tc.initialState.Task.Status + + saver.saved = task.Task + saver.version = task.Version + saver.versionSet = true + saver.failOnce = tc.firstUpdateErr + + m := NewManager(saver, task.Task.TaskInfo(), task) + + if tc.getResult != nil { + updated, _ := utils.DeepCopy(task.Task) + updated.Status = tc.getResult.Status + saver.saved = updated + saver.version = 2 + } + + versioned, err := m.Process(t.Context(), tc.statusUpdate) + if tc.wantErrContain != "" { + if err == nil { + t.Fatalf("m.Process() expected error, got nil") + } + if !strings.Contains(err.Error(), tc.wantErrContain) { + t.Fatalf("got error %q, want contain %q", err.Error(), tc.wantErrContain) + } + return + } + if err != nil { + t.Fatalf("m.Process() unexpected error: %v", err) + } + + if tc.wantResult != nil { + if versioned.Version != tc.wantResult.Version { + t.Errorf("got version %d, want %d", versioned.Version, tc.wantResult.Version) + } + if versioned.Task.Status.State != tc.wantResult.Task.Status.State { + t.Errorf("got state %q, want %q", versioned.Task.Status.State, tc.wantResult.Task.Status.State) + } + } + }) + } +}