Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (

"azure.ai.customtraining/internal/utils"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/azure/azure-dev/cli/azd/pkg/azdext"
"github.com/fatih/color"
)

// sanitizeEnvironmentName converts a project name to a valid azd environment name.
Expand Down Expand Up @@ -61,7 +63,26 @@ func parseProjectEndpoint(endpoint string) (accountName string, projectName stri
return accountName, projectName, nil
}

// validateOrInitEnvironment checks if environment is configured, and if not, attempts implicit initialization.
// validateOrInitEnvironment ensures the azd environment is populated with the
// account/project/tenant/subscription values needed by all `job` subcommands.
//
// Resolution priority (flags beat stored env values, per product direction):
//
// 1. If both --subscription (-s) and --project-endpoint (-e) are provided,
// parse them, look up the tenant, and **overwrite** those values in the
// current azd environment. A yellow warning is printed so the user is aware
// their stored env was modified.
// 2. If neither flag is provided, fall back to the existing azd environment
// (must already be configured).
// 3. If only one of the two flags is provided, error: both must come together.
// 4. If env is unconfigured AND both flags are provided AND no current env
// exists yet, run a one-time implicit initialization (creates the azd env,
// sets values, and `azd env new` makes it current).
// 5. If env is unconfigured AND no flags are provided, error directing the
// user to run init or pass both flags.
//
// Subcommands continue reading values via GetEnvironmentValues from the
// current env, so no subcommand changes are required.
func validateOrInitEnvironment(ctx context.Context, subscriptionId, projectEndpoint string) error {
ctx = azdext.WithAccessToken(ctx)

Expand All @@ -71,14 +92,19 @@ func validateOrInitEnvironment(ctx context.Context, subscriptionId, projectEndpo
}
defer azdClient.Close()

// If user explicitly provided -e and -s flags, always use them (re-initialize environment)
if projectEndpoint != "" && subscriptionId != "" {
return implicitInit(ctx, azdClient, subscriptionId, projectEndpoint)
// Reject mixed flag usage early — both must be provided together.
if (subscriptionId == "") != (projectEndpoint == "") {
return fmt.Errorf(
"--subscription (-s) and --project-endpoint (-e) must be provided together")
}

// No flags provided — check if environment is already configured
envValues, _ := utils.GetEnvironmentValues(ctx, azdClient)
required := []string{utils.EnvAzureTenantID, utils.EnvAzureSubscriptionID, utils.EnvAzureAccountName, utils.EnvAzureProjectName}
required := []string{
utils.EnvAzureTenantID,
utils.EnvAzureSubscriptionID,
utils.EnvAzureAccountName,
utils.EnvAzureProjectName,
}

allConfigured := true
for _, varName := range required {
Expand All @@ -88,11 +114,99 @@ func validateOrInitEnvironment(ctx context.Context, subscriptionId, projectEndpo
}
}

flagsProvided := subscriptionId != "" && projectEndpoint != ""

// Path 1: env already configured + flags provided → override stored values.
if allConfigured && flagsProvided {
return overrideEnvWithFlags(ctx, azdClient, subscriptionId, projectEndpoint)
}

// Path 2: env already configured, no flags → use as-is.
if allConfigured {
return nil
}

return fmt.Errorf("required environment variables not set. Either run 'azd ai training init' or provide both --subscription (-s) and --project-endpoint (-e) flags")
// Path 3: env not configured, no flags → error.
if !flagsProvided {
return fmt.Errorf(
"required environment variables not set. Either run 'azd ai training init' or " +
"provide both --subscription (-s) and --project-endpoint (-e) flags")
}

// Path 4: env not configured + flags provided → first-time implicit init.
fmt.Println("Environment not configured. Running implicit initialization...")
return implicitInit(ctx, azdClient, subscriptionId, projectEndpoint)
}

// overrideEnvWithFlags writes the values derived from --subscription /
// --project-endpoint into the current azd environment, replacing any
// previously stored values. A yellow warning is printed so the user knows
// their stored env was modified by this invocation.
//
// This mirrors the same setEnvValues payload that `azd ai training init`
// writes (tenant, subscription, resource group, location, account, project),
// but skips init's project-scaffolding and env-creation steps because both
// already exist by the time we reach this code path.
func overrideEnvWithFlags(
ctx context.Context,
azdClient *azdext.AzdClient,
subscriptionId, projectEndpoint string,
) error {
accountName, projectName, err := parseProjectEndpoint(projectEndpoint)
if err != nil {
return fmt.Errorf("failed to parse --project-endpoint: %w", err)
}

// Re-resolve tenant from the provided subscription so cross-tenant
// scenarios work (the previously stored AZURE_TENANT_ID may belong to a
// different subscription).
tenantResp, err := azdClient.Account().LookupTenant(ctx, &azdext.LookupTenantRequest{
SubscriptionId: subscriptionId,
})
if err != nil {
return fmt.Errorf("failed to look up tenant for subscription %q: %w", subscriptionId, err)
}

// Look up the project via ARM (same call init uses) to get the authoritative
// resource group + location. Without this, AZURE_LOCATION and
// AZURE_RESOURCE_GROUP_NAME would remain stale from the previous project.
credential, err := azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{
TenantID: tenantResp.TenantId,
AdditionallyAllowedTenants: []string{"*"},
})
if err != nil {
return fmt.Errorf("failed to create azure credential: %w", err)
}
project, err := findProjectByEndpoint(ctx, subscriptionId, accountName, projectName, credential)
if err != nil {
return fmt.Errorf("failed to find project for --project-endpoint: %w", err)
}

currentEnv, err := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{})
if err != nil || currentEnv.Environment == nil {
return fmt.Errorf("failed to determine current azd environment: %w", err)
}
envName := currentEnv.Environment.Name

