Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions cli/azd/pkg/update/msi_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,35 +150,32 @@ func isStandardMSIInstall() error {
return nil
}

// versionFlag returns the install script parameter value for the given channel.
func versionFlag(channel Channel) string {
func escapeForPSSingleQuote(s string) string {
return strings.ReplaceAll(s, "'", "''")
}

// buildInstallScriptArgs constructs the PowerShell arguments to run install-azd.ps1.
// For all channels, the script is downloaded to a temp directory.
// For daily channel, additional parameters (-Version, -InstallFolder) are passed
Comment thread
hemarina marked this conversation as resolved.
Outdated
// to the script. The install folder is escaped for PowerShell single-quoted strings
// to handle paths containing apostrophes (e.g. O'Connor).
// Returns the arguments to pass to the "powershell" command.
Comment thread
hemarina marked this conversation as resolved.
func buildInstallScriptArgs(channel Channel) []string {
var scriptArgs string
switch channel {
case ChannelDaily:
return "daily"
case ChannelStable:
return "stable"
scriptArgs = fmt.Sprintf(" -Version 'daily' -InstallFolder '%s'",
escapeForPSSingleQuote(expectedPerUserInstallDir()))
default:
Comment thread
hemarina marked this conversation as resolved.
return "stable"
scriptArgs = " -Version 'stable'"
}
}

// buildInstallScriptArgs constructs the PowerShell arguments to download and run
// install-azd.ps1 with the appropriate -Version flag.
// The -SkipVerify flag is passed because Authenticode verification via
// Get-AuthenticodeSignature failed.
// The MSI is already downloaded over HTTPS from a Microsoft-controlled domain,
// so the transport-level integrity is sufficient.
// Returns the arguments to pass to the "powershell" command.
func buildInstallScriptArgs(channel Channel) []string {
version := versionFlag(channel)
// Download the script to a temp file, then invoke it with the appropriate -Version flag.
// Using -ExecutionPolicy Bypass ensures the script runs even if the system policy is restrictive.
script := fmt.Sprintf(
`$script = Join-Path $env:TEMP 'install-azd.ps1'; `+
`Invoke-RestMethod '%s' -OutFile $script; `+
`& $script -Version '%s' -SkipVerify; `+
`Remove-Item $script -Force -ErrorAction SilentlyContinue`,
installScriptURL, version,
"$tmpScript = Join-Path $env:TEMP 'azd-install.ps1'; "+
"Invoke-RestMethod '%s' -OutFile $tmpScript; "+
"& $tmpScript%s; "+
"Remove-Item $tmpScript -Force -ErrorAction SilentlyContinue",
installScriptURL, scriptArgs,
)
return []string{"-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", script}
}
143 changes: 118 additions & 25 deletions cli/azd/pkg/update/msi_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"
"testing"

"github.com/azure/azure-dev/cli/azd/pkg/exec"
"github.com/azure/azure-dev/cli/azd/test/mocks/mockexec"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -42,31 +43,16 @@ func TestExpectedPerUserInstallDir(t *testing.T) {
}
}

func TestVersionFlag(t *testing.T) {
tests := []struct {
name string
channel Channel
want string
}{
{"stable channel", ChannelStable, "stable"},
{"daily channel", ChannelDaily, "daily"},
{"unknown defaults to stable", Channel("nightly"), "stable"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := versionFlag(tt.channel)
require.Equal(t, tt.want, got)
})
}
}

func TestBuildInstallScriptArgs(t *testing.T) {
t.Setenv("LOCALAPPDATA", `C:\Users\testuser\AppData\Local`)
expectedDir := expectedPerUserInstallDir()

tests := []struct {
name string
channel Channel
// We check that certain substrings appear in the constructed args
wantContains []string
wantContains []string
wantNotContains []string
}{
{
name: "stable",
Expand All @@ -77,7 +63,10 @@ func TestBuildInstallScriptArgs(t *testing.T) {
"-Command",
installScriptURL,
"-Version 'stable'",
"-SkipVerify",
"Remove-Item",
},
wantNotContains: []string{
"-InstallFolder",
Comment thread
hemarina marked this conversation as resolved.
},
},
{
Expand All @@ -89,7 +78,9 @@ func TestBuildInstallScriptArgs(t *testing.T) {
"-Command",
installScriptURL,
"-Version 'daily'",
"-SkipVerify",
"-InstallFolder",
expectedDir,
"Remove-Item",
},
},
}
Expand All @@ -105,26 +96,74 @@ func TestBuildInstallScriptArgs(t *testing.T) {
for _, s := range tt.wantContains {
require.Contains(t, joined, s, "expected args to contain %q", s)
}
for _, s := range tt.wantNotContains {
require.NotContains(t, joined, s, "expected args NOT to contain %q", s)
}
})
}
}

func TestEscapeForPSSingleQuote(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"no quotes", `C:\Users\testuser`, `C:\Users\testuser`},
{"single apostrophe", `C:\Users\O'Connor`, `C:\Users\O''Connor`},
{"multiple apostrophes", `C:\it's\a'path`, `C:\it''s\a''path`},
{"empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, escapeForPSSingleQuote(tt.input))
})
}
}

