diff --git a/cli/azd/extensions/azure.ai.agents/cspell.yaml b/cli/azd/extensions/azure.ai.agents/cspell.yaml index 8dfa8421459..34b685b78ed 100644 --- a/cli/azd/extensions/azure.ai.agents/cspell.yaml +++ b/cli/azd/extensions/azure.ai.agents/cspell.yaml @@ -50,6 +50,9 @@ words: - projectpkg - protocolversionrecord - Qdrant + - retarget + - retargeted + - retargets - Toolsets - Vnext - webp diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/pending_toolboxes.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/pending_toolboxes.go new file mode 100644 index 00000000000..1d3e6e443c4 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/pending_toolboxes.go @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/url" + "strings" + "time" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +// pendingToolboxesPath is the UserConfig root for per-endpoint pending toolbox buckets. +const pendingToolboxesPath = configPathPrefix + ".pending-toolboxes" + +// PendingToolbox is the per-name record persisted under +// extensions.ai-agents.pending-toolboxes..items.. +type PendingToolbox struct { + Description string `json:"description,omitempty"` + CreatedAt string `json:"createdAt"` +} + +// pendingToolboxBucket is the value persisted per endpoint hash. It carries +// the plain-text endpoint as a sibling of items so the bucket is self-describing. +type pendingToolboxBucket struct { + Endpoint string `json:"endpoint"` + Items map[string]PendingToolbox `json:"items,omitempty"` +} + +// endpointBucketKey returns the 16-hex-char opaque key used to bucket pending +// records per endpoint. The key shape (hex.EncodeToString of the +// first 8 bytes of the sha256 digest) is part of the persisted config schema: +// changing it would orphan every existing record. +func endpointBucketKey(endpoint string) string { + normalized := normalizePendingEndpoint(endpoint) + h := sha256.Sum256([]byte(normalized)) + return hex.EncodeToString(h[:8]) +} + +// normalizePendingEndpoint canonicalizes the endpoint to ensure two equivalent +// endpoints land in the same bucket. Lower-cases the host, strips trailing slashes. +func normalizePendingEndpoint(endpoint string) string { + trimmed := strings.TrimRight(strings.TrimSpace(endpoint), "/") + u, err := url.Parse(trimmed) + if err != nil || u.Host == "" { + return strings.ToLower(trimmed) + } + u.Host = strings.ToLower(u.Host) + return strings.TrimRight(u.String(), "/") +} + +// pendingBucketPath builds the full UserConfig path for one endpoint bucket. +func pendingBucketPath(endpoint string) string { + return pendingToolboxesPath + "." + endpointBucketKey(endpoint) +} + +// getPendingBucket loads the pending bucket for an endpoint. Returns an empty +// (non-nil) bucket when no record exists. +func getPendingBucket( + ctx context.Context, azdClient *azdext.AzdClient, endpoint string, +) (*pendingToolboxBucket, error) { + ch, err := azdext.NewConfigHelper(azdClient) + if err != nil { + return nil, fmt.Errorf("pending toolbox bucket: %w", err) + } + + var bucket pendingToolboxBucket + found, err := ch.GetUserJSON(ctx, pendingBucketPath(endpoint), &bucket) + if err != nil { + return nil, fmt.Errorf("pending toolbox bucket: failed to read: %w", err) + } + + if !found { + return &pendingToolboxBucket{ + Endpoint: normalizePendingEndpoint(endpoint), + Items: map[string]PendingToolbox{}, + }, nil + } + if bucket.Items == nil { + bucket.Items = map[string]PendingToolbox{} + } + if bucket.Endpoint == "" { + bucket.Endpoint = normalizePendingEndpoint(endpoint) + } + return &bucket, nil +} + +// setPendingBucket persists a bucket. If the bucket is empty (no items), the +// whole bucket is left in place to preserve the endpoint mapping; callers that +// want full removal should use deletePendingBucket. +func setPendingBucket( + ctx context.Context, azdClient *azdext.AzdClient, endpoint string, bucket *pendingToolboxBucket, +) error { + ch, err := azdext.NewConfigHelper(azdClient) + if err != nil { + return fmt.Errorf("pending toolbox bucket: %w", err) + } + if err := ch.SetUserJSON(ctx, pendingBucketPath(endpoint), bucket); err != nil { + return fmt.Errorf("pending toolbox bucket: failed to write: %w", err) + } + return nil +} + +// getPendingToolbox returns the record for a single name, or (nil, nil) when absent. +func getPendingToolbox( + ctx context.Context, azdClient *azdext.AzdClient, endpoint, name string, +) (*PendingToolbox, error) { + bucket, err := getPendingBucket(ctx, azdClient, endpoint) + if err != nil { + return nil, err + } + if v, ok := bucket.Items[name]; ok { + return &v, nil + } + return nil, nil +} + +// setPendingToolbox creates or updates a pending record for one toolbox. +func setPendingToolbox( + ctx context.Context, azdClient *azdext.AzdClient, + endpoint, name string, record PendingToolbox, +) error { + bucket, err := getPendingBucket(ctx, azdClient, endpoint) + if err != nil { + return err + } + if record.CreatedAt == "" { + record.CreatedAt = time.Now().UTC().Format(time.RFC3339) + } + bucket.Items[name] = record + return setPendingBucket(ctx, azdClient, endpoint, bucket) +} + +// clearPendingToolbox removes a single pending record. +// Returns true when an entry existed and was removed. +func clearPendingToolbox( + ctx context.Context, azdClient *azdext.AzdClient, endpoint, name string, +) (bool, error) { + bucket, err := getPendingBucket(ctx, azdClient, endpoint) + if err != nil { + return false, err + } + if _, ok := bucket.Items[name]; !ok { + return false, nil + } + delete(bucket.Items, name) + return true, setPendingBucket(ctx, azdClient, endpoint, bucket) +} + +// listPendingToolboxes returns all pending records for an endpoint. +func listPendingToolboxes( + ctx context.Context, azdClient *azdext.AzdClient, endpoint string, +) (map[string]PendingToolbox, error) { + bucket, err := getPendingBucket(ctx, azdClient, endpoint) + if err != nil { + return nil, err + } + return bucket.Items, nil +} + +// pendingToolboxStore is the seam used by commands that need to read or clear +// pending records. The production implementation is azd-host-backed; tests +// substitute an in-memory stub. +type pendingToolboxStore interface { + // Get returns the pending record for (endpoint, name), or (nil, nil) when + // absent. A non-nil error means the store could not be consulted at all. + Get(ctx context.Context, endpoint, name string) (*PendingToolbox, error) + // Clear removes a single pending record. Reports whether an entry was present. + Clear(ctx context.Context, endpoint, name string) (bool, error) +} + +type azdPendingToolboxStore struct { + azdClient *azdext.AzdClient +} + +func (s *azdPendingToolboxStore) Get( + ctx context.Context, endpoint, name string, +) (*PendingToolbox, error) { + return getPendingToolbox(ctx, s.azdClient, endpoint, name) +} + +func (s *azdPendingToolboxStore) Clear( + ctx context.Context, endpoint, name string, +) (bool, error) { + return clearPendingToolbox(ctx, s.azdClient, endpoint, name) +} + +// newAzdPendingToolboxStore opens the production store. The caller must invoke +// the returned closer (via defer) to release the underlying azd client. +func newAzdPendingToolboxStore() (pendingToolboxStore, func(), error) { + c, err := azdext.NewAzdClient() + if err != nil { + return nil, func() {}, err + } + closer := func() { c.Close() } + return &azdPendingToolboxStore{azdClient: c}, closer, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go index c63ea180146..70d4ac394e7 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go @@ -61,6 +61,7 @@ func NewRootCommand() *cobra.Command { rootCmd.AddCommand(newFilesCommand(extCtx)) rootCmd.AddCommand(newSessionCommand(extCtx)) rootCmd.AddCommand(newProjectCommand(extCtx)) + rootCmd.AddCommand(newToolboxCommand(extCtx)) return rootCmd } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox.go new file mode 100644 index 00000000000..437cfd4e40a --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox.go @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + + "azureaiagent/internal/exterrors" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// toolboxFlags carries the cross-cutting flags shared by every `toolbox` verb. +type toolboxFlags struct { + projectEndpoint string + output string + noPrompt bool +} + +// toolboxNamePattern is the validation regex for toolbox and tool names. +var toolboxNamePattern = regexp.MustCompile(`^[A-Za-z0-9_-]+$`) + +// maxToolboxNameLength caps positional names. Mirrors the 63-char ceiling used +// by agent names (see parse.go:validNamePattern). +const maxToolboxNameLength = 63 + +// newToolboxCommand builds the `azd ai agent toolbox` parent. +func newToolboxCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + + cmd := &cobra.Command{ + Use: "toolbox", + Short: "Manage Foundry toolboxes (versioned collections of agent tools).", + Long: `Manage Foundry toolboxes. + +A toolbox is a versioned, named collection of connection-backed tools that +agents reference at run time. Each version is immutable and carries the full +tool list; mutations publish a new version and (after the first one) require +an explicit update to retarget the default.`, + } + + // --output and --no-prompt are reserved azd globals and are inherited + // automatically; only the extension-specific flag is registered here. + cmd.PersistentFlags().String( + "project-endpoint", "", + "Foundry project endpoint URL. When unset, falls back to the active azd "+ + "environment, azd user config, then FOUNDRY_PROJECT_ENDPOINT.", + ) + // Advertise the toolbox-specific --output allowed values + default on the + // parent so `azd ai agent toolbox --help` shows them too. Leaf commands + // re-register on themselves; cobra annotations don't propagate. + registerToolboxOutputFlag(cmd) + + cmd.AddCommand(newToolboxCreateCommand(extCtx)) + cmd.AddCommand(newToolboxUpdateCommand(extCtx)) + cmd.AddCommand(newToolboxDeleteCommand(extCtx)) + cmd.AddCommand(newToolboxShowCommand(extCtx)) + cmd.AddCommand(newToolboxListCommand(extCtx)) + cmd.AddCommand(newToolboxConnectionCommand(extCtx)) + + return cmd +} + +// readToolboxFlags extracts the persistent flag values from a subcommand. The +// reserved azd globals `--output` and `--no-prompt` come from extCtx. `output` +// is normalized to lowercase so downstream branches can compare with `== "json"`. +func readToolboxFlags(cmd *cobra.Command, extCtx *azdext.ExtensionContext) toolboxFlags { + pe, _ := cmd.Flags().GetString("project-endpoint") + out := "" + np := false + if extCtx != nil { + out = strings.ToLower(extCtx.OutputFormat) + np = extCtx.NoPrompt + } + return toolboxFlags{projectEndpoint: pe, output: out, noPrompt: np} +} + +// validateOutputFormat returns a structured error when --output is not table/json. +// The azd host normally enforces this via RegisterFlagOptions; the check stays +// for direct `azd x` invocation and for unit-test reach. +func validateOutputFormat(out string) error { + switch strings.ToLower(out) { + case "", "table", "json": + return nil + default: + return exterrors.Validation( + exterrors.CodeInvalidParameter, + fmt.Sprintf("invalid --output value %q", out), + "use table or json", + ) + } +} + +// registerToolboxOutputFlag attaches the --output annotations every toolbox +// leaf command shares. RegisterFlagOptions writes per-command annotations, so +// it must run on each leaf rather than the parent. +func registerToolboxOutputFlag(cmd *cobra.Command) { + azdext.RegisterFlagOptions(cmd, azdext.FlagOptions{ + Name: "output", + AllowedValues: []string{"table", "json"}, + Default: "table", + }) +} + +// validateToolboxName enforces ^[A-Za-z0-9_-]+$ on the positional `` +// and caps length at maxToolboxNameLength. +func validateToolboxName(name string) error { + if !toolboxNamePattern.MatchString(name) || len(name) > maxToolboxNameLength { + return exterrors.Validation( + exterrors.CodeInvalidToolboxName, + fmt.Sprintf("toolbox name %q is invalid", name), + fmt.Sprintf("names must match ^[A-Za-z0-9_-]+$ and be at most %d characters", maxToolboxNameLength), + ) + } + return nil +} + +// validateToolName enforces the same regex on tool-entry names. Failing here +// avoids a service round trip that would yield a generic 400. +func validateToolName(name string) error { + if !toolboxNamePattern.MatchString(name) || len(name) > maxToolboxNameLength { + return exterrors.Validation( + exterrors.CodeInvalidToolboxName, + fmt.Sprintf( + "tool entry name %q is invalid; the Foundry service requires names "+ + "to match ^[A-Za-z0-9_-]+$ (max %d characters)", + name, maxToolboxNameLength, + ), + "rename the project connection so its short name matches the regex", + ) + } + return nil +} + +// resolveToolboxAndClient walks the endpoint cascade, validates flags, and +// returns a toolbox client bound to the resolved endpoint. +func resolveToolboxAndClient( + ctx context.Context, flags toolboxFlags, +) (toolboxClient, *resolvedEndpoint, error) { + if err := validateOutputFormat(flags.output); err != nil { + return nil, nil, err + } + resolved, err := resolveProjectEndpoint(ctx, resolveProjectEndpointOpts{FlagValue: flags.projectEndpoint}) + if err != nil { + return nil, nil, err + } + client, err := newToolboxClient(resolved.Endpoint) + if err != nil { + return nil, nil, err + } + return client, resolved, nil +} + +// isAzureNotFound reports whether err originates from an Azure response with HTTP 404. +func isAzureNotFound(err error) bool { + if err == nil { + return false + } + if respErr, ok := errors.AsType[*azcore.ResponseError](err); ok { + return respErr.StatusCode == http.StatusNotFound + } + return false +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_client.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_client.go new file mode 100644 index 00000000000..9403ed19904 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_client.go @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + + "azureaiagent/internal/pkg/azure" +) + +// toolboxClient is the subset of *azure.FoundryToolboxClient that the toolbox +// command implementations rely on. Defining it as an interface lets unit tests +// inject mock implementations without spinning up an HTTP server. +// +// The real *azure.FoundryToolboxClient satisfies this interface directly. +type toolboxClient interface { + GetToolbox(ctx context.Context, name string) (*azure.ToolboxObject, error) + CreateToolboxVersion( + ctx context.Context, name string, req *azure.CreateToolboxVersionRequest, + ) (*azure.ToolboxVersionObject, error) + DeleteToolbox(ctx context.Context, name string) error + + ListToolboxes(ctx context.Context) ([]azure.ToolboxObject, error) + GetToolboxVersion( + ctx context.Context, name, version string, + ) (*azure.ToolboxVersionObject, error) + ListToolboxVersions( + ctx context.Context, name string, + ) ([]azure.ToolboxVersionObject, error) + DeleteToolboxVersion(ctx context.Context, name, version string) error + SetDefaultVersion( + ctx context.Context, name, version string, + ) (*azure.ToolboxObject, error) + + // Endpoint returns the project endpoint root this client is bound to. + // Used by `toolbox show` to compute the runtime MCP consumption URL. + Endpoint() string +} + +// compile-time guard: the real client must satisfy the interface. +var _ toolboxClient = (*azure.FoundryToolboxClient)(nil) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_commands_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_commands_test.go new file mode 100644 index 00000000000..a0fb0b85e89 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_commands_test.go @@ -0,0 +1,416 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "errors" + "testing" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/azure" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestRunToolboxDeleteWith covers the live + per-version branches of +// runToolboxDeleteWith that do not depend on the azd config store. The +// pending-only path is exercised indirectly through TestEndpointBucketKey. +func TestRunToolboxDeleteWith_Branches(t *testing.T) { + t.Run("not_found_no_pending_returns_validation_error", func(t *testing.T) { + // The default getResults returns NotFound for unknown names. + client := newMockToolboxClient("https://e/") + err := runDeleteToolboxVersion( + t.Context(), client, "https://e/", "missing", + toolboxDeleteFlags{version: "1", force: true}, toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodeToolboxNotFound) + assert.Empty(t, client.deleteVersionCalls) + }) + + t.Run("version_is_default_with_others_blocks_with_retarget_suggestion", func(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "2", + }} + client.listVersionsResults["tb"] = []azure.ToolboxVersionObject{ + {Name: "tb", Version: "1"}, {Name: "tb", Version: "2"}, + } + err := runDeleteToolboxVersion( + t.Context(), client, "https://e/", "tb", + toolboxDeleteFlags{version: "2", force: true}, toolboxFlags{output: "table"}, + ) + le := requireLocalError(t, err, exterrors.CodeDefaultVersionDelete) + assert.Contains(t, le.Suggestion, "default-version") + assert.Empty(t, client.deleteVersionCalls, "service must not be called") + }) + + t.Run("version_is_only_remaining_without_force_blocks", func(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "1", + }} + client.listVersionsResults["tb"] = []azure.ToolboxVersionObject{ + {Name: "tb", Version: "1"}, + } + err := runDeleteToolboxVersion( + t.Context(), client, "https://e/", "tb", + toolboxDeleteFlags{version: "1", force: false}, toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodeOnlyVersionDelete) + assert.Empty(t, client.deleteVersionCalls) + }) + + t.Run("version_is_only_remaining_with_force_proceeds", func(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "1", + }} + client.listVersionsResults["tb"] = []azure.ToolboxVersionObject{ + {Name: "tb", Version: "1"}, + } + err := runDeleteToolboxVersion( + t.Context(), client, "https://e/", "tb", + toolboxDeleteFlags{version: "1", force: true}, toolboxFlags{output: "json"}, + ) + require.NoError(t, err) + require.Len(t, client.deleteVersionCalls, 1) + assert.Equal(t, "tb", client.deleteVersionCalls[0].name) + assert.Equal(t, "1", client.deleteVersionCalls[0].version) + }) + + t.Run("non_default_version_with_force_proceeds", func(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "5", + }} + err := runDeleteToolboxVersion( + t.Context(), client, "https://e/", "tb", + toolboxDeleteFlags{version: "3", force: true}, toolboxFlags{output: "json"}, + ) + require.NoError(t, err) + require.Len(t, client.deleteVersionCalls, 1) + assert.Equal(t, "3", client.deleteVersionCalls[0].version) + }) + + // Non-default version delete has no confirmation prompt. + t.Run("non_default_version_without_force_does_not_prompt", func(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "5", + }} + err := runDeleteToolboxVersion( + t.Context(), client, "https://e/", "tb", + toolboxDeleteFlags{version: "3", force: false}, toolboxFlags{output: "json"}, + ) + require.NoError(t, err) + require.Len(t, client.deleteVersionCalls, 1, + "non-default version delete must proceed without prompting") + }) +} + +func TestRunToolboxDelete_NoPromptWithoutForce(t *testing.T) { + client := newMockToolboxClient("https://e/") + // Parent-toolbox delete with --no-prompt and no --force must reject. + err := runDeleteToolbox( + t.Context(), client, "https://e/", "tb", + toolboxDeleteFlags{}, + toolboxFlags{output: "table", noPrompt: true}, + ) + requireLocalError(t, err, exterrors.CodeMissingForceFlag) +} + +func TestRunToolboxDelete_InvalidName(t *testing.T) { + err := runToolboxDelete( + t.Context(), "bad/name", + toolboxDeleteFlags{force: true}, + toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodeInvalidToolboxName) +} + +func TestRunToolboxShowWith_LiveAndVersionMissing(t *testing.T) { + t.Run("default version live happy path", func(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "1", + }} + client.versionResults["tb/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "tb", Version: "1", Tools: []map[string]any{ + {"type": "mcp", "name": "x", "project_connection_id": "/c/x"}, + }, + }} + err := runToolboxShowWith( + t.Context(), client, "https://e/", "tb", + toolboxShowFlags{}, toolboxFlags{output: "json"}, + ) + require.NoError(t, err) + }) + + t.Run("explicit version missing returns validation error", func(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "1", + }} + err := runToolboxShowWith( + t.Context(), client, "https://e/", "tb", + toolboxShowFlags{version: "9"}, toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodeToolboxNotFound) + }) +} + +func TestRunToolboxListWith_MergesNoPending(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.listToolboxesResult = []azure.ToolboxObject{ + {Name: "alpha", DefaultVersion: "1"}, + {Name: "beta", DefaultVersion: "2"}, + } + client.versionResults["alpha/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "alpha", Version: "1", Tools: []map[string]any{ + {"type": "mcp", "name": "t1"}, + }, + }} + client.versionResults["beta/2"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "beta", Version: "2", Tools: []map[string]any{}, + }} + + err := runToolboxListWith( + t.Context(), client, "https://e/", toolboxFlags{output: "json"}, + ) + require.NoError(t, err) +} + +func TestRunConnectionAddWith_DuplicateRejected(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "1", + }} + client.versionResults["tb/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "tb", Version: "1", Tools: []map[string]any{ + {"type": "mcp", "name": "x", "project_connection_id": "/c/x"}, + }, + }} + + resolver := newStubConnectionResolver() + resolver.byName["x"] = &projectConnection{ + ID: "/c/x", Category: azure.ConnectionTypeRemoteTool, Name: "x", Target: "https://mcp", + } + + err := runConnectionAddWith( + t.Context(), client, resolver, newStubPendingStore(), "https://e/", + "tb", "x", connectionAddFlags{}, toolboxFlags{output: "json"}, + ) + requireLocalError(t, err, exterrors.CodeDuplicateConnection) + assert.Empty(t, client.createVersionCalls, "must not POST on duplicate") +} + +func TestRunConnectionAddWith_AppendsAndPromotesDefault(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{ + Name: "tb", DefaultVersion: "1", + }} + client.versionResults["tb/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "tb", Version: "1", Description: "first", Tools: []map[string]any{ + {"type": "mcp", "name": "a", "project_connection_id": "/c/a"}, + }, + }} + + resolver := newStubConnectionResolver() + resolver.byName["b"] = &projectConnection{ + ID: "/c/b", Category: azure.ConnectionTypeRemoteTool, Name: "b", Target: "https://mcp-b", + } + + err := runConnectionAddWith( + t.Context(), client, resolver, newStubPendingStore(), "https://e/", + "tb", "b", connectionAddFlags{}, toolboxFlags{output: "json"}, + ) + require.NoError(t, err) + require.Len(t, client.createVersionCalls, 1) + assert.Equal(t, "first", client.createVersionCalls[0].req.Description, "description carried forward") + assert.Len(t, client.createVersionCalls[0].req.Tools, 2) + require.Len(t, client.setDefaultCalls, 1, "default_version must be retargeted") +} + +func TestRunConnectionAddWith_ConnectionNotFound(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{Name: "tb", DefaultVersion: "1"}} + client.versionResults["tb/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "tb", Version: "1", Tools: []map[string]any{}, + }} + + resolver := newStubConnectionResolver() + err := runConnectionAddWith( + t.Context(), client, resolver, newStubPendingStore(), "https://e/", + "tb", "missing", connectionAddFlags{}, toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodeConnectionNotFound) + assert.Empty(t, client.createVersionCalls) +} + +func TestRunConnectionRemoveWith_LastToolBlocks(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{Name: "tb", DefaultVersion: "1"}} + client.versionResults["tb/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "tb", Version: "1", Tools: []map[string]any{ + {"type": "mcp", "name": "a", "project_connection_id": "/c/a"}, + }, + }} + resolver := newStubConnectionResolver() + resolver.byName["a"] = &projectConnection{ + ID: "/c/a", Category: azure.ConnectionTypeRemoteTool, Name: "a", + } + + err := runConnectionRemoveWith( + t.Context(), client, resolver, "https://e/", + "tb", "a", toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodeLastToolRemoval) + assert.Empty(t, client.createVersionCalls) +} + +func TestRunConnectionRemoveWith_FilteredAndPromoted(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{Name: "tb", DefaultVersion: "1"}} + client.versionResults["tb/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "tb", Version: "1", Tools: []map[string]any{ + {"type": "mcp", "name": "a", "project_connection_id": "/c/a"}, + {"type": "mcp", "name": "b", "project_connection_id": "/c/b"}, + }, + }} + resolver := newStubConnectionResolver() + resolver.byName["a"] = &projectConnection{ + ID: "/c/a", Category: azure.ConnectionTypeRemoteTool, Name: "a", + } + + err := runConnectionRemoveWith( + t.Context(), client, resolver, "https://e/", + "tb", "a", toolboxFlags{output: "json"}, + ) + require.NoError(t, err) + require.Len(t, client.createVersionCalls, 1) + assert.Len(t, client.createVersionCalls[0].req.Tools, 1) + require.Len(t, client.setDefaultCalls, 1) +} + +func TestRunConnectionRemoveWith_ConnectionNotInToolbox(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{Name: "tb", DefaultVersion: "1"}} + client.versionResults["tb/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "tb", Version: "1", Tools: []map[string]any{ + {"type": "mcp", "name": "other", "project_connection_id": "/c/other"}, + }, + }} + resolver := newStubConnectionResolver() + resolver.byName["a"] = &projectConnection{ + ID: "/c/a", Category: azure.ConnectionTypeRemoteTool, Name: "a", + } + + err := runConnectionRemoveWith( + t.Context(), client, resolver, "https://e/", + "tb", "a", toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodeConnectionNotInToolbox) +} + +func TestRunConnectionListWith_EmitsAllShapes(t *testing.T) { + client := newMockToolboxClient("https://e/") + client.getResults["tb"] = toolboxGetResult{obj: &azure.ToolboxObject{Name: "tb", DefaultVersion: "1"}} + client.versionResults["tb/1"] = toolboxVersionResult{obj: &azure.ToolboxVersionObject{ + Name: "tb", Version: "1", Tools: []map[string]any{ + {"type": "mcp", "name": "m", "project_connection_id": "/conn/m"}, + { + "type": "azure_ai_search", + "name": "s", + "azure_ai_search": map[string]any{ + "indexes": []any{ + map[string]any{"project_connection_id": "/conn/s", "index_name": "i"}, + }, + }, + }, + {"type": "code_interpreter", "name": "ci"}, // not surfaced + }, + }} + + err := runConnectionListWith( + t.Context(), client, "tb", toolboxFlags{output: "json"}, + ) + require.NoError(t, err) +} + +func TestRunToolboxUpdate_MissingDefaultVersion(t *testing.T) { + err := runToolboxUpdate( + t.Context(), "tb", + toolboxUpdateFlags{}, + toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodeMissingUpdateField) +} + +// Pending-record promotion path: POST v1 with the carried-forward +// description, then clear the record. +func TestRunConnectionAddWith_PendingPromotion(t *testing.T) { + client := newMockToolboxClient("https://e/") + resolver := newStubConnectionResolver() + resolver.byName["my-mcp"] = &projectConnection{ + ID: "/c/my-mcp", + Category: azure.ConnectionTypeRemoteTool, + Name: "my-mcp", + Target: "https://mcp.example.com", + } + + store := newStubPendingStore() + store.records[store.key("https://e/", "tb")] = &PendingToolbox{ + Description: "Research-time toolset", + CreatedAt: "2026-05-12T10:23:00Z", + } + + err := runConnectionAddWith( + t.Context(), client, resolver, store, "https://e/", + "tb", "my-mcp", connectionAddFlags{}, toolboxFlags{output: "json"}, + ) + require.NoError(t, err) + require.Len(t, client.createVersionCalls, 1, "v1 must be POSTed") + assert.Equal(t, "Research-time toolset", client.createVersionCalls[0].req.Description, + "description from pending record must be carried forward") + assert.Len(t, client.createVersionCalls[0].req.Tools, 1) + assert.Empty(t, client.setDefaultCalls, "first version is default automatically; no PATCH") + assert.Equal(t, 1, store.clearCalls, "pending record must be cleared") + assert.Empty(t, store.records, "pending record must be removed after success") +} + +// A pending-store read failure must surface as Internal, not silently fall +// through to a misleading CodeToolboxNotFound. +func TestRunConnectionAddWith_PendingStoreFailureSurfaces(t *testing.T) { + client := newMockToolboxClient("https://e/") + resolver := newStubConnectionResolver() + resolver.byName["c"] = &projectConnection{ + ID: "/c/c", Category: azure.ConnectionTypeRemoteTool, Name: "c", + Target: "https://mcp.example.com", + } + + store := newStubPendingStore() + store.getErr = errors.New("config read failed") + + err := runConnectionAddWith( + t.Context(), client, resolver, store, "https://e/", + "tb", "c", connectionAddFlags{}, toolboxFlags{output: "table"}, + ) + requireLocalError(t, err, exterrors.CodePendingToolboxStoreFailed) + assert.Empty(t, client.createVersionCalls, + "existing-toolbox branch must not be entered when the pending store fails") +} + +// Client-side ^[A-Za-z0-9_-]+$ enforcement on tool entry names. +func TestBuildToolEntry_RejectsInvalidName(t *testing.T) { + _, err := buildToolEntry(&projectConnection{ + ID: "/c/x", + Category: azure.ConnectionTypeRemoteTool, + Name: "tools.v1", // dot is not in ^[A-Za-z0-9_-]+$ + Target: "https://mcp", + }, "") + le := requireLocalError(t, err, exterrors.CodeInvalidToolboxName) + assert.Contains(t, le.Message, "tool entry name") + assert.Contains(t, le.Message, "tools.v1") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection.go new file mode 100644 index 00000000000..b61d6323d8c --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection.go @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "fmt" + "strings" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/azure" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// newToolboxConnectionCommand returns the `azd ai agent toolbox connection` parent. +func newToolboxConnectionCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + cmd := &cobra.Command{ + Use: "connection", + Short: "Manage the connection-backed tools attached to a toolbox.", + Long: `Manage the connection-backed tools attached to a toolbox. + +Tools are project connections (MCP servers via RemoteTool, or Azure AI Search +indexes via CognitiveSearch). Each mutation publishes a new immutable version +and retargets the toolbox default.`, + } + cmd.AddCommand(newToolboxConnectionAddCommand(extCtx)) + cmd.AddCommand(newToolboxConnectionRemoveCommand(extCtx)) + cmd.AddCommand(newToolboxConnectionListCommand(extCtx)) + return cmd +} + +// buildToolEntry returns the tool-entry map appropriate for the connection's +// category. Enforces the --index flag rules and the `tool.name` regex. +func buildToolEntry(conn *projectConnection, index string) (map[string]any, error) { + if err := validateToolName(conn.Name); err != nil { + return nil, err + } + switch conn.Category { + case azure.ConnectionTypeRemoteTool: + if index != "" { + return nil, exterrors.Validation( + exterrors.CodeUnsupportedIndexFlag, + fmt.Sprintf( + "--index is only valid for CognitiveSearch connections, "+ + "connection %q has category %q", + conn.Name, conn.Category, + ), + "omit --index for RemoteTool (MCP) connections", + ) + } + // Reject locally rather than letting the service produce a generic 400. + if strings.TrimSpace(conn.Target) == "" { + return nil, exterrors.Validation( + exterrors.CodeConnectionMissingTarget, + fmt.Sprintf( + "connection %q is a RemoteTool but has no target URL", + conn.Name, + ), + "set the target on the project connection (this is the MCP server URL)", + ) + } + return map[string]any{ + "type": "mcp", + "name": conn.Name, + "server_label": conn.Name, + "server_url": conn.Target, + "project_connection_id": conn.ID, + }, nil + + case azure.ConnectionTypeCognitiveSearch: + if strings.TrimSpace(index) == "" { + return nil, exterrors.Validation( + exterrors.CodeMissingIndex, + fmt.Sprintf( + "connection %q is a CognitiveSearch connection; --index is required", + conn.Name, + ), + "pass --index with the search index to attach", + ) + } + return map[string]any{ + "type": "azure_ai_search", + "name": conn.Name, + "azure_ai_search": map[string]any{ + "indexes": []any{ + map[string]any{ + "project_connection_id": conn.ID, + "index_name": index, + }, + }, + }, + }, nil + + default: + return nil, exterrors.Validation( + exterrors.CodeUnsupportedConnectionCategory, + fmt.Sprintf( + "connection %q has category %q which is not supported as a toolbox tool today; "+ + "v1 supports RemoteTool (MCP) and CognitiveSearch (Azure AI Search) only", + conn.Name, conn.Category, + ), + "use a RemoteTool (MCP) or CognitiveSearch (Azure AI Search) connection, "+ + "or file an issue requesting support for the connection category you need", + ) + } +} + +// duplicateConnectionInTools reports whether any tool entry already references +// the given project_connection_id. +func duplicateConnectionInTools(tools []map[string]any, connID string) bool { + found := false + forEachToolConnectionID(tools, func(id string) bool { + if id == connID { + found = true + return true + } + return false + }) + return found +} + +// filterOutConnection returns tools[] with every entry whose +// project_connection_id matches connID stripped (top-level and nested forms). +// `removed` reports whether at least one entry was filtered. +func filterOutConnection(tools []map[string]any, connID string) (result []map[string]any, removed bool) { + for _, t := range tools { + if toolEntryReferences(t, func(id string) bool { return id == connID }) { + removed = true + continue + } + result = append(result, t) + } + return result, removed +} + +// shortConnectionName extracts the connection's short name from the trailing +// segment of its ARM ID (e.g. ".../connections/my-mcp" → "my-mcp"). Falls back +// to the full id when no slash is present. +func shortConnectionName(id string) string { + if id == "" { + return "" + } + if i := strings.LastIndex(id, "/"); i >= 0 && i < len(id)-1 { + return id[i+1:] + } + return id +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_add.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_add.go new file mode 100644 index 00000000000..36fbda1c1c8 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_add.go @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "log" + "slices" + "strings" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/azure" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// connectionAddFlags carries the verb-specific flags for `connection add`. +type connectionAddFlags struct { + index string +} + +// newToolboxConnectionAddCommand returns the `connection add` command. +func newToolboxConnectionAddCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + flags := &connectionAddFlags{} + + cmd := &cobra.Command{ + Use: "add ", + Short: "Attach a project connection to a toolbox.", + Long: `Attach a project connection to a toolbox. + +The tool entry shape is inferred from the connection's category: + RemoteTool → mcp tool wired to the connection's MCP server URL + CognitiveSearch → azure_ai_search tool (requires --index) + Note: "CognitiveSearch" is the category for Azure AI Search. +Other categories are rejected. + +If the toolbox has a local pending record (from 'toolbox create'), v1 is +published with this connection as the only tool. Otherwise the current +default version is fetched, the tool entry is appended, and a new default +version is published.`, + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + return runConnectionAdd( + cmd.Context(), args[0], args[1], *flags, + readToolboxFlags(cmd, extCtx), + defaultConnectionResolver{}, + ) + }, + } + + cmd.Flags().StringVar( + &flags.index, "index", "", + "Index name (required when the connection's category is CognitiveSearch, i.e. Azure AI Search).", + ) + registerToolboxOutputFlag(cmd) + return cmd +} + +func runConnectionAdd( + ctx context.Context, toolboxName, connName string, + verb connectionAddFlags, parent toolboxFlags, + resolver connectionResolver, +) error { + if err := validateToolboxName(toolboxName); err != nil { + return err + } + if err := validateOutputFormat(parent.output); err != nil { + return err + } + if strings.TrimSpace(connName) == "" { + return exterrors.Validation( + exterrors.CodeInvalidPositionalArg, + " must not be empty", + "pass the short name of a project connection", + ) + } + + client, resolved, err := resolveToolboxAndClient(ctx, parent) + if err != nil { + return err + } + logResolvedEndpoint("toolbox connection add", resolved) + + store, closer, err := newAzdPendingToolboxStore() + if err != nil { + return exterrors.Internal(exterrors.CodeAzdClientFailed, + fmt.Sprintf("failed to open the pending-toolbox store: %s", err)) + } + defer closer() + + return runConnectionAddWith(ctx, client, resolver, store, resolved.Endpoint, + toolboxName, connName, verb, parent) +} + +// runConnectionAddWith is the testable core. +func runConnectionAddWith( + ctx context.Context, client toolboxClient, resolver connectionResolver, + store pendingToolboxStore, + endpoint, toolboxName, connName string, + verb connectionAddFlags, parent toolboxFlags, +) error { + conn, err := resolver.resolveConnection(ctx, endpoint, connName) + if err != nil { + return err + } + + entry, err := buildToolEntry(conn, verb.index) + if err != nil { + return err + } + + // Pending-promotion path: if a pending record exists, POST v1 directly. + // A store-read failure must not silently fall through to the live-toolbox + // branch (which would 404 and report CodeToolboxNotFound). + pending, err := store.Get(ctx, endpoint, toolboxName) + if err != nil { + return exterrors.Internal( + exterrors.CodePendingToolboxStoreFailed, + fmt.Sprintf("failed to read pending toolbox state: %s", err), + ) + } + if pending != nil { + req := &azure.CreateToolboxVersionRequest{ + Description: pending.Description, + Tools: []map[string]any{entry}, + } + created, err := client.CreateToolboxVersion(ctx, toolboxName, req) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpCreateToolboxVersion) + } + // The version has been published. A pending-store clear failure is + // non-fatal: log it but proceed so the user sees the success path. + if _, err := store.Clear(ctx, endpoint, toolboxName); err != nil { + log.Printf( + "toolbox connection add: %q v%s was published, but the local pending record could not be cleared: %v", + toolboxName, created.Version, err, + ) + } + return emitConnectionAddResult(toolboxName, created.Version, conn, parent.output, true, endpoint) + } + + // Existing-toolbox path: fetch default → append → POST → PATCH default_version. + tb, err := client.GetToolbox(ctx, toolboxName) + if err != nil { + if isAzureNotFound(err) { + return exterrors.Dependency( + exterrors.CodeToolboxNotFound, + fmt.Sprintf("toolbox %q not found", toolboxName), + fmt.Sprintf( + "run 'azd ai agent toolbox create %q' first, then re-run 'connection add'", + toolboxName, + ), + ) + } + return exterrors.ServiceFromAzure(err, exterrors.OpGetToolbox) + } + + current, err := client.GetToolboxVersion(ctx, toolboxName, tb.DefaultVersion) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpGetToolboxVersion) + } + + if duplicateConnectionInTools(current.Tools, conn.ID) { + return exterrors.Validation( + exterrors.CodeDuplicateConnection, + fmt.Sprintf( + "connection %q (%s) is already attached to toolbox %q", + connName, conn.ID, toolboxName, + ), + fmt.Sprintf("use 'connection list %q' to inspect current tools", toolboxName), + ) + } + + newTools := slices.Clone(current.Tools) + newTools = append(newTools, entry) + + req := &azure.CreateToolboxVersionRequest{ + Description: current.Description, + Metadata: current.Metadata, + Tools: newTools, + } + created, err := client.CreateToolboxVersion(ctx, toolboxName, req) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpCreateToolboxVersion) + } + + if _, err := client.SetDefaultVersion(ctx, toolboxName, created.Version); err != nil { + // The new version exists but isn't the default. Surface this so the + // user can recover with `toolbox update --default-version ` rather + // than silently losing the connection add. + return exterrors.Dependency( + exterrors.CodeSetDefaultVersionFailed, + fmt.Sprintf( + "toolbox %q version %q was created but could not be promoted to default: %s", + toolboxName, created.Version, err, + ), + fmt.Sprintf( + "run `azd ai agent toolbox update %q --default-version %q` to retarget the default", + toolboxName, created.Version, + ), + ) + } + + return emitConnectionAddResult(toolboxName, created.Version, conn, parent.output, false, endpoint) +} + +// emitConnectionAddResult prints the standard output for a successful add. The +// resolved endpoint is included so the user can paste it into agent code. +func emitConnectionAddResult( + toolboxName, newVersion string, conn *projectConnection, output string, promoted bool, endpoint string, +) error { + mcpURL := buildToolboxMcpURL(endpoint, toolboxName, newVersion) + if output == "json" { + payload := map[string]any{ + "toolbox": toolboxName, + "version": newVersion, + "connection": conn.Name, + "connectionId": conn.ID, + "category": string(conn.Category), + "promotedFromPending": promoted, + "endpoint": mcpURL, + } + return emitJSON(payload) + } + if promoted { + fmt.Printf("Published toolbox %s version %s with connection %s.\n", + toolboxName, newVersion, conn.Name) + } else { + fmt.Printf("Attached connection %s to toolbox %s (now at version %s).\n", + conn.Name, toolboxName, newVersion) + } + // Surface the MCP endpoint so the dev can wire it into agent code. Suggest + // `azd env set` when running inside an azd project. + envVar := strings.ReplaceAll(strings.ToUpper(toolboxName), "-", "_") + "_MCP_ENDPOINT" + fmt.Printf("\nEndpoint: %s\n", mcpURL) + fmt.Printf("Save it as an env var:\n azd env set %s %s\n", envVar, mcpURL) + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_list.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_list.go new file mode 100644 index 00000000000..fd0ec9a877b --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_list.go @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "os" + "text/tabwriter" + + "azureaiagent/internal/exterrors" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// newToolboxConnectionListCommand returns the `connection list` command. +func newToolboxConnectionListCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + + cmd := &cobra.Command{ + Use: "list ", + Short: "List the connection-backed tools attached to a toolbox.", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runConnectionList(cmd.Context(), args[0], readToolboxFlags(cmd, extCtx)) + }, + } + registerToolboxOutputFlag(cmd) + return cmd +} + +func runConnectionList(ctx context.Context, toolboxName string, parent toolboxFlags) error { + if err := validateToolboxName(toolboxName); err != nil { + return err + } + if err := validateOutputFormat(parent.output); err != nil { + return err + } + + client, resolved, err := resolveToolboxAndClient(ctx, parent) + if err != nil { + return err + } + logResolvedEndpoint("toolbox connection list", resolved) + + return runConnectionListWith(ctx, client, toolboxName, parent) +} + +func runConnectionListWith( + ctx context.Context, client toolboxClient, toolboxName string, parent toolboxFlags, +) error { + tb, err := client.GetToolbox(ctx, toolboxName) + if err != nil { + return toolboxNotFoundOrService(err, toolboxName, exterrors.OpGetToolbox) + } + + version, err := client.GetToolboxVersion(ctx, toolboxName, tb.DefaultVersion) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpGetToolboxVersion) + } + + connections := extractConnectionTools(version.Tools) + + if parent.output == "json" { + return emitJSON(map[string]any{"connections": connections}) + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "NAME\tCONNECTION\tTYPE") + fmt.Fprintln(w, "----\t----------\t----") + for _, c := range connections { + fmt.Fprintf(w, "%s\t%s\t%s\n", c["name"], c["connection"], c["type"]) + } + return w.Flush() +} + +// extractConnectionTools collapses the tool list to one row per connection-backed +// entry, surfacing the short connection name parsed from the trailing segment +// of the connection ARM ID (the `connection` column in `connection list`). +func extractConnectionTools(tools []map[string]any) []map[string]string { + rows := []map[string]string{} + for _, t := range tools { + toolType, _ := t["type"].(string) + toolName, _ := t["name"].(string) + switch toolType { + case "mcp": + if id, ok := t["project_connection_id"].(string); ok && id != "" { + rows = append(rows, map[string]string{ + "name": toolName, + "connection": shortConnectionName(id), + "connectionId": id, + "type": toolType, + }) + } + case "azure_ai_search": + if search, ok := t["azure_ai_search"].(map[string]any); ok { + if indexes, ok := search["indexes"].([]any); ok { + for _, idx := range indexes { + m, _ := idx.(map[string]any) + if m == nil { + continue + } + id, _ := m["project_connection_id"].(string) + idxName, _ := m["index_name"].(string) + rows = append(rows, map[string]string{ + "name": toolName, + "connection": shortConnectionName(id), + "connectionId": id, + "type": toolType, + "index": idxName, + }) + } + } + } + } + } + return rows +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_remove.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_remove.go new file mode 100644 index 00000000000..a3300e71db6 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_remove.go @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "strings" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/azure" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// newToolboxConnectionRemoveCommand returns the `connection remove` command. +func newToolboxConnectionRemoveCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + + cmd := &cobra.Command{ + Use: "remove ", + Short: "Detach a project connection from a toolbox.", + Long: `Detach a project connection from a toolbox. + +Publishes a new default version with the named connection's tool entry +removed. Refuses to leave the toolbox with zero tools (use 'toolbox delete' +instead).`, + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + return runConnectionRemove( + cmd.Context(), args[0], args[1], + readToolboxFlags(cmd, extCtx), + defaultConnectionResolver{}, + ) + }, + } + registerToolboxOutputFlag(cmd) + return cmd +} + +func runConnectionRemove( + ctx context.Context, toolboxName, connName string, + parent toolboxFlags, resolver connectionResolver, +) error { + if err := validateToolboxName(toolboxName); err != nil { + return err + } + if err := validateOutputFormat(parent.output); err != nil { + return err + } + if strings.TrimSpace(connName) == "" { + return exterrors.Validation( + exterrors.CodeInvalidPositionalArg, + " must not be empty", + "pass the short name of a project connection", + ) + } + + client, resolved, err := resolveToolboxAndClient(ctx, parent) + if err != nil { + return err + } + logResolvedEndpoint("toolbox connection remove", resolved) + + return runConnectionRemoveWith(ctx, client, resolver, resolved.Endpoint, + toolboxName, connName, parent) +} + +func runConnectionRemoveWith( + ctx context.Context, client toolboxClient, resolver connectionResolver, + endpoint, toolboxName, connName string, parent toolboxFlags, +) error { + conn, err := resolver.resolveConnection(ctx, endpoint, connName) + if err != nil { + return err + } + + tb, err := client.GetToolbox(ctx, toolboxName) + if err != nil { + return toolboxNotFoundOrService(err, toolboxName, exterrors.OpGetToolbox) + } + + current, err := client.GetToolboxVersion(ctx, toolboxName, tb.DefaultVersion) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpGetToolboxVersion) + } + + filtered, removed := filterOutConnection(current.Tools, conn.ID) + if !removed { + return exterrors.Validation( + exterrors.CodeConnectionNotInToolbox, + fmt.Sprintf( + "connection %q is not attached to toolbox %q's current default version", + connName, toolboxName, + ), + fmt.Sprintf("run 'azd ai agent toolbox connection list %q'", toolboxName), + ) + } + if len(filtered) == 0 { + return exterrors.Validation( + exterrors.CodeLastToolRemoval, + fmt.Sprintf( + "removing %q would leave toolbox %q with zero tools", + connName, toolboxName, + ), + fmt.Sprintf( + "delete the toolbox with `azd ai agent toolbox delete %q` instead", + toolboxName, + ), + ) + } + + req := &azure.CreateToolboxVersionRequest{ + Description: current.Description, + Metadata: current.Metadata, + Tools: filtered, + } + created, err := client.CreateToolboxVersion(ctx, toolboxName, req) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpCreateToolboxVersion) + } + if _, err := client.SetDefaultVersion(ctx, toolboxName, created.Version); err != nil { + return exterrors.Dependency( + exterrors.CodeSetDefaultVersionFailed, + fmt.Sprintf( + "toolbox %q version %q was created but could not be promoted to default: %s", + toolboxName, created.Version, err, + ), + fmt.Sprintf( + "run `azd ai agent toolbox update %q --default-version %q` to retarget the default", + toolboxName, created.Version, + ), + ) + } + + return emitConnectionRemoveResult(toolboxName, created.Version, conn, parent.output) +} + +func emitConnectionRemoveResult( + toolboxName, newVersion string, conn *projectConnection, output string, +) error { + if output == "json" { + payload := map[string]any{ + "toolbox": toolboxName, + "version": newVersion, + "connection": conn.Name, + "connectionId": conn.ID, + } + return emitJSON(payload) + } + fmt.Printf( + "Detached connection %s from toolbox %s (now at version %s).\n", + conn.Name, toolboxName, newVersion, + ) + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_resolver.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_resolver.go new file mode 100644 index 00000000000..d38335d3120 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_connection_resolver.go @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/azure" +) + +// projectConnection is the minimal slice of an Azure project connection that +// toolbox commands need: the ARM `id` (used as `project_connection_id`), the +// category (drives the tool-entry shape), the short name, and the data-plane +// `target` (becomes `server_url` on MCP tool entries). +type projectConnection struct { + ID string + Category azure.ConnectionType + Name string + Target string +} + +// connectionResolver is the seam that tests substitute with stubConnectionResolver. +type connectionResolver interface { + resolveConnection(ctx context.Context, endpoint, name string) (*projectConnection, error) +} + +type defaultConnectionResolver struct{} + +func (defaultConnectionResolver) resolveConnection( + ctx context.Context, endpoint, name string, +) (*projectConnection, error) { + client, err := newProjectsClientFromEndpoint(endpoint) + if err != nil { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + fmt.Sprintf("failed to build a project client for %s: %s", endpoint, err), + "verify the project endpoint is well-formed", + ) + } + + conn, err := client.GetConnection(ctx, name) + if err != nil { + if isAzureNotFound(err) { + return nil, connectionNotFoundError(name) + } + return nil, exterrors.ServiceFromAzure(err, exterrors.OpResolveProjectConnection) + } + + return &projectConnection{ + ID: conn.ID, + Category: conn.Type, + Name: conn.Name, + Target: conn.Target, + }, nil +} + +func connectionNotFoundError(name string) error { + return exterrors.Validation( + exterrors.CodeConnectionNotFound, + fmt.Sprintf("connection %q was not found on the project", name), + "run `azd ai connection list` to see available connections", + ) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_context.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_context.go new file mode 100644 index 00000000000..da01c87b20a --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_context.go @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "fmt" + "log" + "strings" + + "azureaiagent/internal/pkg/azure" +) + +func trimEndpoint(s string) string { + return strings.TrimRight(strings.TrimSpace(s), "/") +} + +// newToolboxClient builds a FoundryToolboxClient bound to the resolved endpoint. +func newToolboxClient(endpoint string) (*azure.FoundryToolboxClient, error) { + cred, err := newAgentCredential() + if err != nil { + return nil, err + } + return azure.NewFoundryToolboxClient(endpoint, cred), nil +} + +// newProjectsClientFromEndpoint builds a FoundryProjectsClient bound to the +// account+project parsed out of the toolbox endpoint URL. +func newProjectsClientFromEndpoint(endpoint string) (*azure.FoundryProjectsClient, error) { + account, project, err := parseAccountProjectFromEndpoint(endpoint) + if err != nil { + return nil, err + } + cred, err := newAgentCredential() + if err != nil { + return nil, err + } + return azure.NewFoundryProjectsClient(account, project, cred) +} + +// parseAccountProjectFromEndpoint extracts account + project names from an endpoint +// formatted as `https://.services.ai.azure.com/api/projects/` (with +// optional trailing path). +func parseAccountProjectFromEndpoint(endpoint string) (account, project string, err error) { + trimmed := trimEndpoint(endpoint) + const marker = ".services.ai.azure.com/api/projects/" + before, after, ok := strings.Cut(trimmed, marker) + if !ok { + return "", "", fmt.Errorf( + "endpoint %q does not match the expected pattern .services.ai.azure.com/api/projects/", + endpoint, + ) + } + hostPart := before + if schemeIdx := strings.Index(hostPart, "://"); schemeIdx >= 0 { + hostPart = hostPart[schemeIdx+3:] + } + rest := after + projectName := rest + if before, _, ok := strings.Cut(rest, "/"); ok { + projectName = before + } + if hostPart == "" || projectName == "" { + return "", "", fmt.Errorf("endpoint %q is missing the account or project segment", endpoint) + } + return hostPart, projectName, nil +} + +// logResolvedEndpoint records the resolved endpoint and source to --debug. +func logResolvedEndpoint(verb string, r *resolvedEndpoint) { + if r == nil { + return + } + log.Printf("%s: resolved project endpoint %s (source=%s)", verb, r.Endpoint, r.Source) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_create.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_create.go new file mode 100644 index 00000000000..c5f59dbdf75 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_create.go @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "time" + + "azureaiagent/internal/exterrors" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// toolboxCreateFlags holds the verb-specific flags for `toolbox create`. +type toolboxCreateFlags struct { + description string +} + +// newToolboxCreateCommand returns the `azd ai agent toolbox create ` command. +// `create` records a local pending entry; v1 is POSTed on the first +// `connection add`. +func newToolboxCreateCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + flags := &toolboxCreateFlags{} + + cmd := &cobra.Command{ + Use: "create ", + Short: "Register a new toolbox locally (publishes on first `connection add`).", + Long: `Register a new toolbox locally. + +A toolbox must have at least one tool before it can be published, so 'create' +only records a local pending entry. The first 'connection add' against the +same toolbox name publishes v1 and clears the pending record.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runToolboxCreate(cmd.Context(), args[0], *flags, readToolboxFlags(cmd, extCtx)) + }, + } + + cmd.Flags().StringVar( + &flags.description, "description", "", + "Optional description recorded with the toolbox.", + ) + registerToolboxOutputFlag(cmd) + + return cmd +} + +func runToolboxCreate( + ctx context.Context, name string, verb toolboxCreateFlags, parent toolboxFlags, +) error { + if err := validateToolboxName(name); err != nil { + return err + } + if err := validateOutputFormat(parent.output); err != nil { + return err + } + + resolved, err := resolveProjectEndpoint(ctx, resolveProjectEndpointOpts{FlagValue: parent.projectEndpoint}) + if err != nil { + return err + } + logResolvedEndpoint("toolbox create", resolved) + + // Check whether the toolbox already exists on the service. + client, err := newToolboxClient(resolved.Endpoint) + if err != nil { + return err + } + + if _, err := client.GetToolbox(ctx, name); err == nil { + return emitCreateResult(name, true /* alreadyExists */, parent.output, verb, resolved.Endpoint) + } else if !isAzureNotFound(err) { + return exterrors.ServiceFromAzure(err, exterrors.OpGetToolbox) + } + + // New name → record a pending entry. + if err := withAzdClient(func(azdClient *azdext.AzdClient) error { + record := PendingToolbox{ + Description: verb.description, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + } + if err := setPendingToolbox(ctx, azdClient, resolved.Endpoint, name, record); err != nil { + return exterrors.Internal(exterrors.CodePendingToolboxStoreFailed, err.Error()) + } + return nil + }); err != nil { + return err + } + + return emitCreateResult(name, false, parent.output, verb, resolved.Endpoint) +} + +// emitCreateResult prints the standard one-liner or JSON envelope. +func emitCreateResult( + name string, alreadyExists bool, output string, verb toolboxCreateFlags, endpoint string, +) error { + if output == "json" { + payload := map[string]any{ + "toolbox": map[string]any{ + "name": name, + "pending": !alreadyExists, + "description": verb.description, + }, + "endpoint": endpoint, + "alreadyExists": alreadyExists, + } + return emitJSON(payload) + } + + if alreadyExists { + fmt.Printf("Toolbox %s already exists.\n", name) + fmt.Println("Next steps:") + fmt.Println(" - Run 'azd ai agent toolbox connection add' to publish a new version.") + fmt.Println(" - Run 'azd ai agent toolbox update --default-version ' to retarget the default.") + return nil + } + fmt.Printf("Registered toolbox %s (pending tools).\n", name) + fmt.Println("Next step:") + fmt.Printf(" Run 'azd ai agent toolbox connection add %s ' to publish v1.\n", name) + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_delete.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_delete.go new file mode 100644 index 00000000000..4a6c4b326a6 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_delete.go @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "log" + + "azureaiagent/internal/exterrors" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// toolboxDeleteFlags carries the verb-specific flags for `toolbox delete`. +type toolboxDeleteFlags struct { + version string + force bool +} + +// newToolboxDeleteCommand returns the `azd ai agent toolbox delete ` command. +func newToolboxDeleteCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + flags := &toolboxDeleteFlags{} + + cmd := &cobra.Command{ + Use: "delete ", + Short: "Delete a toolbox or a single version.", + Long: `Delete a toolbox or one of its versions. + +Without --version the whole toolbox is removed (cascades to every version). +With --version the named version is deleted; the CLI refuses to delete the +default version while others exist (retarget first) or — without --force — +when it is the only remaining version (which would cascade and remove the +toolbox).`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runToolboxDelete(cmd.Context(), args[0], *flags, readToolboxFlags(cmd, extCtx)) + }, + } + + cmd.Flags().StringVar( + &flags.version, "version", "", + "Delete a single version instead of the whole toolbox.", + ) + cmd.Flags().BoolVar( + &flags.force, "force", false, + "Skip confirmation prompts and override safety checks where allowed.", + ) + registerToolboxOutputFlag(cmd) + + return cmd +} + +func runToolboxDelete( + ctx context.Context, name string, verb toolboxDeleteFlags, parent toolboxFlags, +) error { + if err := validateToolboxName(name); err != nil { + return err + } + if err := validateOutputFormat(parent.output); err != nil { + return err + } + + client, resolved, err := resolveToolboxAndClient(ctx, parent) + if err != nil { + return err + } + logResolvedEndpoint("toolbox delete", resolved) + + return runToolboxDeleteWith(ctx, client, resolved.Endpoint, name, verb, parent) +} + +// runToolboxDeleteWith is the testable core. It accepts a toolboxClient +// interface so unit tests can drive the branches without an HTTP server. +func runToolboxDeleteWith( + ctx context.Context, client toolboxClient, endpoint, name string, + verb toolboxDeleteFlags, parent toolboxFlags, +) error { + if verb.version == "" { + return runDeleteToolbox(ctx, client, endpoint, name, verb, parent) + } + return runDeleteToolboxVersion(ctx, client, endpoint, name, verb, parent) +} + +// runDeleteToolbox handles `toolbox delete ` (no --version). +func runDeleteToolbox( + ctx context.Context, client toolboxClient, endpoint, name string, + verb toolboxDeleteFlags, parent toolboxFlags, +) error { + // Only the parent-toolbox delete prompts for confirmation; --no-prompt + // without --force is rejected here, not in runDeleteToolboxVersion which + // does not prompt. + if parent.noPrompt && !verb.force { + return exterrors.Validation( + exterrors.CodeMissingForceFlag, + "--no-prompt requires --force when deleting a toolbox", + "add --force to confirm the deletion non-interactively", + ) + } + return withAzdClient(func(azdClient *azdext.AzdClient) error { + // Best-effort pending lookup; a read failure is logged but non-fatal. + pending, err := getPendingToolbox(ctx, azdClient, endpoint, name) + if err != nil { + log.Printf("toolbox delete: pending-toolbox read failed for %q: %v", name, err) + } + + _, getErr := client.GetToolbox(ctx, name) + switch { + case getErr == nil: + // Live toolbox. + if !verb.force { + confirmed, err := confirmToolboxDelete(ctx, azdClient, + fmt.Sprintf("Delete toolbox %q (cascades to every version)?", name)) + if err != nil { + return err + } + if !confirmed { + fmt.Println("Aborted.") + return nil + } + } + if err := client.DeleteToolbox(ctx, name); err != nil && !isAzureNotFound(err) { + return exterrors.ServiceFromAzure(err, exterrors.OpDeleteToolbox) + } + // Best-effort clear of any local pending record (non-fatal). + if _, err := clearPendingToolbox(ctx, azdClient, endpoint, name); err != nil { + log.Printf("toolbox delete: failed to clear pending record for %q: %v", name, err) + } + return emitDeleteResult(name, "", "deleted", parent.output) + + case isAzureNotFound(getErr): + if pending != nil { + if _, err := clearPendingToolbox(ctx, azdClient, endpoint, name); err != nil { + return exterrors.Internal(exterrors.CodePendingToolboxStoreFailed, err.Error()) + } + if parent.output == "json" { + return emitDeleteResult(name, "", "pending_cleared", parent.output) + } + fmt.Printf("Cleared pending toolbox %s.\n", name) + return nil + } + return exterrors.Dependency( + exterrors.CodeToolboxNotFound, + fmt.Sprintf("toolbox %q not found at %s", name, endpoint), + "run 'azd ai agent toolbox list' to see available toolboxes", + ) + + default: + return exterrors.ServiceFromAzure(getErr, exterrors.OpGetToolbox) + } + }) +} + +// runDeleteToolboxVersion handles `toolbox delete --version `. +func runDeleteToolboxVersion( + ctx context.Context, client toolboxClient, endpoint, name string, + verb toolboxDeleteFlags, parent toolboxFlags, +) error { + tb, err := client.GetToolbox(ctx, name) + if err != nil { + return toolboxNotFoundOrService(err, name, exterrors.OpGetToolbox) + } + + cascaded := false + if verb.version == tb.DefaultVersion { + versions, err := client.ListToolboxVersions(ctx, name) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpListToolboxVersions) + } + + if len(versions) > 1 { + return exterrors.Validation( + exterrors.CodeDefaultVersionDelete, + fmt.Sprintf( + "version %q is the default for toolbox %q and other versions exist", + verb.version, name, + ), + "retarget the default with `azd ai agent toolbox update --default-version ` first", + ) + } + + // Only remaining version → cascading delete; require --force. + if !verb.force { + return exterrors.Validation( + exterrors.CodeOnlyVersionDelete, + fmt.Sprintf( + "version %q is the only remaining version of toolbox %q; "+ + "deleting it removes the toolbox", verb.version, name, + ), + fmt.Sprintf( + "run `azd ai agent toolbox delete %q` to delete the toolbox, "+ + "or pass --force to confirm", + name, + ), + ) + } + cascaded = true + } + // NOTE: non-default version delete has no confirmation prompt by design. + // We intentionally do not add one here even without --force. + + if err := client.DeleteToolboxVersion(ctx, name, verb.version); err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpDeleteToolboxVersion) + } + + if cascaded { + // Server cascaded the parent toolbox away — best-effort clear of any + // local pending record so the name doesn't linger in `toolbox list`. + _ = withAzdClient(func(azdClient *azdext.AzdClient) error { + if _, err := clearPendingToolbox(ctx, azdClient, endpoint, name); err != nil { + log.Printf( + "toolbox delete: failed to clear pending record after cascade for %q: %v", + name, err, + ) + } + return nil + }) + if parent.output == "json" { + return emitDeleteResult(name, verb.version, "toolbox_cascaded", parent.output) + } + fmt.Printf("Deleted toolbox %s (last version removed).\n", name) + return nil + } + return emitDeleteResult(name, verb.version, "version_deleted", parent.output) +} + +// confirmToolboxDelete shows a destructive-action confirmation prompt. +func confirmToolboxDelete(ctx context.Context, azdClient *azdext.AzdClient, message string) (bool, error) { + resp, err := azdClient.Prompt().Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: message, + DefaultValue: new(false), + }, + }) + if err != nil { + return false, exterrors.FromPrompt(err, "delete confirmation") + } + if resp == nil || resp.Value == nil { + return false, nil + } + return *resp.Value, nil +} + +func emitDeleteResult(name, version, outcome, output string) error { + if output == "json" { + payload := map[string]any{ + "name": name, + "version": version, + "outcome": outcome, + } + return emitJSON(payload) + } + switch outcome { + case "deleted": + fmt.Printf("Deleted toolbox %s.\n", name) + case "version_deleted": + fmt.Printf("Deleted version %s of toolbox %s.\n", version, name) + } + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_helpers_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_helpers_test.go new file mode 100644 index 00000000000..459a37e71e4 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_helpers_test.go @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "errors" + "strings" + "testing" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/azure" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// requireLocalError asserts err is an *azdext.LocalError with the given code. +func requireLocalError(t *testing.T, err error, code string) *azdext.LocalError { + t.Helper() + require.Error(t, err) + le, ok := errors.AsType[*azdext.LocalError](err) + require.True(t, ok, "expected LocalError, got %T: %v", err, err) + assert.Equal(t, code, le.Code, "code mismatch in %v", le) + return le +} + +func TestValidateToolboxName(t *testing.T) { + cases := []struct { + name string + input string + wantErr bool + }{ + {"simple", "research", false}, + {"with dash", "my-tools", false}, + {"with underscore", "my_tools", false}, + {"mixed", "Tools_v2-alpha", false}, + {"max length", strings.Repeat("a", maxToolboxNameLength), false}, + {"empty", "", true}, + {"slash", "a/b", true}, + {"space", "my tools", true}, + {"dot", "tools.v1", true}, + {"too long", strings.Repeat("a", maxToolboxNameLength+1), true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateToolboxName(tc.input) + if tc.wantErr { + requireLocalError(t, err, exterrors.CodeInvalidToolboxName) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateOutputFormat(t *testing.T) { + for _, ok := range []string{"", "table", "json", "Table", "JSON"} { + assert.NoError(t, validateOutputFormat(ok), "expected %q to be accepted", ok) + } + err := validateOutputFormat("yaml") + requireLocalError(t, err, exterrors.CodeInvalidParameter) +} + +func TestParseAccountProjectFromEndpoint(t *testing.T) { + account, project, err := parseAccountProjectFromEndpoint( + "https://my-acct.services.ai.azure.com/api/projects/my-project", + ) + require.NoError(t, err) + assert.Equal(t, "my-acct", account) + assert.Equal(t, "my-project", project) + + _, _, err = parseAccountProjectFromEndpoint("https://wrong.example.com/") + assert.Error(t, err) +} + +func TestBuildToolEntry(t *testing.T) { + t.Run("RemoteTool builds mcp entry", func(t *testing.T) { + entry, err := buildToolEntry(&projectConnection{ + ID: "/subs/x/.../connections/my-mcp", + Category: azure.ConnectionTypeRemoteTool, + Name: "my-mcp", + Target: "https://mcp.example.com", + }, "") + require.NoError(t, err) + assert.Equal(t, "mcp", entry["type"]) + assert.Equal(t, "my-mcp", entry["name"]) + assert.Equal(t, "my-mcp", entry["server_label"]) + assert.Equal(t, "https://mcp.example.com", entry["server_url"]) + assert.Equal(t, "/subs/x/.../connections/my-mcp", entry["project_connection_id"]) + }) + + t.Run("RemoteTool rejects --index", func(t *testing.T) { + _, err := buildToolEntry(&projectConnection{ + Category: azure.ConnectionTypeRemoteTool, + Name: "my-mcp", + }, "idx") + requireLocalError(t, err, exterrors.CodeUnsupportedIndexFlag) + }) + + t.Run("RemoteTool rejects empty target", func(t *testing.T) { + _, err := buildToolEntry(&projectConnection{ + ID: "/c/x", + Category: azure.ConnectionTypeRemoteTool, + Name: "x", + Target: " ", // whitespace-only is treated as empty + }, "") + le := requireLocalError(t, err, exterrors.CodeConnectionMissingTarget) + assert.Contains(t, le.Message, "target URL") + }) + + t.Run("CognitiveSearch requires --index", func(t *testing.T) { + _, err := buildToolEntry(&projectConnection{ + Category: azure.ConnectionTypeCognitiveSearch, + Name: "search", + }, "") + requireLocalError(t, err, exterrors.CodeMissingIndex) + }) + + t.Run("CognitiveSearch builds azure_ai_search entry", func(t *testing.T) { + entry, err := buildToolEntry(&projectConnection{ + ID: "/subs/x/.../connections/search", + Category: azure.ConnectionTypeCognitiveSearch, + Name: "search", + }, "products") + require.NoError(t, err) + assert.Equal(t, "azure_ai_search", entry["type"]) + search := entry["azure_ai_search"].(map[string]any) + indexes := search["indexes"].([]any) + require.Len(t, indexes, 1) + first := indexes[0].(map[string]any) + assert.Equal(t, "products", first["index_name"]) + assert.Equal(t, "/subs/x/.../connections/search", first["project_connection_id"]) + }) + + t.Run("unsupported category rejected", func(t *testing.T) { + for _, cat := range []azure.ConnectionType{ + azure.ConnectionTypeApiKey, + azure.ConnectionTypeCustomKeys, + azure.ConnectionTypeAppInsights, + } { + _, err := buildToolEntry(&projectConnection{Category: cat, Name: "x"}, "") + le := requireLocalError(t, err, exterrors.CodeUnsupportedConnectionCategory) + assert.Contains(t, le.Message, string(cat), + "expected category in message") + } + }) +} + +func TestDuplicateConnectionInTools(t *testing.T) { + tools := []map[string]any{ + {"type": "mcp", "project_connection_id": "/conn/a"}, + { + "type": "azure_ai_search", + "azure_ai_search": map[string]any{ + "indexes": []any{ + map[string]any{"project_connection_id": "/conn/b", "index_name": "x"}, + }, + }, + }, + } + assert.True(t, duplicateConnectionInTools(tools, "/conn/a")) + assert.True(t, duplicateConnectionInTools(tools, "/conn/b")) + assert.False(t, duplicateConnectionInTools(tools, "/conn/c")) +} + +func TestFilterOutConnection(t *testing.T) { + tools := []map[string]any{ + {"type": "mcp", "project_connection_id": "/conn/a", "name": "a"}, + {"type": "code_interpreter", "name": "ci"}, // built-in carries through + {"type": "mcp", "project_connection_id": "/conn/b", "name": "b"}, + { + "type": "azure_ai_search", + "name": "s", + "azure_ai_search": map[string]any{ + "indexes": []any{ + map[string]any{"project_connection_id": "/conn/c"}, + }, + }, + }, + } + got, removed := filterOutConnection(tools, "/conn/a") + assert.True(t, removed) + assert.Len(t, got, 3) + for _, e := range got { + assert.NotEqual(t, "/conn/a", e["project_connection_id"]) + } + + // Removing missing connection: removed=false, slice unchanged in length. + got2, removed2 := filterOutConnection(tools, "/conn/zzz") + assert.False(t, removed2) + assert.Len(t, got2, 4) + + // Removing nested search connection. + got3, removed3 := filterOutConnection(tools, "/conn/c") + assert.True(t, removed3) + assert.Len(t, got3, 3) +} + +func TestShortConnectionName(t *testing.T) { + assert.Equal(t, "my-mcp", shortConnectionName("/subs/x/connections/my-mcp")) + assert.Equal(t, "plain", shortConnectionName("plain")) + assert.Equal(t, "", shortConnectionName("")) +} + +func TestBuildToolboxMcpURL(t *testing.T) { + got := buildToolboxMcpURL("https://acct.services.ai.azure.com/api/projects/p", "research", "3") + assert.Equal(t, + "https://acct.services.ai.azure.com/api/projects/p/toolboxes/research/versions/3/mcp?api-version=v1", + got, + ) + + // Service-supplied version strings could in theory contain unsafe URL chars. + // Both segments must be PathEscaped so downstream consumers can use the URL + // without parsing surprises. + escaped := buildToolboxMcpURL( + "https://acct.services.ai.azure.com/api/projects/p", + "research", + "v 1/2", // space and slash require escaping + ) + assert.Contains(t, escaped, "versions/v%201%2F2/mcp") +} + +func TestEndpointBucketKey(t *testing.T) { + a := endpointBucketKey("https://acct.example.com/api/projects/p") + b := endpointBucketKey("https://acct.example.com/api/projects/p/") // trailing slash + assert.Equal(t, a, b, "trailing slash must not change bucket key") + assert.Len(t, a, 16, "bucket key length is pinned to 16 hex chars") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_list.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_list.go new file mode 100644 index 00000000000..e2ed916e9a8 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_list.go @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "cmp" + "context" + "fmt" + "log" + "maps" + "os" + "slices" + "text/tabwriter" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/azure" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// newToolboxListCommand returns the `azd ai agent toolbox list` command. +func newToolboxListCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + cmd := &cobra.Command{ + Use: "list", + Short: "List toolboxes on the project, plus any local pending records.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return runToolboxList(cmd.Context(), readToolboxFlags(cmd, extCtx)) + }, + } + registerToolboxOutputFlag(cmd) + return cmd +} + +func runToolboxList(ctx context.Context, parent toolboxFlags) error { + if err := validateOutputFormat(parent.output); err != nil { + return err + } + + client, resolved, err := resolveToolboxAndClient(ctx, parent) + if err != nil { + return err + } + logResolvedEndpoint("toolbox list", resolved) + + return runToolboxListWith(ctx, client, resolved.Endpoint, parent) +} + +// runToolboxListWith is the testable core. +func runToolboxListWith( + ctx context.Context, client toolboxClient, endpoint string, parent toolboxFlags, +) error { + live, err := client.ListToolboxes(ctx) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpListToolboxes) + } + + // Best-effort merge of pending records; failures are non-fatal. + var pending map[string]PendingToolbox + azdClient, azdErr := azdext.NewAzdClient() + if azdErr != nil { + log.Printf("toolbox list: azd client unavailable, skipping pending merge: %v", azdErr) + } else { + defer azdClient.Close() + items, readErr := listPendingToolboxes(ctx, azdClient, endpoint) + if readErr != nil { + log.Printf("toolbox list: pending-toolbox read failed: %v", readErr) + } else { + pending = items + } + } + + // Drop pending records that already exist live-side. The pending record + // is normally cleared when `connection add` publishes v1, but a clear + // failure (logged in connection add) can leave a stale entry. This dedup + // makes the list output self-healing. + liveNames := map[string]struct{}{} + for _, t := range live { + liveNames[t.Name] = struct{}{} + } + for k := range pending { + if _, dup := liveNames[k]; dup { + delete(pending, k) + } + } + + if parent.output == "json" { + return emitListJSON(live, pending) + } + return emitListTable(ctx, client, live, pending) +} + +func emitListJSON(live []azure.ToolboxObject, pending map[string]PendingToolbox) error { + toolboxes := make([]map[string]any, 0, len(live)+len(pending)) + for _, t := range live { + toolboxes = append(toolboxes, map[string]any{ + "id": t.ID, + "name": t.Name, + "default_version": t.DefaultVersion, + "pending": false, + }) + } + for _, k := range slices.Sorted(maps.Keys(pending)) { + p := pending[k] + toolboxes = append(toolboxes, map[string]any{ + "name": k, + "pending": true, + "description": p.Description, + "createdAt": p.CreatedAt, + }) + } + return emitJSON(map[string]any{"toolboxes": toolboxes}) +} + +// emitListTable produces NAME / DEFAULT-VERSION / STATE / CREATED. The table +// intentionally omits a TOOLS count to avoid an extra fetch per row; use +// `toolbox show` to see tools for a single toolbox. +func emitListTable( + _ context.Context, _ toolboxClient, + live []azure.ToolboxObject, pending map[string]PendingToolbox, +) error { + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "NAME\tDEFAULT-VERSION\tSTATE\tCREATED") + fmt.Fprintln(w, "----\t---------------\t-----\t-------") + + sortedLive := slices.Clone(live) + slices.SortFunc(sortedLive, func(a, b azure.ToolboxObject) int { + return cmp.Compare(a.Name, b.Name) + }) + + for _, t := range sortedLive { + fmt.Fprintf(w, "%s\t%s\tlive\t-\n", t.Name, t.DefaultVersion) + } + + for _, name := range slices.Sorted(maps.Keys(pending)) { + fmt.Fprintf(w, "%s\t-\tpending\t%s\n", name, pending[name].CreatedAt) + } + + return w.Flush() +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_shared.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_shared.go new file mode 100644 index 00000000000..16fd1750f4e --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_shared.go @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "encoding/json" + "fmt" + + "azureaiagent/internal/exterrors" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +// toolboxNotFoundOrService maps a GetToolbox / GetToolboxVersion error to the +// right structured error: Dependency(CodeToolboxNotFound) on 404, ServiceError +// otherwise. +func toolboxNotFoundOrService(err error, name, op string) error { + if isAzureNotFound(err) { + return exterrors.Dependency( + exterrors.CodeToolboxNotFound, + fmt.Sprintf("toolbox %q not found", name), + "run 'azd ai agent toolbox list' to see available toolboxes", + ) + } + return exterrors.ServiceFromAzure(err, op) +} + +// forEachToolConnectionID invokes fn for every project_connection_id reference +// in tools[] (top-level on mcp entries, nested under azure_ai_search.indexes +// on search entries). fn returns true to stop early. +func forEachToolConnectionID(tools []map[string]any, fn func(connID string) bool) { + for _, t := range tools { + if toolEntryReferences(t, func(id string) bool { return fn(id) }) { + return + } + } +} + +// toolEntryReferences runs match against every connection ID referenced by a +// single tool entry and returns true on the first hit. +func toolEntryReferences(t map[string]any, match func(connID string) bool) bool { + if id, ok := t["project_connection_id"].(string); ok && id != "" && match(id) { + return true + } + search, ok := t["azure_ai_search"].(map[string]any) + if !ok { + return false + } + indexes, ok := search["indexes"].([]any) + if !ok { + return false + } + for _, idx := range indexes { + m, ok := idx.(map[string]any) + if !ok { + continue + } + if id, ok := m["project_connection_id"].(string); ok && id != "" && match(id) { + return true + } + } + return false +} + +// emitJSON marshals payload as indented JSON to stdout. +func emitJSON(payload any) error { + data, err := json.MarshalIndent(payload, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal result: %w", err) + } + fmt.Println(string(data)) + return nil +} + +// withAzdClient opens the azd client, invokes fn, and closes the client. +// A client-open failure is surfaced as Internal(CodeAzdClientFailed). +func withAzdClient(fn func(c *azdext.AzdClient) error) error { + c, err := azdext.NewAzdClient() + if err != nil { + return exterrors.Internal( + exterrors.CodeAzdClientFailed, + fmt.Sprintf("failed to create azd client: %s", err), + ) + } + defer c.Close() + return fn(c) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_show.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_show.go new file mode 100644 index 00000000000..d06487acaad --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_show.go @@ -0,0 +1,269 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + "strings" + "text/tabwriter" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/azure" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// toolboxShowFlags carries the verb-specific flags for `toolbox show`. +type toolboxShowFlags struct { + version string +} + +// newToolboxShowCommand returns the `azd ai agent toolbox show ` command. +func newToolboxShowCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + flags := &toolboxShowFlags{} + + cmd := &cobra.Command{ + Use: "show ", + Short: "Show a toolbox version, including its computed MCP endpoint.", + Long: `Show a toolbox. + +By default shows the default version. Use --version to inspect a specific +version. The output includes the toolbox's runtime MCP endpoint, which agents +consume via the TOOLBOX__ENDPOINT environment variable convention. + +If the toolbox exists only as a pending local record (no version published +yet), the command emits a pending-toolbox view and rejects --version.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runToolboxShow(cmd.Context(), args[0], *flags, readToolboxFlags(cmd, extCtx)) + }, + } + + cmd.Flags().StringVar( + &flags.version, "version", "", + "Specific version to show. Defaults to the server's default_version.", + ) + registerToolboxOutputFlag(cmd) + + return cmd +} + +func runToolboxShow( + ctx context.Context, name string, verb toolboxShowFlags, parent toolboxFlags, +) error { + if err := validateToolboxName(name); err != nil { + return err + } + if err := validateOutputFormat(parent.output); err != nil { + return err + } + + client, resolved, err := resolveToolboxAndClient(ctx, parent) + if err != nil { + return err + } + logResolvedEndpoint("toolbox show", resolved) + + return runToolboxShowWith(ctx, client, resolved.Endpoint, name, verb, parent) +} + +// runToolboxShowWith is the testable core. +func runToolboxShowWith( + ctx context.Context, client toolboxClient, endpoint, name string, + verb toolboxShowFlags, parent toolboxFlags, +) error { + tb, err := client.GetToolbox(ctx, name) + if err != nil { + if isAzureNotFound(err) { + return showPendingOrNotFound(ctx, endpoint, name, verb, parent) + } + return exterrors.ServiceFromAzure(err, exterrors.OpGetToolbox) + } + + shownVersion := verb.version + if shownVersion == "" { + shownVersion = tb.DefaultVersion + } + + version, err := client.GetToolboxVersion(ctx, name, shownVersion) + if err != nil { + if isAzureNotFound(err) { + return exterrors.Dependency( + exterrors.CodeToolboxNotFound, + fmt.Sprintf("version %q of toolbox %q not found", shownVersion, name), + fmt.Sprintf("run 'azd ai agent toolbox show %q' to see the default version", name), + ) + } + return exterrors.ServiceFromAzure(err, exterrors.OpGetToolboxVersion) + } + + mcpURL := buildToolboxMcpURL(endpoint, name, shownVersion) + + if parent.output == "json" { + return emitShowJSON(tb, version, mcpURL) + } + return emitShowTable(tb, version, mcpURL) +} + +// showPendingOrNotFound handles the 404 branch: either render the pending-toolbox +// view or surface a structured Dependency(CodeToolboxNotFound). +func showPendingOrNotFound( + ctx context.Context, endpoint, name string, + verb toolboxShowFlags, parent toolboxFlags, +) error { + return withAzdClient(func(azdClient *azdext.AzdClient) error { + pending, err := getPendingToolbox(ctx, azdClient, endpoint, name) + if err != nil { + log.Printf("toolbox show: pending-toolbox read failed for %q: %v", name, err) + } + if pending == nil { + return exterrors.Dependency( + exterrors.CodeToolboxNotFound, + fmt.Sprintf("toolbox %q not found at %s", name, endpoint), + "run 'azd ai agent toolbox list' to see available toolboxes", + ) + } + + return renderPendingShow(name, verb, parent, pending) + }) +} + +// renderPendingShow emits the pending-toolbox view. +func renderPendingShow( + name string, verb toolboxShowFlags, parent toolboxFlags, pending *PendingToolbox, +) error { + + if verb.version != "" { + return exterrors.Validation( + exterrors.CodeMissingUpdateField, + fmt.Sprintf( + "toolbox %q has no published versions yet; --version cannot be used", + name, + ), + fmt.Sprintf( + "run 'azd ai agent toolbox connection add %q ' to publish v1 first", + name, + ), + ) + } + + if parent.output == "json" { + payload := map[string]any{ + "toolbox": map[string]any{ + "name": name, + "pending": true, + "description": pending.Description, + "createdAt": pending.CreatedAt, + }, + "version": nil, + "endpoint": nil, + } + return emitJSON(payload) + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "FIELD\tVALUE") + fmt.Fprintln(w, "-----\t-----") + fmt.Fprintf(w, "Name\t%s\n", name) + fmt.Fprintf(w, "State\tpending\n") + fmt.Fprintf(w, "Description\t%s\n", pending.Description) + fmt.Fprintf(w, "Created\t%s\n", pending.CreatedAt) + if err := w.Flush(); err != nil { + return err + } + fmt.Printf( + "\nRun `azd ai agent toolbox connection add %q ` to publish v1.\n", + name, + ) + return nil +} + +// buildToolboxMcpURL computes the runtime MCP consumption URL. +// version is service-supplied so both segments are PathEscaped. +func buildToolboxMcpURL(endpoint, name, version string) string { + return fmt.Sprintf( + "%s/toolboxes/%s/versions/%s/mcp?api-version=v1", + strings.TrimRight(endpoint, "/"), + url.PathEscape(name), + url.PathEscape(version), + ) +} + +// emitShowJSON prints the JSON envelope for `toolbox show`. +func emitShowJSON( + tb *azure.ToolboxObject, version *azure.ToolboxVersionObject, mcpURL string, +) error { + return emitJSON(map[string]any{ + "toolbox": tb, + "version": version, + "endpoint": mcpURL, + }) +} + +// emitShowTable renders the table form of `toolbox show`. +func emitShowTable( + tb *azure.ToolboxObject, version *azure.ToolboxVersionObject, mcpURL string, +) error { + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "FIELD\tVALUE") + fmt.Fprintln(w, "-----\t-----") + fmt.Fprintf(w, "Name\t%s\n", tb.Name) + fmt.Fprintf(w, "Default version\t%s\n", tb.DefaultVersion) + fmt.Fprintf(w, "Shown version\t%s\n", version.Version) + fmt.Fprintf(w, "Description\t%s\n", version.Description) + fmt.Fprintf(w, "Endpoint\t%s\n", mcpURL) + fmt.Fprintf(w, "Tools\t%d\n", len(version.Tools)) + if err := w.Flush(); err != nil { + return err + } + + if len(version.Tools) > 0 { + fmt.Println() + tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "TOOL\tTYPE\tDETAIL") + fmt.Fprintln(tw, "----\t----\t------") + for _, tool := range version.Tools { + toolName, _ := tool["name"].(string) + toolType, _ := tool["type"].(string) + detail := describeToolDetail(toolType, tool) + fmt.Fprintf(tw, "%s\t%s\t%s\n", toolName, toolType, detail) + } + if err := tw.Flush(); err != nil { + return err + } + } + return nil +} + +// describeToolDetail returns the per-tool annotation used in the show table: +// "(builtin)" for first-party tools and "(connection:)" for connection-backed entries. +func describeToolDetail(toolType string, tool map[string]any) string { + switch toolType { + case "code_interpreter", "web_search", "file_search": + return "(builtin)" + case "mcp", "azure_ai_search": + if id := firstConnectionID(tool); id != "" { + return "(connection:" + id + ")" + } + } + return "" +} + +// firstConnectionID returns the first project_connection_id referenced by a +// tool entry — top-level on `mcp` tools, or nested under +// azure_ai_search.indexes[] for search tools. +func firstConnectionID(tool map[string]any) string { + var found string + toolEntryReferences(tool, func(id string) bool { + found = id + return true // stop at the first hit + }) + return found +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_test_helpers_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_test_helpers_test.go new file mode 100644 index 00000000000..c6a1e82ea88 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_test_helpers_test.go @@ -0,0 +1,259 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "net/http" + "net/url" + "sync" + + "azureaiagent/internal/pkg/azure" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// mockToolboxClient is a test stub for the toolboxClient interface. Each +// method returns a configured value/error and records call shape; mu keeps +// race-detector runs clean. +type mockToolboxClient struct { + mu sync.Mutex + + endpoint string + + getResults map[string]toolboxGetResult + versionResults map[string]toolboxVersionResult + listToolboxesResult []azure.ToolboxObject + listToolboxesErr error + listVersionsResults map[string][]azure.ToolboxVersionObject + listVersionsErr error + createVersionResult *azure.ToolboxVersionObject + createVersionErr error + setDefaultResult *azure.ToolboxObject + setDefaultErr error + deleteToolboxErr error + deleteToolboxVersionErr error + + createVersionCalls []createVersionCall + setDefaultCalls []setDefaultCall + deleteCalls []deleteCall + deleteVersionCalls []deleteVersionCall +} + +type toolboxGetResult struct { + obj *azure.ToolboxObject + err error +} + +type toolboxVersionResult struct { + obj *azure.ToolboxVersionObject + err error +} + +type createVersionCall struct { + name string + req *azure.CreateToolboxVersionRequest +} + +type setDefaultCall struct { + name, version string +} + +type deleteCall struct { + name string +} + +type deleteVersionCall struct { + name, version string +} + +// newMockToolboxClient seeds an empty mock bound to the given endpoint. +func newMockToolboxClient(endpoint string) *mockToolboxClient { + return &mockToolboxClient{ + endpoint: endpoint, + getResults: map[string]toolboxGetResult{}, + versionResults: map[string]toolboxVersionResult{}, + listVersionsResults: map[string][]azure.ToolboxVersionObject{}, + } +} + +func (m *mockToolboxClient) Endpoint() string { return m.endpoint } + +func (m *mockToolboxClient) GetToolbox(_ context.Context, name string) (*azure.ToolboxObject, error) { + m.mu.Lock() + defer m.mu.Unlock() + r, ok := m.getResults[name] + if !ok { + return nil, notFoundResponseError("toolbox " + name + " not found") + } + return r.obj, r.err +} + +func (m *mockToolboxClient) CreateToolboxVersion( + _ context.Context, name string, req *azure.CreateToolboxVersionRequest, +) (*azure.ToolboxVersionObject, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.createVersionCalls = append(m.createVersionCalls, createVersionCall{name: name, req: req}) + if m.createVersionErr != nil { + return nil, m.createVersionErr + } + if m.createVersionResult != nil { + return m.createVersionResult, nil + } + // Default: synthesize a new version object based on the request length. + return &azure.ToolboxVersionObject{ + Name: name, + Version: fmt.Sprintf("v%d", len(m.createVersionCalls)), + Description: req.Description, + Metadata: req.Metadata, + Tools: req.Tools, + }, nil +} + +func (m *mockToolboxClient) DeleteToolbox(_ context.Context, name string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.deleteCalls = append(m.deleteCalls, deleteCall{name: name}) + return m.deleteToolboxErr +} + +func (m *mockToolboxClient) ListToolboxes(_ context.Context) ([]azure.ToolboxObject, error) { + m.mu.Lock() + defer m.mu.Unlock() + return m.listToolboxesResult, m.listToolboxesErr +} + +func (m *mockToolboxClient) GetToolboxVersion( + _ context.Context, name, version string, +) (*azure.ToolboxVersionObject, error) { + m.mu.Lock() + defer m.mu.Unlock() + key := name + "/" + version + r, ok := m.versionResults[key] + if !ok { + return nil, notFoundResponseError("version " + key + " not found") + } + return r.obj, r.err +} + +func (m *mockToolboxClient) ListToolboxVersions( + _ context.Context, name string, +) ([]azure.ToolboxVersionObject, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.listVersionsErr != nil { + return nil, m.listVersionsErr + } + return m.listVersionsResults[name], nil +} + +func (m *mockToolboxClient) DeleteToolboxVersion(_ context.Context, name, version string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.deleteVersionCalls = append(m.deleteVersionCalls, deleteVersionCall{name: name, version: version}) + return m.deleteToolboxVersionErr +} + +func (m *mockToolboxClient) SetDefaultVersion( + _ context.Context, name, version string, +) (*azure.ToolboxObject, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.setDefaultCalls = append(m.setDefaultCalls, setDefaultCall{name: name, version: version}) + if m.setDefaultErr != nil { + return nil, m.setDefaultErr + } + if m.setDefaultResult != nil { + return m.setDefaultResult, nil + } + return &azure.ToolboxObject{Name: name, DefaultVersion: version}, nil +} + +// notFoundResponseError builds a synthetic *azcore.ResponseError with HTTP 404 +// and a fully-populated http.Request so isAzureNotFound returns true and +// downstream URL-aware formatters do not panic. +func notFoundResponseError(message string) error { + stubURL, _ := url.Parse("https://stub.test/synthetic-404") + return &azcore.ResponseError{ + StatusCode: http.StatusNotFound, + ErrorCode: message, + RawResponse: &http.Response{ + StatusCode: http.StatusNotFound, + Request: &http.Request{ + Host: "stub.test", + Method: http.MethodGet, + URL: stubURL, + }, + }, + } +} + +// stubConnectionResolver is the connectionResolver test fake. +type stubConnectionResolver struct { + byName map[string]*projectConnection + err map[string]error +} + +func newStubConnectionResolver() *stubConnectionResolver { + return &stubConnectionResolver{ + byName: map[string]*projectConnection{}, + err: map[string]error{}, + } +} + +func (s *stubConnectionResolver) resolveConnection( + _ context.Context, _ string, name string, +) (*projectConnection, error) { + if e, ok := s.err[name]; ok { + return nil, e + } + if c, ok := s.byName[name]; ok { + return c, nil + } + return nil, connectionNotFoundError(name) +} + +// compile-time guard. +var _ toolboxClient = (*mockToolboxClient)(nil) +var _ connectionResolver = (*stubConnectionResolver)(nil) +var _ pendingToolboxStore = (*stubPendingStore)(nil) + +// stubPendingStore is the in-memory pendingToolboxStore for unit tests. +// getErr/clearErr inject failures to exercise error-handling branches. +type stubPendingStore struct { + records map[string]*PendingToolbox + getErr error + clearErr error + getCalls int + clearCalls int +} + +func newStubPendingStore() *stubPendingStore { + return &stubPendingStore{records: map[string]*PendingToolbox{}} +} + +func (s *stubPendingStore) key(endpoint, name string) string { + return endpoint + "::" + name +} + +func (s *stubPendingStore) Get(_ context.Context, endpoint, name string) (*PendingToolbox, error) { + s.getCalls++ + if s.getErr != nil { + return nil, s.getErr + } + return s.records[s.key(endpoint, name)], nil +} + +func (s *stubPendingStore) Clear(_ context.Context, endpoint, name string) (bool, error) { + s.clearCalls++ + if s.clearErr != nil { + return false, s.clearErr + } + k := s.key(endpoint, name) + _, ok := s.records[k] + delete(s.records, k) + return ok, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_update.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_update.go new file mode 100644 index 00000000000..80ea436d756 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/toolbox_update.go @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "strings" + + "azureaiagent/internal/exterrors" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// toolboxUpdateFlags carries the verb-specific flags for `toolbox update`. +type toolboxUpdateFlags struct { + defaultVersion string +} + +// newToolboxUpdateCommand returns the `azd ai agent toolbox update ` command. +// Only --default-version is supported. +func newToolboxUpdateCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + extCtx = ensureExtensionContext(extCtx) + flags := &toolboxUpdateFlags{} + + cmd := &cobra.Command{ + Use: "update ", + Short: "Update a toolbox (currently: retarget the default version).", + Long: `Update a toolbox. + +Only --default-version is supported today. To change the tool list, publish a +new version with 'connection add' or 'connection remove'.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runToolboxUpdate(cmd.Context(), args[0], *flags, readToolboxFlags(cmd, extCtx)) + }, + } + + cmd.Flags().StringVar( + &flags.defaultVersion, "default-version", "", + "Version string to mark as the default for this toolbox.", + ) + registerToolboxOutputFlag(cmd) + + return cmd +} + +func runToolboxUpdate( + ctx context.Context, name string, verb toolboxUpdateFlags, parent toolboxFlags, +) error { + if err := validateToolboxName(name); err != nil { + return err + } + if err := validateOutputFormat(parent.output); err != nil { + return err + } + + if strings.TrimSpace(verb.defaultVersion) == "" { + return exterrors.Validation( + exterrors.CodeMissingUpdateField, + "no fields to update", + "specify --default-version", + ) + } + + client, resolved, err := resolveToolboxAndClient(ctx, parent) + if err != nil { + return err + } + logResolvedEndpoint("toolbox update", resolved) + + result, err := client.SetDefaultVersion(ctx, name, verb.defaultVersion) + if err != nil { + return toolboxNotFoundOrService(err, name, exterrors.OpSetDefaultVersion) + } + + if parent.output == "json" { + return emitJSON(result) + } + fmt.Printf("Toolbox %s default version set to %s.\n", name, result.DefaultVersion) + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go index 54a63bc10a3..f4a0874d850 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go +++ b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go @@ -104,8 +104,24 @@ const ( // Error codes for toolbox operations. const ( - CodeInvalidToolbox = "invalid_toolbox" - CodeCreateToolboxVersionFailed = "create_toolbox_version_failed" + CodeInvalidToolbox = "invalid_toolbox" + CodeCreateToolboxVersionFailed = "create_toolbox_version_failed" + CodeToolboxNotFound = "toolbox_not_found" + CodeMissingUpdateField = "missing_update_field" + CodeDefaultVersionDelete = "default_version_delete" + CodeOnlyVersionDelete = "only_version_delete" + CodeUnsupportedConnectionCategory = "unsupported_connection_category" + CodeMissingIndex = "missing_index" + CodeUnsupportedIndexFlag = "unsupported_index_flag" + CodeDuplicateConnection = "duplicate_connection" + CodeConnectionNotFound = "connection_not_found" + CodeConnectionNotInToolbox = "connection_not_in_toolbox" + CodeConnectionMissingTarget = "connection_missing_target" + CodeLastToolRemoval = "last_tool_removal" + CodeMissingForceFlag = "missing_force_flag" + CodeInvalidToolboxName = "invalid_toolbox_name" + CodePendingToolboxStoreFailed = "pending_toolbox_store_failed" + CodeSetDefaultVersionFailed = "set_default_version_failed" ) // Error codes for connection operations. @@ -144,17 +160,24 @@ const ( // Operation names for [ServiceFromAzure] errors. // These are prefixed to the Azure error code (e.g., "create_agent.NotFound"). const ( - OpGetFoundryProject = "get_foundry_project" - OpContainerBuild = "container_build" - OpContainerPackage = "container_package" - OpContainerPublish = "container_publish" - OpCreateAgent = "create_agent" - OpStartContainer = "start_container" - OpGetContainerOperation = "get_container_operation" - OpCreateSession = "create_session" - OpGetSession = "get_session" - OpDeleteSession = "delete_session" - OpListSessions = "list_sessions" - OpCreateToolboxVersion = "create_toolbox_version" - OpGetToolbox = "get_toolbox" + OpGetFoundryProject = "get_foundry_project" + OpContainerBuild = "container_build" + OpContainerPackage = "container_package" + OpContainerPublish = "container_publish" + OpCreateAgent = "create_agent" + OpStartContainer = "start_container" + OpGetContainerOperation = "get_container_operation" + OpCreateSession = "create_session" + OpGetSession = "get_session" + OpDeleteSession = "delete_session" + OpListSessions = "list_sessions" + OpCreateToolboxVersion = "create_toolbox_version" + OpGetToolbox = "get_toolbox" + OpDeleteToolbox = "delete_toolbox" + OpDeleteToolboxVersion = "delete_toolbox_version" + OpSetDefaultVersion = "set_default_version" + OpListToolboxes = "list_toolboxes" + OpGetToolboxVersion = "get_toolbox_version" + OpListToolboxVersions = "list_toolbox_versions" + OpResolveProjectConnection = "resolve_project_connection" ) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_projects_client.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_projects_client.go index 9bcbafc02e4..69cad9d68b2 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_projects_client.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_projects_client.go @@ -161,6 +161,41 @@ func (c *FoundryProjectsClient) GetPagedConnections(ctx context.Context) (*Paged return &pagedConnections, nil } +// GetConnection retrieves a specific connection by name without surfacing +// credential material. Returns an azcore.ResponseError with StatusCode 404 +// when the connection is missing on the project. +func (c *FoundryProjectsClient) GetConnection(ctx context.Context, name string) (*Connection, error) { + targetEndpoint := fmt.Sprintf( + "%s/connections/%s?api-version=%s", + c.baseEndpoint, url.PathEscape(name), c.apiVersion) + + req, err := runtime.NewRequest(ctx, http.MethodGet, targetEndpoint) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var connection Connection + if err := json.Unmarshal(body, &connection); err != nil { + return nil, fmt.Errorf("failed to unmarshal connection response: %w", err) + } + return &connection, nil +} + // GetConnectionWithCredentials retrieves a specific connection with its credentials func (c *FoundryProjectsClient) GetConnectionWithCredentials(ctx context.Context, name string) (*Connection, error) { targetEndpoint := fmt.Sprintf( diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go index 7c0ae152a82..7b409a43ee2 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "net/url" "strings" @@ -17,6 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/azure/azure-dev/cli/azd/pkg/azsdk" "azureaiagent/internal/version" @@ -65,6 +67,121 @@ func NewFoundryToolboxClient( } } +// Endpoint returns the toolbox endpoint root used by this client (without trailing slash). +// Used by the CLI to compute the runtime MCP consumption URL surfaced by `toolbox show`. +func (c *FoundryToolboxClient) Endpoint() string { + return c.endpoint +} + +// doJSON sends `method url` with an optional JSON body and decodes the response +// body into `out` (pass nil to discard). `okCodes` selects which HTTP status +// codes count as success; defaults to {200} when empty. The Foundry-Features +// header is set on every request. +func (c *FoundryToolboxClient) doJSON( + ctx context.Context, method, target string, body any, out any, okCodes ...int, +) error { + if len(okCodes) == 0 { + okCodes = []int{http.StatusOK} + } + + req, err := runtime.NewRequest(ctx, method, target) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) + + if body != nil { + payload, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + if err := req.SetBody( + streaming.NopCloser(bytes.NewReader(payload)), + "application/json", + ); err != nil { + return fmt.Errorf("failed to set request body: %w", err) + } + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, okCodes...) { + return runtime.NewResponseError(resp) + } + + if out == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + if len(raw) == 0 { + return nil + } + if err := json.Unmarshal(raw, out); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + return nil +} + +// listPagedFromClient walks `cursor`-style pagination on a toolbox endpoint. +// Subsequent pages use the last item's id as the `after` cursor while +// `has_more=true`. Capped at maxPaginationPages to guard against a server that +// keeps reporting has_more. +func listPagedFromClient[T any]( + ctx context.Context, c *FoundryToolboxClient, initialURL string, + pickLastID func(t T) string, +) ([]T, error) { + type page struct { + Data []T `json:"data"` + HasMore bool `json:"has_more,omitempty"` + LastID string `json:"last_id,omitempty"` + } + + out := []T{} + target := initialURL + for range maxPaginationPages { + var p page + if err := c.doJSON(ctx, http.MethodGet, target, nil, &p); err != nil { + return nil, err + } + out = append(out, p.Data...) + if !p.HasMore || len(p.Data) == 0 { + return out, nil + } + last := p.LastID + if last == "" && pickLastID != nil { + last = pickLastID(p.Data[len(p.Data)-1]) + } + if last == "" { + // HasMore=true with no cursor: log a warning and return the partial + // results rather than spin. Callers may receive incomplete data. + log.Printf( + "foundry_toolsets_client: pagination has_more=true but no cursor for %s; returning %d items", + initialURL, len(out), + ) + return out, nil + } + sep := "&" + if !strings.Contains(target, "?") { + sep = "?" + } + target = initialURL + sep + "after=" + url.QueryEscape(last) + } + return out, fmt.Errorf( + "pagination cap reached: more than %d pages returned for %s", + maxPaginationPages, initialURL, + ) +} + +const maxPaginationPages = 1000 + // CreateToolboxVersionRequest is the request body for creating a new toolbox version. // The toolbox name is provided in the URL path, not in the body. type CreateToolboxVersionRequest struct { @@ -75,14 +192,14 @@ type CreateToolboxVersionRequest struct { // ToolboxObject is the lightweight response for a toolbox (no tools list). type ToolboxObject struct { - Id string `json:"id"` + ID string `json:"id"` Name string `json:"name"` DefaultVersion string `json:"default_version"` } // ToolboxVersionObject is the response for a specific toolbox version. type ToolboxVersionObject struct { - Id string `json:"id"` + ID string `json:"id"` Name string `json:"name"` Version string `json:"version"` Description string `json:"description,omitempty"` @@ -91,126 +208,111 @@ type ToolboxVersionObject struct { Tools []map[string]any `json:"tools"` } +// toolboxURL builds the canonical toolboxes URL with the api-version query. +// Path segments are escaped; callers must not pre-escape. +func (c *FoundryToolboxClient) toolboxURL(parts ...string) string { + escaped := make([]string, len(parts)) + for i, p := range parts { + escaped[i] = url.PathEscape(p) + } + tail := strings.Join(escaped, "/") + if tail != "" { + tail = "/" + tail + } + return fmt.Sprintf("%s/toolboxes%s?api-version=%s", c.endpoint, tail, toolboxesApiVersion) +} + // CreateToolboxVersion creates a new version of a toolbox. // If the toolbox does not exist, it will be created automatically. func (c *FoundryToolboxClient) CreateToolboxVersion( - ctx context.Context, - toolboxName string, - request *CreateToolboxVersionRequest, + ctx context.Context, toolboxName string, request *CreateToolboxVersionRequest, ) (*ToolboxVersionObject, error) { - targetUrl := fmt.Sprintf( - "%s/toolboxes/%s/versions?api-version=%s", - c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, - ) - - payload, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := runtime.NewRequest(ctx, http.MethodPost, targetUrl) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) - - if err := req.SetBody( - streaming.NopCloser(bytes.NewReader(payload)), - "application/json", - ); err != nil { - return nil, fmt.Errorf("failed to set request body: %w", err) - } - - resp, err := c.pipeline.Do(req) - if err != nil { - return nil, fmt.Errorf("HTTP request failed: %w", err) - } - defer resp.Body.Close() - - if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) { - return nil, runtime.NewResponseError(resp) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - + target := c.toolboxURL(toolboxName, "versions") var result ToolboxVersionObject - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) + if err := c.doJSON( + ctx, http.MethodPost, target, request, &result, + http.StatusOK, http.StatusCreated, + ); err != nil { + return nil, err } - return &result, nil } // GetToolbox retrieves a toolbox by name. func (c *FoundryToolboxClient) GetToolbox( - ctx context.Context, - toolboxName string, + ctx context.Context, toolboxName string, ) (*ToolboxObject, error) { - targetUrl := fmt.Sprintf( - "%s/toolboxes/%s?api-version=%s", - c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, - ) - - req, err := runtime.NewRequest(ctx, http.MethodGet, targetUrl) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) - - resp, err := c.pipeline.Do(req) - if err != nil { - return nil, fmt.Errorf("HTTP request failed: %w", err) - } - defer resp.Body.Close() - - if !runtime.HasStatusCode(resp, http.StatusOK) { - return nil, runtime.NewResponseError(resp) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - var result ToolboxObject - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) + if err := c.doJSON( + ctx, http.MethodGet, c.toolboxURL(toolboxName), nil, &result, + ); err != nil { + return nil, err } - return &result, nil } // DeleteToolbox deletes a toolbox and all its versions. -func (c *FoundryToolboxClient) DeleteToolbox( - ctx context.Context, - toolboxName string, -) error { - targetUrl := fmt.Sprintf( - "%s/toolboxes/%s?api-version=%s", - c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, +func (c *FoundryToolboxClient) DeleteToolbox(ctx context.Context, toolboxName string) error { + return c.doJSON( + ctx, http.MethodDelete, c.toolboxURL(toolboxName), nil, nil, + http.StatusOK, http.StatusNoContent, ) +} - req, err := runtime.NewRequest(ctx, http.MethodDelete, targetUrl) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) +// ListToolboxes returns every toolbox visible on the project endpoint by walking pagination. +func (c *FoundryToolboxClient) ListToolboxes(ctx context.Context) ([]ToolboxObject, error) { + return listPagedFromClient( + ctx, c, c.toolboxURL(), + func(t ToolboxObject) string { return t.ID }, + ) +} + +// GetToolboxVersion fetches the full version body, including tools[]. +func (c *FoundryToolboxClient) GetToolboxVersion( + ctx context.Context, toolboxName, version string, +) (*ToolboxVersionObject, error) { + var result ToolboxVersionObject + if err := c.doJSON( + ctx, http.MethodGet, c.toolboxURL(toolboxName, "versions", version), + nil, &result, + ); err != nil { + return nil, err } + return &result, nil +} - req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) +// ListToolboxVersions returns all version summaries for the named toolbox. +func (c *FoundryToolboxClient) ListToolboxVersions( + ctx context.Context, toolboxName string, +) ([]ToolboxVersionObject, error) { + return listPagedFromClient( + ctx, c, c.toolboxURL(toolboxName, "versions"), + func(v ToolboxVersionObject) string { return v.ID }, + ) +} - resp, err := c.pipeline.Do(req) - if err != nil { - return fmt.Errorf("HTTP request failed: %w", err) - } - defer resp.Body.Close() +// DeleteToolboxVersion deletes a single version. Service returns 400 with +// `bad_request` if the version is the current `default_version` and other +// versions exist; the CLI guards this pre-flight. +func (c *FoundryToolboxClient) DeleteToolboxVersion( + ctx context.Context, toolboxName, version string, +) error { + return c.doJSON( + ctx, http.MethodDelete, c.toolboxURL(toolboxName, "versions", version), nil, nil, + http.StatusOK, http.StatusNoContent, + ) +} - if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusNoContent) { - return runtime.NewResponseError(resp) +// SetDefaultVersion PATCHes the toolbox to mark a different version as default. +func (c *FoundryToolboxClient) SetDefaultVersion( + ctx context.Context, toolboxName, version string, +) (*ToolboxObject, error) { + var result ToolboxObject + if err := c.doJSON( + ctx, http.MethodPatch, c.toolboxURL(toolboxName), + map[string]string{"default_version": version}, &result, + ); err != nil { + return nil, err } - - return nil + return &result, nil }