if err := setEnvValues(ctx, azdClient, envName, map[string]string{
utils.EnvAzureTenantID: tenantResp.TenantId,
utils.EnvAzureSubscriptionID: project.SubscriptionId,
utils.EnvAzureResourceGroup: project.ResourceGroupName,
utils.EnvAzureLocation: project.Location,
utils.EnvAzureAccountName: project.AiAccountName,
utils.EnvAzureProjectName: project.AiProjectName,
}); err != nil {
return fmt.Errorf("failed to update azd environment %q: %w", envName, err)
}

color.Yellow(
"Warning: --subscription and --project-endpoint overrode azd environment %q "+
"(subscription, project endpoint, and the derived tenant, resource group, "+
"location, and account name). These changes persist for subsequent commands. "+
"Run 'azd ai training init' to reconfigure interactively.\n",
envName,
)
return nil
}

// implicitInit performs a lightweight initialization using the provided subscription and project endpoint flags.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package cmd

import (
"testing"

"github.com/stretchr/testify/require"
)

// TestParseProjectEndpoint covers the endpoint parser used by both the init
// flow and the new --subscription / --project-endpoint override path in
// validateOrInitEnvironment. Failures here propagate as user-visible errors,
// so all the success and failure modes are pinned down.
func TestParseProjectEndpoint(t *testing.T) {
tests := []struct {
name string
endpoint string
wantAccount string
wantProject string
wantErr bool
wantErrContains string
}{
{
name: "services.ai.azure.com endpoint",
endpoint: "https://my-account.services.ai.azure.com/api/projects/my-project",
wantAccount: "my-account",
wantProject: "my-project",
},
{
name: "cognitiveservices.azure.com endpoint",
endpoint: "https://other-account.cognitiveservices.azure.com/api/projects/other-project",
wantAccount: "other-account",
wantProject: "other-project",
},
{
name: "trailing slash on project segment is tolerated",
endpoint: "https://acc.services.ai.azure.com/api/projects/proj/",
wantAccount: "acc",
wantProject: "proj",
},
{
name: "missing /api/projects/ segment",
endpoint: "https://acc.services.ai.azure.com/foo/bar/baz",
wantErr: true,
wantErrContains: "expected format /api/projects/{project-name}",
},
{
name: "wrong path order",
endpoint: "https://acc.services.ai.azure.com/projects/api/proj",
wantErr: true,
wantErrContains: "expected format /api/projects/{project-name}",
},
{
name: "missing project name",
endpoint: "https://acc.services.ai.azure.com/api/projects/",
wantErr: true,
wantErrContains: "expected format /api/projects/{project-name}",
},
{
name: "no path at all",
endpoint: "https://acc.services.ai.azure.com",
wantErr: true,
wantErrContains: "expected format /api/projects/{project-name}",
},
{
name: "empty hostname",
endpoint: "https:///api/projects/proj",
wantErr: true,
wantErrContains: "cannot extract account name",
},
{
name: "http scheme accepted (parser is scheme-agnostic)",
endpoint: "http://acc.services.ai.azure.com/api/projects/proj",
wantAccount: "acc",
wantProject: "proj",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
account, project, err := parseProjectEndpoint(tc.endpoint)
if tc.wantErr {
require.Error(t, err)
if tc.wantErrContains != "" {
require.Contains(t, err.Error(), tc.wantErrContains)
}
return
}
require.NoError(t, err)
require.Equal(t, tc.wantAccount, account)
require.Equal(t, tc.wantProject, project)
})
}
}

// TestSanitizeEnvironmentName covers the helper used by implicitInit to
// derive an azd environment name from a project name. azd env names must be
// lowercase letters, numbers, and hyphens only, and must start/end with an
// alphanumeric character.
func TestSanitizeEnvironmentName(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{name: "already valid", input: "my-project", want: "my-project"},
{name: "uppercase lowered", input: "MyProject", want: "myproject"},
{name: "spaces become hyphens", input: "my project name", want: "my-project-name"},
{name: "underscores become hyphens", input: "my_project_name", want: "my-project-name"},
{name: "special chars stripped", input: "my.project!name@123", want: "myprojectname123"},
{name: "consecutive hyphens collapsed", input: "my---project", want: "my-project"},
{name: "leading/trailing hyphens trimmed", input: "-my-project-", want: "my-project"},
{name: "all special chars falls back to default", input: "!@#$%", want: "training-env"},
{name: "empty string falls back to default", input: "", want: "training-env"},
{name: "mixed messy input", input: " My_Crazy.Project!Name ", want: "my-crazyprojectname"},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := sanitizeEnvironmentName(tc.input)
require.Equal(t, tc.want, got)
})
}
}