func TestBuildInstallScriptArgs_ApostropheInPath(t *testing.T) {
t.Setenv("LOCALAPPDATA", `C:\Users\O'Connor\AppData\Local`)

args := buildInstallScriptArgs(ChannelDaily)
script := args[4]

// The apostrophe must be doubled for a valid PowerShell single-quoted string.
require.Contains(t, script, `O''Connor`)
// Must NOT contain unescaped apostrophe inside the -InstallFolder value.
require.NotContains(t, script, `-InstallFolder 'C:\Users\O'Connor`)
}

func TestBuildInstallScriptArgs_Structure(t *testing.T) {
t.Setenv("LOCALAPPDATA", `C:\Users\testuser\AppData\Local`)
expectedDir := expectedPerUserInstallDir()

args := buildInstallScriptArgs(ChannelStable)

// The args should be: ["-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", <script>]
require.Equal(t, 5, len(args), "expected exactly 5 args")
require.Equal(t, "-NoProfile", args[0])
require.Equal(t, "-ExecutionPolicy", args[1])
require.Equal(t, "Bypass", args[2])
require.Equal(t, "-Command", args[3])

// The script (args[4]) should be a single string containing the full PowerShell pipeline
// Stable downloads to temp file — passes -Version 'stable' explicitly
script := args[4]
require.Contains(t, script, "Invoke-RestMethod")
require.Contains(t, script, installScriptURL)
require.Contains(t, script, "-SkipVerify")
require.Contains(t, script, "Remove-Item")
require.Contains(t, script, "-Version 'stable'")
require.NotContains(t, script, "-InstallFolder")

// Daily downloads to temp file with -Version 'daily'
argsDaily := buildInstallScriptArgs(ChannelDaily)
require.Equal(t, 5, len(argsDaily))
require.Equal(t, "Bypass", argsDaily[2])
scriptDaily := argsDaily[4]
require.Contains(t, scriptDaily, "Invoke-RestMethod")
require.Contains(t, scriptDaily, installScriptURL)
require.Contains(t, scriptDaily, "-Version 'daily'")
require.Contains(t, scriptDaily, "-InstallFolder")
require.Contains(t, scriptDaily, expectedDir)
require.Contains(t, scriptDaily, "Remove-Item")
}

func TestIsStandardMSIInstall_StandardPath(t *testing.T) {
Expand Down Expand Up @@ -243,3 +282,57 @@ func TestUpdateViaMSI_NonStandardInstallBlocks(t *testing.T) {
require.True(t, errors.As(err, &updateErr))
require.Equal(t, CodeNonStandardInstall, updateErr.Code)
}

// TestUpdateViaMSI_InvokesPowerShellWithCorrectArgs verifies that updateViaMSI calls
// "powershell" with the arguments produced by buildInstallScriptArgs.
func TestUpdateViaMSI_InvokesPowerShellWithCorrectArgs(t *testing.T) {
// Point LOCALAPPDATA at the test binary so isStandardMSIInstall passes.
Comment thread
hemarina marked this conversation as resolved.
Outdated
exePath, err := os.Executable()
require.NoError(t, err)
exePath, err = filepath.EvalSymlinks(exePath)
require.NoError(t, err)

actualDir := filepath.Dir(exePath)
suffix := filepath.Join("Programs", "Azure Dev CLI")
if !strings.HasSuffix(strings.ToLower(filepath.Clean(actualDir)), strings.ToLower(suffix)) {
t.Skipf("test binary dir %q does not end with %q; skipping", actualDir, suffix)
}

localAppData := strings.TrimSuffix(filepath.Clean(actualDir), filepath.Clean(suffix))
localAppData = strings.TrimRight(localAppData, string(filepath.Separator))
t.Setenv("LOCALAPPDATA", localAppData)

for _, channel := range []Channel{ChannelStable, ChannelDaily} {
t.Run(string(channel), func(t *testing.T) {
var capturedArgs exec.RunArgs
captured := false

mockRunner := mockexec.NewMockCommandRunner()
mockRunner.When(func(args exec.RunArgs, command string) bool {
return strings.HasPrefix(command, "powershell")
}).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) {
capturedArgs = args
captured = true
return exec.NewRunResult(0, "", ""), nil
})

m := NewManager(mockRunner, nil)
var buf strings.Builder
cfg := &UpdateConfig{Channel: channel}

// updateViaMSI will likely fail at backup/hash stage in test, but
// if the mock is reached we can validate the invocation.
_ = m.updateViaMSI(context.Background(), cfg, &buf)

if !captured {
t.Skip("powershell mock was not reached (backupCurrentExe failed in test env)")
}

require.Equal(t, "powershell", capturedArgs.Cmd, "expected powershell executable")

// Verify args match buildInstallScriptArgs output.
expectedArgs := buildInstallScriptArgs(channel)
require.Equal(t, expectedArgs, capturedArgs.Args)
})
}
}
Loading