diff --git a/internal/cmd/common/common.go b/internal/cmd/common/common.go index b726cea6..45aa5d89 100644 --- a/internal/cmd/common/common.go +++ b/internal/cmd/common/common.go @@ -34,6 +34,8 @@ var ( LogDebug bool LogLevel string IgnoreProxy bool + // logFileCloser holds the currently open log file so it can be closed on cleanup. + logFileCloser io.Closer ) // TerminalRestorer is a function that can be called to restore terminal state @@ -131,6 +133,7 @@ func InitLogger(cmd *cobra.Command) { if err != nil { panic(err) } + logFileCloser = runLogFile defaultOut = &CustomWriter{Writer: runLogFile} rootFlags := cmd.Root().PersistentFlags() diff --git a/internal/cmd/common/common_test.go b/internal/cmd/common/common_test.go new file mode 100644 index 00000000..2ca33fb7 --- /dev/null +++ b/internal/cmd/common/common_test.go @@ -0,0 +1,421 @@ +package common + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFatalHook_FatalLevelCallsTerminalRestorer(t *testing.T) { + orig := TerminalRestorer + defer func() { TerminalRestorer = orig }() + + called := false + TerminalRestorer = func() { called = true } + + hook := FatalHook{} + var buf bytes.Buffer + logger := zerolog.New(&buf) + event := logger.Fatal() + hook.Run(event, zerolog.FatalLevel, "test message") + + assert.True(t, called, "TerminalRestorer should be called on fatal level") +} + +func TestFatalHook_NonFatalLevelsDoNotCallTerminalRestorer(t *testing.T) { + orig := TerminalRestorer + defer func() { TerminalRestorer = orig }() + + levels := []zerolog.Level{ + zerolog.TraceLevel, + zerolog.DebugLevel, + zerolog.InfoLevel, + zerolog.WarnLevel, + zerolog.ErrorLevel, + } + + hook := FatalHook{} + var buf bytes.Buffer + logger := zerolog.New(&buf) + + for _, lvl := range levels { + called := false + TerminalRestorer = func() { called = true } + + event := logger.WithLevel(lvl) + hook.Run(event, lvl, "test") + assert.False(t, called, "TerminalRestorer should not be called for level %v", lvl) + } +} + +func TestFatalHook_NilTerminalRestorerDoesNotPanic(t *testing.T) { + orig := TerminalRestorer + defer func() { TerminalRestorer = orig }() + + TerminalRestorer = nil + hook := FatalHook{} + + assert.NotPanics(t, func() { + var buf bytes.Buffer + logger := zerolog.New(&buf) + event := logger.Fatal() + hook.Run(event, zerolog.FatalLevel, "test") + }) +} + +func TestCustomWriter_WritesCorrectly(t *testing.T) { + tmpDir := t.TempDir() + logFile := filepath.Join(tmpDir, "test.log") + + f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY, 0644) + require.NoError(t, err) + defer func() { _ = f.Close() }() + + writer := &CustomWriter{Writer: f} + testData := []byte(`{"level":"info","message":"hello"}` + "\n") + + n, err := writer.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n, "Write should return original length") + + require.NoError(t, f.Sync()) + content, err := os.ReadFile(logFile) + require.NoError(t, err) + assert.Contains(t, string(content), "hello", "Written content should be readable") +} + +func TestCustomWriter_HandlesDataWithoutTrailingNewline(t *testing.T) { + tmpDir := t.TempDir() + f, err := os.OpenFile(filepath.Join(tmpDir, "noeol.log"), os.O_CREATE|os.O_WRONLY, 0644) + require.NoError(t, err) + defer func() { _ = f.Close() }() + + writer := &CustomWriter{Writer: f} + data := []byte(`no trailing newline`) + n, err := writer.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) +} + +func TestSetGlobalLogLevel(t *testing.T) { + tests := []struct { + name string + logDebug bool + logLevel string + expectedLevel zerolog.Level + }{ + { + name: "verbose flag sets debug level", + logDebug: true, + logLevel: "", + expectedLevel: zerolog.DebugLevel, + }, + { + name: "log-level trace", + logDebug: false, + logLevel: "trace", + expectedLevel: zerolog.TraceLevel, + }, + { + name: "log-level debug", + logDebug: false, + logLevel: "debug", + expectedLevel: zerolog.DebugLevel, + }, + { + name: "log-level info", + logDebug: false, + logLevel: "info", + expectedLevel: zerolog.InfoLevel, + }, + { + name: "log-level warn", + logDebug: false, + logLevel: "warn", + expectedLevel: zerolog.WarnLevel, + }, + { + name: "log-level error", + logDebug: false, + logLevel: "error", + expectedLevel: zerolog.ErrorLevel, + }, + { + name: "default (no flags) is info", + logDebug: false, + logLevel: "", + expectedLevel: zerolog.InfoLevel, + }, + { + name: "invalid log-level defaults to info", + logDebug: false, + logLevel: "garbage", + expectedLevel: zerolog.InfoLevel, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origLevel := zerolog.GlobalLevel() + origLogger := log.Logger + origDebug := LogDebug + origLogLevel := LogLevel + defer func() { + zerolog.SetGlobalLevel(origLevel) + log.Logger = origLogger + LogDebug = origDebug + LogLevel = origLogLevel + }() + + // Redirect logger to discard to avoid test output pollution + log.Logger = zerolog.New(zerolog.MultiLevelWriter()).Level(zerolog.TraceLevel) + + LogDebug = tt.logDebug + LogLevel = tt.logLevel + SetGlobalLogLevel(nil) + + assert.Equal(t, tt.expectedLevel, zerolog.GlobalLevel(), + "global log level should be %v when logDebug=%v, logLevel=%q", + tt.expectedLevel, tt.logDebug, tt.logLevel) + }) + } +} + +func TestAddCommonFlags(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + AddCommonFlags(cmd) + + flags := cmd.PersistentFlags() + + t.Run("json flag", func(t *testing.T) { + f := flags.Lookup("json") + assert.NotNil(t, f, "json flag should be registered") + assert.Equal(t, "false", f.DefValue, "json flag default should be false") + }) + + t.Run("logfile flag", func(t *testing.T) { + f := flags.Lookup("logfile") + assert.NotNil(t, f, "logfile flag should be registered") + assert.Equal(t, "", f.DefValue, "logfile flag default should be empty") + }) + + t.Run("verbose flag", func(t *testing.T) { + f := flags.Lookup("verbose") + assert.NotNil(t, f, "verbose flag should be registered") + assert.Equal(t, "false", f.DefValue, "verbose flag default should be false") + assert.Equal(t, "v", f.Shorthand, "verbose flag shorthand should be 'v'") + }) + + t.Run("log-level flag", func(t *testing.T) { + f := flags.Lookup("log-level") + assert.NotNil(t, f, "log-level flag should be registered") + assert.Equal(t, "", f.DefValue, "log-level flag default should be empty string") + }) + + t.Run("color flag", func(t *testing.T) { + f := flags.Lookup("color") + assert.NotNil(t, f, "color flag should be registered") + assert.Equal(t, "true", f.DefValue, "color flag default should be true") + }) + + t.Run("ignore-proxy flag", func(t *testing.T) { + f := flags.Lookup("ignore-proxy") + assert.NotNil(t, f, "ignore-proxy flag should be registered") + assert.Equal(t, "false", f.DefValue, "ignore-proxy flag default should be false") + }) +} + +func TestFormatLevelWithHitColor(t *testing.T) { + tests := []struct { + name string + colorEnabled bool + level string + expectColor bool + expectLevel string + }{ + { + name: "hit level with color", + colorEnabled: true, + level: "hit", + expectColor: true, + expectLevel: "hit", + }, + { + name: "hit level without color", + colorEnabled: false, + level: "hit", + expectColor: false, + expectLevel: "hit", + }, + { + name: "info level with color", + colorEnabled: true, + level: "info", + expectColor: true, + expectLevel: "info", + }, + { + name: "info level without color", + colorEnabled: false, + level: "info", + expectColor: false, + expectLevel: "info", + }, + { + name: "unknown level", + colorEnabled: true, + level: "custom", + expectColor: false, + expectLevel: "custom", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + formatter := formatLevelWithHitColor(tt.colorEnabled) + result := formatter(tt.level) + + assert.Contains(t, result, tt.expectLevel, + "result should contain the level string") + if tt.expectColor { + assert.Contains(t, result, "\x1b[", + "result should contain ANSI escape code when color enabled") + } else { + assert.NotContains(t, result, "\x1b[", + "result should not contain ANSI escape code when color disabled") + } + }) + } +} + +func TestFormatLevelWithHitColor_NonStringInput(t *testing.T) { + formatter := formatLevelWithHitColor(true) + result := formatter(42) + assert.Equal(t, "", result, "non-string input should return empty string") +} + +func TestSetupPersistentPreRun(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + assert.Nil(t, cmd.PersistentPreRun, "PersistentPreRun should be nil before setup") + + SetupPersistentPreRun(cmd) + assert.NotNil(t, cmd.PersistentPreRun, "PersistentPreRun should be set after setup") +} + +// TestInitLogger_ConsoleMode verifies InitLogger sets up a console zerolog writer. +func TestInitLogger_ConsoleMode(t *testing.T) { + origLogFile := LogFile + origJson := JsonLogoutput + origColor := LogColor + origLogger := log.Logger + origLogFileCloser := logFileCloser + defer func() { + log.Logger = origLogger + if logFileCloser != nil && logFileCloser != origLogFileCloser { + _ = logFileCloser.Close() + } + logFileCloser = origLogFileCloser + LogFile = origLogFile + JsonLogoutput = origJson + LogColor = origColor + }() + + tmpDir := t.TempDir() + LogFile = filepath.Join(tmpDir, "console.log") + JsonLogoutput = false + LogColor = false + + cmd := &cobra.Command{Use: "test"} + InitLogger(cmd) + + log.Info().Msg("console-mode-test-msg") + + content, err := os.ReadFile(LogFile) + require.NoError(t, err) + assert.Contains(t, string(content), "console-mode-test-msg") +} + +// TestInitLogger_JSONMode verifies InitLogger outputs JSON when JsonLogoutput=true. +func TestInitLogger_JSONMode(t *testing.T) { + origLogFile := LogFile + origJson := JsonLogoutput + origLogger := log.Logger + origLogFileCloser := logFileCloser + defer func() { + log.Logger = origLogger + if logFileCloser != nil && logFileCloser != origLogFileCloser { + _ = logFileCloser.Close() + } + logFileCloser = origLogFileCloser + LogFile = origLogFile + JsonLogoutput = origJson + }() + + tmpDir := t.TempDir() + LogFile = filepath.Join(tmpDir, "json.log") + JsonLogoutput = true + + cmd := &cobra.Command{Use: "test"} + InitLogger(cmd) + + log.Info().Msg("json-mode-test-msg") + + content, err := os.ReadFile(LogFile) + require.NoError(t, err) + s := string(content) + assert.Contains(t, s, "json-mode-test-msg") + assert.Contains(t, s, `"level"`, "JSON output should contain level field") +} + +// TestInitLogger_NoLogFile verifies InitLogger does not panic when LogFile is empty. +func TestInitLogger_NoLogFile(t *testing.T) { + origLogFile := LogFile + origJson := JsonLogoutput + origLogger := log.Logger + defer func() { + LogFile = origLogFile + JsonLogoutput = origJson + log.Logger = origLogger + }() + + LogFile = "" + JsonLogoutput = false + + cmd := &cobra.Command{Use: "test"} + assert.NotPanics(t, func() { + InitLogger(cmd) + }) +} + +// TestSaveAndRestoreTerminalState verifies that SaveTerminalState and RestoreTerminalState +// do not panic regardless of whether stdin is a terminal. +func TestSaveAndRestoreTerminalState(t *testing.T) { + origState := originalTermState + defer func() { originalTermState = origState }() + + assert.NotPanics(t, func() { + SaveTerminalState() + }) + + assert.NotPanics(t, func() { + RestoreTerminalState() + }) +} + +// TestRestoreTerminalState_NilState verifies no panic when terminal state was never saved. +func TestRestoreTerminalState_NilState(t *testing.T) { + origState := originalTermState + defer func() { originalTermState = origState }() + + originalTermState = nil + assert.NotPanics(t, func() { + RestoreTerminalState() + }) +} diff --git a/internal/cmd/devops/devops_test.go b/internal/cmd/devops/devops_test.go index 18f8a111..d4a5d3fc 100644 --- a/internal/cmd/devops/devops_test.go +++ b/internal/cmd/devops/devops_test.go @@ -4,85 +4,58 @@ import ( "testing" "github.com/CompassSecurity/pipeleek/internal/cmd/devops/scan" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewAzureDevOpsRootCmd(t *testing.T) { cmd := NewAzureDevOpsRootCmd() - if cmd == nil { - t.Fatal("Expected non-nil command") - return - } - - if cmd.Use != "ad [command]" { - t.Errorf("Expected Use to be 'ad [command]', got %q", cmd.Use) - } - - if cmd.Short == "" { - t.Error("Expected non-empty Short description") - } - - if cmd.GroupID != "AzureDevOps" { - t.Errorf("Expected GroupID 'AzureDevOps', got %q", cmd.GroupID) - } - - if len(cmd.Commands()) < 1 { - t.Errorf("Expected at least 1 subcommand, got %d", len(cmd.Commands())) - } + require.NotNil(t, cmd, "NewAzureDevOpsRootCmd should return non-nil command") + assert.Equal(t, "ad [command]", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") + assert.Equal(t, "AzureDevOps", cmd.GroupID) + assert.GreaterOrEqual(t, len(cmd.Commands()), 1, "should have at least 1 subcommand") scanCmd := cmd.Commands()[0] - if scanCmd.Use != "scan [no options!]" { - t.Errorf("Expected first subcommand Use to be 'scan [no options!]', got %q", scanCmd.Use) - } + assert.Equal(t, "scan [no options!]", scanCmd.Use) } func TestNewScanCmd(t *testing.T) { cmd := scan.NewScanCmd() - if cmd == nil { - t.Fatal("Expected non-nil command") - return - } + require.NotNil(t, cmd, "NewScanCmd should return non-nil command") + assert.Equal(t, "scan [no options!]", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") + assert.NotEmpty(t, cmd.Long, "Long description should not be empty") + assert.NotEmpty(t, cmd.Example, "Example should not be empty") - if cmd.Use != "scan [no options!]" { - t.Errorf("Expected Use to be 'scan [no options!]', got %q", cmd.Use) - } + flags := cmd.Flags() - if cmd.Short == "" { - t.Error("Expected non-empty Short description") - } + tokenFlag := flags.Lookup("token") + assert.NotNil(t, tokenFlag, "'token' flag should be registered") + assert.Equal(t, "", tokenFlag.DefValue, "'token' flag default should be empty") + assert.Equal(t, "t", tokenFlag.Shorthand, "'token' flag shorthand should be 't'") - if cmd.Long == "" { - t.Error("Expected non-empty Long description") - } + orgFlag := flags.Lookup("organization") + assert.NotNil(t, orgFlag, "'organization' flag should be registered") + assert.Equal(t, "", orgFlag.DefValue, "'organization' flag default should be empty") - if cmd.Example == "" { - t.Error("Expected non-empty Example") - } + projectFlag := flags.Lookup("project") + assert.NotNil(t, projectFlag, "'project' flag should be registered") + assert.Equal(t, "", projectFlag.DefValue, "'project' flag default should be empty") - flags := cmd.Flags() - if flags.Lookup("token") == nil { - t.Error("Expected 'token' flag to exist") - } - if flags.Lookup("organization") == nil { - t.Error("Expected 'organization' flag to exist") - } - if flags.Lookup("project") == nil { - t.Error("Expected 'project' flag to exist") - } - if flags.Lookup("confidence") == nil { - t.Error("Expected 'confidence' flag to exist") - } - if flags.Lookup("threads") == nil { - t.Error("Expected 'threads' flag to exist") - } - if flags.Lookup("truffle-hog-verification") == nil { - t.Error("Expected 'truffle-hog-verification' flag to exist") - } - if flags.Lookup("max-builds") == nil { - t.Error("Expected 'max-builds' flag to exist") - } - if flags.Lookup("max-artifact-size") == nil { - t.Error("Expected 'max-artifact-size' flag to exist") - } + devopsFlag := flags.Lookup("devops") + assert.NotNil(t, devopsFlag, "'devops' flag should be registered") + assert.Equal(t, "https://dev.azure.com", devopsFlag.DefValue, + "'devops' flag default should be https://dev.azure.com") + + maxBuildsFlag := flags.Lookup("max-builds") + assert.NotNil(t, maxBuildsFlag, "'max-builds' flag should be registered") + assert.Equal(t, "-1", maxBuildsFlag.DefValue, "'max-builds' flag default should be -1") + + assert.NotNil(t, flags.Lookup("confidence"), "'confidence' flag should be registered") + assert.NotNil(t, flags.Lookup("threads"), "'threads' flag should be registered") + assert.NotNil(t, flags.Lookup("truffle-hog-verification"), "'truffle-hog-verification' flag should be registered") + assert.NotNil(t, flags.Lookup("max-artifact-size"), "'max-artifact-size' flag should be registered") } diff --git a/internal/cmd/gitlab/gitlab_test.go b/internal/cmd/gitlab/gitlab_test.go index bda9157e..e221053d 100644 --- a/internal/cmd/gitlab/gitlab_test.go +++ b/internal/cmd/gitlab/gitlab_test.go @@ -8,150 +8,82 @@ import ( "github.com/CompassSecurity/pipeleek/internal/cmd/gitlab/shodan" "github.com/CompassSecurity/pipeleek/internal/cmd/gitlab/variables" "github.com/CompassSecurity/pipeleek/internal/cmd/gitlab/vuln" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewGitLabRootCmd(t *testing.T) { cmd := NewGitLabRootCmd() - if cmd == nil { - t.Fatal("Expected non-nil command") - return - } - - if cmd.Use != "gl [command]" { - t.Errorf("Expected Use to be 'gl [command]', got %q", cmd.Use) - } - - if cmd.Short == "" { - t.Error("Expected non-empty Short description") - } - - if cmd.Long == "" { - t.Error("Expected non-empty Long description") - } - - if cmd.GroupID != "GitLab" { - t.Errorf("Expected GroupID 'GitLab', got %q", cmd.GroupID) - } + require.NotNil(t, cmd, "NewGitLabRootCmd should return non-nil command") + assert.Equal(t, "gl [command]", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") + assert.NotEmpty(t, cmd.Long, "Long description should not be empty") + assert.Equal(t, "GitLab", cmd.GroupID) + assert.GreaterOrEqual(t, len(cmd.Commands()), 8, + "should have at least 8 subcommands") flags := cmd.PersistentFlags() - if flags.Lookup("gitlab") == nil { - t.Error("Expected 'gitlab' persistent flag to exist") - } - if flags.Lookup("token") == nil { - t.Error("Expected 'token' persistent flag to exist") - } - - if len(cmd.Commands()) < 8 { - t.Errorf("Expected at least 8 subcommands, got %d", len(cmd.Commands())) - } + gitlabFlag := flags.Lookup("gitlab") + assert.NotNil(t, gitlabFlag, "'gitlab' persistent flag should be registered") + assert.Equal(t, "", gitlabFlag.DefValue, + "'gitlab' flag default should be empty") + + tokenFlag := flags.Lookup("token") + assert.NotNil(t, tokenFlag, "'token' persistent flag should be registered") + assert.Equal(t, "", tokenFlag.DefValue, "'token' flag default should be empty") } func TestNewVulnCmd(t *testing.T) { cmd := vuln.NewVulnCmd() - if cmd == nil { - t.Fatal("Expected non-nil command") - return - } - - if cmd.Use != "vuln" { - t.Errorf("Expected Use to be 'vuln', got %q", cmd.Use) - } - - if cmd.Short == "" { - t.Error("Expected non-empty Short description") - } + require.NotNil(t, cmd, "NewVulnCmd should return non-nil command") + assert.Equal(t, "vuln", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") } func TestNewVariablesCmd(t *testing.T) { cmd := variables.NewVariablesCmd() - if cmd == nil { - t.Fatal("Expected non-nil command") - return - } - - if cmd.Use != "variables" { - t.Errorf("Expected Use to be 'variables', got %q", cmd.Use) - } - - if cmd.Short == "" { - t.Error("Expected non-empty Short description") - } + require.NotNil(t, cmd, "NewVariablesCmd should return non-nil command") + assert.Equal(t, "variables", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") flags := cmd.Flags() - if flags.Lookup("gitlab") == nil { - t.Error("Expected 'gitlab' flag to exist") - } + assert.NotNil(t, flags.Lookup("gitlab"), "'gitlab' flag should be registered") } func TestNewEnumCmd(t *testing.T) { cmd := enum.NewEnumCmd() - if cmd == nil { - t.Fatal("Expected non-nil command") - return - } - - if cmd.Use != "enum" { - t.Errorf("Expected Use to be 'enum', got %q", cmd.Use) - } - - if cmd.Short == "" { - t.Error("Expected non-empty Short description") - } + require.NotNil(t, cmd, "NewEnumCmd should return non-nil command") + assert.Equal(t, "enum", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") } func TestNewRegisterCmd(t *testing.T) { cmd := register.NewRegisterCmd() - if cmd == nil { - t.Fatal("Expected non-nil command") - return - } - - if cmd.Use != "register" { - t.Errorf("Expected Use to be 'register', got %q", cmd.Use) - } - - if cmd.Short == "" { - t.Error("Expected non-empty Short description") - } + require.NotNil(t, cmd, "NewRegisterCmd should return non-nil command") + assert.Equal(t, "register", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") flags := cmd.Flags() - if flags.Lookup("username") == nil { - t.Error("Expected 'username' flag to exist") - } - if flags.Lookup("email") == nil { - t.Error("Expected 'email' flag to exist") - } - if flags.Lookup("password") == nil { - t.Error("Expected 'password' flag to exist") - } - if flags.Lookup("gitlab") == nil { - t.Error("Expected 'gitlab' flag to exist") - } + assert.NotNil(t, flags.Lookup("username"), "'username' flag should be registered") + assert.NotNil(t, flags.Lookup("email"), "'email' flag should be registered") + assert.NotNil(t, flags.Lookup("password"), "'password' flag should be registered") + assert.NotNil(t, flags.Lookup("gitlab"), "'gitlab' flag should be registered") } func TestNewShodanCmd(t *testing.T) { cmd := shodan.NewShodanCmd() - if cmd == nil { - t.Fatal("Expected non-nil command") - return - } - - if cmd.Use != "shodan" { - t.Errorf("Expected Use to be 'shodan', got %q", cmd.Use) - } - - if cmd.Short == "" { - t.Error("Expected non-empty Short description") - } + require.NotNil(t, cmd, "NewShodanCmd should return non-nil command") + assert.Equal(t, "shodan", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") flags := cmd.Flags() - if flags.Lookup("json") == nil { - t.Error("Expected 'json' flag to exist") - } + jsonFlag := flags.Lookup("json") + assert.NotNil(t, jsonFlag, "'json' flag should be registered") + assert.Equal(t, "", jsonFlag.DefValue, "'json' flag default should be empty string (path to Shodan JSON file)") } diff --git a/internal/cmd/gitlab/jobToken/jobtoken_test.go b/internal/cmd/gitlab/jobToken/jobtoken_test.go index 398370a9..308cba37 100644 --- a/internal/cmd/gitlab/jobToken/jobtoken_test.go +++ b/internal/cmd/gitlab/jobToken/jobtoken_test.go @@ -1,32 +1,28 @@ package jobtoken -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) func TestNewJobTokenRootCmd(t *testing.T) { cmd := NewJobTokenRootCmd() - if cmd == nil { - t.Fatal("expected non-nil command") - } - - if cmd.Use != "jobToken" { - t.Fatalf("expected Use to be jobToken, got %q", cmd.Use) - } + require.NotNil(t, cmd, "NewJobTokenRootCmd should return non-nil command") - if cmd.Short == "" { - t.Fatal("expected non-empty Short description") - } - - if cmd.Long == "" { - t.Fatal("expected non-empty Long description") - } + assert.Equal(t, "jobToken", cmd.Use) + assert.NotEmpty(t, cmd.Short, "Short description should not be empty") + assert.NotEmpty(t, cmd.Long, "Long description should not be empty") flags := cmd.PersistentFlags() - if flags.Lookup("gitlab") == nil { - t.Fatal("expected gitlab flag to exist") - } - if flags.Lookup("token") == nil { - t.Fatal("expected token flag to exist") - } + gitlabFlag := flags.Lookup("gitlab") + assert.NotNil(t, gitlabFlag, "'gitlab' persistent flag should be registered") + assert.Equal(t, "", gitlabFlag.DefValue, "'gitlab' flag default should be empty") + + tokenFlag := flags.Lookup("token") + assert.NotNil(t, tokenFlag, "'token' persistent flag should be registered") + assert.Equal(t, "", tokenFlag.DefValue, "'token' flag default should be empty") foundExploit := false for _, sub := range cmd.Commands() { @@ -35,7 +31,5 @@ func TestNewJobTokenRootCmd(t *testing.T) { break } } - if !foundExploit { - t.Fatal("expected exploit subcommand to be registered") - } + assert.True(t, foundExploit, "jobToken command should have 'exploit' subcommand") } diff --git a/internal/cmd/root_test.go b/internal/cmd/root_test.go index 11c0e1c7..335d7122 100644 --- a/internal/cmd/root_test.go +++ b/internal/cmd/root_test.go @@ -14,116 +14,128 @@ import ( func TestGlobalVerboseFlagRegistered(t *testing.T) { flag := rootCmd.PersistentFlags().Lookup("verbose") - if flag == nil { - t.Fatal("Global verbose flag not registered") - } + assert.NotNil(t, flag, "Global verbose flag should be registered") + assert.Equal(t, "false", flag.DefValue, "verbose flag default should be false") } func TestGlobalLogLevelFlagRegistered(t *testing.T) { flag := rootCmd.PersistentFlags().Lookup("log-level") - if flag == nil { - t.Fatal("Global log-level flag not registered") - } -} - -func TestSetGlobalLogLevel_VerboseFlag(t *testing.T) { - LogDebug = true - LogLevel = "" - setGlobalLogLevel(nil) - if zerolog.GlobalLevel() != zerolog.DebugLevel { - t.Errorf("Expected DebugLevel with -v flag, got %v", zerolog.GlobalLevel()) - } - LogDebug = false -} - -func TestSetGlobalLogLevel_LogLevelDebug(t *testing.T) { - LogDebug = false - LogLevel = "debug" - setGlobalLogLevel(nil) - if zerolog.GlobalLevel() != zerolog.DebugLevel { - t.Errorf("Expected DebugLevel, got %v", zerolog.GlobalLevel()) - } - LogLevel = "" -} - -func TestSetGlobalLogLevel_Info(t *testing.T) { - LogDebug = false - LogLevel = "info" - setGlobalLogLevel(nil) - if zerolog.GlobalLevel() != zerolog.InfoLevel { - t.Errorf("Expected InfoLevel, got %v", zerolog.GlobalLevel()) - } - LogLevel = "" -} - -func TestSetGlobalLogLevel_Warn(t *testing.T) { - LogDebug = false - LogLevel = "warn" - setGlobalLogLevel(nil) - if zerolog.GlobalLevel() != zerolog.WarnLevel { - t.Errorf("Expected WarnLevel, got %v", zerolog.GlobalLevel()) - } - LogLevel = "" -} - -func TestSetGlobalLogLevel_Error(t *testing.T) { - LogDebug = false - LogLevel = "error" - setGlobalLogLevel(nil) - if zerolog.GlobalLevel() != zerolog.ErrorLevel { - t.Errorf("Expected ErrorLevel, got %v", zerolog.GlobalLevel()) - } - LogLevel = "" + assert.NotNil(t, flag, "Global log-level flag should be registered") + assert.Equal(t, "", flag.DefValue, "log-level flag default should be empty string") } -func TestSetGlobalLogLevel_Default(t *testing.T) { - LogDebug = false - LogLevel = "" - setGlobalLogLevel(nil) - if zerolog.GlobalLevel() != zerolog.InfoLevel { - t.Errorf("Expected InfoLevel for default, got %v", zerolog.GlobalLevel()) +func TestSetGlobalLogLevel(t *testing.T) { + tests := []struct { + name string + logDebug bool + logLevel string + expectedLevel zerolog.Level + }{ + { + name: "verbose flag sets debug level", + logDebug: true, + logLevel: "", + expectedLevel: zerolog.DebugLevel, + }, + { + name: "log-level debug sets debug level", + logDebug: false, + logLevel: "debug", + expectedLevel: zerolog.DebugLevel, + }, + { + name: "log-level info sets info level", + logDebug: false, + logLevel: "info", + expectedLevel: zerolog.InfoLevel, + }, + { + name: "log-level warn sets warn level", + logDebug: false, + logLevel: "warn", + expectedLevel: zerolog.WarnLevel, + }, + { + name: "log-level error sets error level", + logDebug: false, + logLevel: "error", + expectedLevel: zerolog.ErrorLevel, + }, + { + name: "default (no flags) sets info level", + logDebug: false, + logLevel: "", + expectedLevel: zerolog.InfoLevel, + }, + { + name: "invalid log-level defaults to info", + logDebug: false, + logLevel: "invalid", + expectedLevel: zerolog.InfoLevel, + }, + { + name: "log-level takes precedence over verbose flag when both set", + logDebug: true, + logLevel: "info", + expectedLevel: zerolog.InfoLevel, + }, } -} -func TestSetGlobalLogLevel_Invalid(t *testing.T) { - LogDebug = false - LogLevel = "invalid" - setGlobalLogLevel(nil) - if zerolog.GlobalLevel() != zerolog.InfoLevel { - t.Errorf("Expected InfoLevel for invalid, got %v", zerolog.GlobalLevel()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save and restore all global state via defer to prevent test contamination + origLevel := zerolog.GlobalLevel() + origDebug := LogDebug + origLogLevel := LogLevel + defer func() { + zerolog.SetGlobalLevel(origLevel) + LogDebug = origDebug + LogLevel = origLogLevel + }() + + LogDebug = tt.logDebug + LogLevel = tt.logLevel + setGlobalLogLevel(nil) + + assert.Equal(t, tt.expectedLevel, zerolog.GlobalLevel(), + "log level should be %v for logDebug=%v, logLevel=%q", + tt.expectedLevel, tt.logDebug, tt.logLevel) + }) } } func TestGlobalColorFlagRegistered(t *testing.T) { flag := rootCmd.PersistentFlags().Lookup("color") - if flag == nil { - t.Fatal("Global color flag not registered") - return - } - - if flag.DefValue != "true" { - t.Errorf("Expected default value 'true' for color flag, got %s", flag.DefValue) - } + assert.NotNil(t, flag, "Global color flag should be registered") + assert.Equal(t, "true", flag.DefValue, "color flag default should be true") } func TestGlobalConfigFlagRegistered(t *testing.T) { flag := rootCmd.PersistentFlags().Lookup("config") - if flag == nil { - t.Fatal("Global config flag not registered") - } + assert.NotNil(t, flag, "Global config flag should be registered") + assert.Equal(t, "", flag.DefValue, "config flag default should be empty string") } func TestGlobalLogFileFlagRegistered(t *testing.T) { flag := rootCmd.PersistentFlags().Lookup("logfile") - if flag == nil { - t.Fatal("Global logfile flag not registered") - } + assert.NotNil(t, flag, "Global logfile flag should be registered") + assert.Equal(t, "", flag.DefValue, "logfile flag default should be empty string") +} + +func TestGlobalIgnoreProxyFlagRegistered(t *testing.T) { + flag := rootCmd.PersistentFlags().Lookup("ignore-proxy") + assert.NotNil(t, flag, "Global ignore-proxy flag should be registered") + assert.Equal(t, "false", flag.DefValue, "ignore-proxy flag default should be false") +} + +func TestGlobalJSONFlagRegistered(t *testing.T) { + flag := rootCmd.PersistentFlags().Lookup("json") + assert.NotNil(t, flag, "Global json flag should be registered") + assert.Equal(t, "false", flag.DefValue, "json flag default should be false") } func TestPersistentPreRunRegistered(t *testing.T) { - if rootCmd.PersistentPreRun == nil { - t.Fatal("PersistentPreRun should be registered") - } + assert.NotNil(t, rootCmd.PersistentPreRun, "PersistentPreRun should be registered") } func TestTerminalRestorer(t *testing.T) { @@ -337,3 +349,71 @@ func TestGetVersion(t *testing.T) { func TestRootCmdHasVersion(t *testing.T) { assert.NotEmpty(t, rootCmd.Version, "rootCmd should have a version") } + +// TestFormatLevelWithHitColor_ColorEnabled verifies each log level gets the correct +// color escape code when color output is enabled. +func TestFormatLevelWithHitColor_ColorEnabled(t *testing.T) { + formatter := formatLevelWithHitColor(true) + + tests := []struct { + level string + wantCode string + }{ + {"hit", "\x1b[35m"}, // magenta + {"trace", "\x1b[90m"}, // dark grey + {"info", "\x1b[32m"}, // green + {"warn", "\x1b[33m"}, // yellow + {"error", "\x1b[31m"}, // red + {"fatal", "\x1b[31m"}, // red + {"panic", "\x1b[31m"}, // red + {"debug", "debug"}, // no color for debug - returned as-is + } + + for _, tt := range tests { + t.Run("level_"+tt.level, func(t *testing.T) { + result := formatter(tt.level) + assert.Contains(t, result, tt.wantCode, "level=%q should contain color code %q", tt.level, tt.wantCode) + }) + } +} + +// TestFormatLevelWithHitColor_ColorDisabled verifies that every level is returned +// unchanged (no escape codes) when color output is disabled. +func TestFormatLevelWithHitColor_ColorDisabled(t *testing.T) { + formatter := formatLevelWithHitColor(false) + + levels := []string{"hit", "trace", "debug", "info", "warn", "error", "fatal", "panic", "unknown"} + for _, level := range levels { + t.Run("level_"+level, func(t *testing.T) { + result := formatter(level) + assert.Equal(t, level, result, "color disabled: level=%q should be returned unchanged", level) + }) + } +} + +// TestFormatLevelWithHitColor_UnknownLevel verifies that unknown levels are passed through. +func TestFormatLevelWithHitColor_UnknownLevel(t *testing.T) { + formatter := formatLevelWithHitColor(true) + result := formatter("custom-level") + // Unknown levels fall through to the default case which returns the level unchanged + assert.Equal(t, "custom-level", result) +} + +// TestFormatLevelWithHitColor_NonStringInput verifies that non-string input returns "". +func TestFormatLevelWithHitColor_NonStringInput(t *testing.T) { + formatter := formatLevelWithHitColor(true) + result := formatter(42) + assert.Equal(t, "", result) +} + +// TestLoadConfigFile_NoConfigFile verifies that loadConfigFile does not panic or error +// when no config file path is set (default empty string). +func TestLoadConfigFile_NoConfigFile(t *testing.T) { + origConfigFile := ConfigFile + defer func() { ConfigFile = origConfigFile }() + + ConfigFile = "" + assert.NotPanics(t, func() { + loadConfigFile(rootCmd) + }) +} diff --git a/pkg/archive/strings_test.go b/pkg/archive/strings_test.go index b67c41a4..90a2ae4d 100644 --- a/pkg/archive/strings_test.go +++ b/pkg/archive/strings_test.go @@ -10,6 +10,7 @@ import ( ) func TestExtractPrintableStrings(t *testing.T) { + t.Parallel() tests := []struct { name string input []byte @@ -153,6 +154,7 @@ func TestExtractPrintableStrings(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := ExtractPrintableStrings(tt.input, tt.minLength) // Split result by newlines and filter empty strings @@ -169,6 +171,7 @@ func TestExtractPrintableStrings(t *testing.T) { } func TestExtractPrintableStrings_LargeBinary(t *testing.T) { + t.Parallel() // Create a large binary file with embedded secrets var largeBinary bytes.Buffer @@ -193,6 +196,7 @@ func TestExtractPrintableStrings_LargeBinary(t *testing.T) { } func TestExtractPrintableStrings_ASCII(t *testing.T) { + t.Parallel() // Test with ASCII-only strings (UTF-8 bytes are treated as non-printable) input := []byte{0x00, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', 0x00} result := ExtractPrintableStrings(input, 4) @@ -202,6 +206,7 @@ func TestExtractPrintableStrings_ASCII(t *testing.T) { } func TestExtractPrintableStrings_RealWorldBinary(t *testing.T) { + t.Parallel() // Simulate a real-world scenario: a compiled binary with embedded config binary := []byte{ // Some binary header @@ -250,6 +255,7 @@ func TestExtractPrintableStrings_EdgeCases(t *testing.T) { } func TestIsPrintableByte(t *testing.T) { + t.Parallel() tests := []struct { name string b byte @@ -280,6 +286,7 @@ func TestIsPrintableByte(t *testing.T) { } func TestExtractPrintableStrings_MinStringLength(t *testing.T) { + t.Parallel() // Verify the constant value assert.Equal(t, 4, MinStringLength, "MinStringLength should be 4 to match Unix strings command") } @@ -303,6 +310,7 @@ func BenchmarkExtractPrintableStrings(b *testing.B) { } func TestExtractPrintableStrings_SecretPatterns(t *testing.T) { + t.Parallel() // Test extraction of various secret patterns that might be found in binaries tests := []struct { name string @@ -346,6 +354,7 @@ func TestExtractPrintableStrings_SecretPatterns(t *testing.T) { } func TestExtractPrintableStrings_Reproducibility(t *testing.T) { + t.Parallel() // Ensure the function produces consistent results input := []byte{0x00, 0x01, 'T', 'e', 's', 't', 0xFF, 'D', 'a', 't', 'a', 0x00} diff --git a/pkg/config/loader_bind_test.go b/pkg/config/loader_bind_test.go index b23a9cc2..583bc06d 100644 --- a/pkg/config/loader_bind_test.go +++ b/pkg/config/loader_bind_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // resetViper resets the global viper instance for tests. @@ -74,3 +76,99 @@ func TestBindCommandFlags_InheritedFlags(t *testing.T) { t.Fatalf("expected inherited flag value, got %q", got) } } + +func TestAutoBindFlags_LocalFlag(t *testing.T) { + resetViper(t) + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("token", "", "API token") + + err := AutoBindFlags(cmd, map[string]string{"token": "gitlab.token"}) + require.NoError(t, err) + + require.NoError(t, cmd.Flags().Set("token", "my-token")) + assert.Equal(t, "my-token", GetString("gitlab.token")) +} + +func TestAutoBindFlags_InheritedFlag(t *testing.T) { + resetViper(t) + + root := &cobra.Command{Use: "root"} + root.PersistentFlags().String("url", "", "GitLab URL") + + child := &cobra.Command{Use: "child"} + root.AddCommand(child) + + err := AutoBindFlags(child, map[string]string{"url": "gitlab.url"}) + require.NoError(t, err) + + require.NoError(t, root.PersistentFlags().Set("url", "https://gitlab.example.com")) + assert.Equal(t, "https://gitlab.example.com", GetString("gitlab.url")) +} + +func TestAutoBindFlags_UnknownFlagIsIgnored(t *testing.T) { + resetViper(t) + + cmd := &cobra.Command{Use: "test"} + + // A flag that doesn't exist should be silently ignored, not error + err := AutoBindFlags(cmd, map[string]string{"nonexistent-flag": "some.key"}) + assert.NoError(t, err) +} + +func TestAutoBindFlags_MultipleFlags(t *testing.T) { + resetViper(t) + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("token", "", "API token") + cmd.Flags().String("url", "", "URL") + cmd.Flags().Int("threads", 4, "Thread count") + + err := AutoBindFlags(cmd, map[string]string{ + "token": "gitlab.token", + "url": "gitlab.url", + "threads": "common.threads", + }) + require.NoError(t, err) + + require.NoError(t, cmd.Flags().Set("token", "abc123")) + require.NoError(t, cmd.Flags().Set("url", "https://gitlab.com")) + require.NoError(t, cmd.Flags().Set("threads", "8")) + + assert.Equal(t, "abc123", GetString("gitlab.token")) + assert.Equal(t, "https://gitlab.com", GetString("gitlab.url")) + assert.Equal(t, 8, GetInt("common.threads")) +} + +func TestUnmarshalConfig_Defaults(t *testing.T) { + resetViper(t) + + cfg, err := UnmarshalConfig() + require.NoError(t, err) + require.NotNil(t, cfg) + + // Verify default values are populated + assert.Equal(t, 4, cfg.Common.Threads) + assert.True(t, cfg.Common.TruffleHogVerification) + assert.Equal(t, "500Mb", cfg.Common.MaxArtifactSize) + assert.Equal(t, "https://api.github.com", cfg.GitHub.URL) + assert.Equal(t, "https://bitbucket.org", cfg.BitBucket.URL) + assert.Equal(t, "https://dev.azure.com", cfg.AzureDevOps.URL) +} + +func TestUnmarshalConfig_WithSetValues(t *testing.T) { + resetViper(t) + + v := GetViper() + v.Set("gitlab.url", "https://mygitlab.com") + v.Set("gitlab.token", "glpat-secret") + v.Set("common.threads", 16) + + cfg, err := UnmarshalConfig() + require.NoError(t, err) + require.NotNil(t, cfg) + + assert.Equal(t, "https://mygitlab.com", cfg.GitLab.URL) + assert.Equal(t, "glpat-secret", cfg.GitLab.Token) + assert.Equal(t, 16, cfg.Common.Threads) +} diff --git a/pkg/config/validation_test.go b/pkg/config/validation_test.go index 4c76a9ac..1794d1e8 100644 --- a/pkg/config/validation_test.go +++ b/pkg/config/validation_test.go @@ -6,6 +6,7 @@ import ( ) func TestValidateURL(t *testing.T) { + t.Parallel() tests := []struct { name string url string @@ -49,6 +50,7 @@ func TestValidateURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() err := ValidateURL(tt.url, tt.fieldName) if tt.wantError { if err == nil { @@ -66,6 +68,7 @@ func TestValidateURL(t *testing.T) { } func TestParseMaxArtifactSize(t *testing.T) { + t.Parallel() tests := []struct { name string sizeStr string @@ -99,6 +102,7 @@ func TestParseMaxArtifactSize(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() got, err := ParseMaxArtifactSize(tt.sizeStr) if tt.wantError { if err == nil { @@ -117,6 +121,7 @@ func TestParseMaxArtifactSize(t *testing.T) { } func TestValidateToken(t *testing.T) { + t.Parallel() tests := []struct { name string token string @@ -139,6 +144,7 @@ func TestValidateToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() err := ValidateToken(tt.token, tt.fieldName) if tt.wantError && err == nil { t.Errorf("ValidateToken() expected error but got none") @@ -151,6 +157,7 @@ func TestValidateToken(t *testing.T) { } func TestValidateThreadCount(t *testing.T) { + t.Parallel() tests := []struct { name string threads int @@ -190,6 +197,7 @@ func TestValidateThreadCount(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() err := ValidateThreadCount(tt.threads) if tt.wantError && err == nil { t.Errorf("ValidateThreadCount() expected error but got none") diff --git a/pkg/docs/generator_test.go b/pkg/docs/generator_test.go index 9678aad4..d89e9d6e 100644 --- a/pkg/docs/generator_test.go +++ b/pkg/docs/generator_test.go @@ -3,10 +3,12 @@ package docs import ( "os" "path/filepath" + "runtime" "testing" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" ) @@ -173,3 +175,268 @@ func TestWriteMkdocsYaml_GithubPagesPrefix(t *testing.T) { firstItem := introItems[0].(map[string]interface{}) assert.Equal(t, "/pipeleek/introduction/getting_started/", firstItem["Getting Started"]) } + +func TestCopyFile(t *testing.T) { + tmpDir := t.TempDir() + + srcPath := filepath.Join(tmpDir, "source.txt") + dstPath := filepath.Join(tmpDir, "dest.txt") + + content := []byte("hello, test content") + require.NoError(t, os.WriteFile(srcPath, content, 0644)) + + err := copyFile(srcPath, dstPath) + assert.NoError(t, err) + + got, err := os.ReadFile(dstPath) + assert.NoError(t, err) + assert.Equal(t, content, got) +} + +func TestCopyFile_SourceNotExist(t *testing.T) { + tmpDir := t.TempDir() + err := copyFile(filepath.Join(tmpDir, "nonexistent.txt"), filepath.Join(tmpDir, "dst.txt")) + assert.Error(t, err) +} + +func TestCopyFile_DestinationNotWritable(t *testing.T) { + tmpDir := t.TempDir() + + srcPath := filepath.Join(tmpDir, "source.txt") + require.NoError(t, os.WriteFile(srcPath, []byte("data"), 0644)) + + // Try to write to a directory path (should fail) + err := copyFile(srcPath, tmpDir) + assert.Error(t, err) +} + +func TestCopyDir(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + // Create files in source + require.NoError(t, os.WriteFile(filepath.Join(srcDir, "a.txt"), []byte("aaa"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(srcDir, "b.txt"), []byte("bbb"), 0644)) + + // Create a subdirectory + subDir := filepath.Join(srcDir, "sub") + require.NoError(t, os.Mkdir(subDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(subDir, "c.txt"), []byte("ccc"), 0644)) + + err := copyDir(srcDir, dstDir) + assert.NoError(t, err) + + // Verify files were copied + gotA, err := os.ReadFile(filepath.Join(dstDir, "a.txt")) + assert.NoError(t, err) + assert.Equal(t, []byte("aaa"), gotA) + + gotB, err := os.ReadFile(filepath.Join(dstDir, "b.txt")) + assert.NoError(t, err) + assert.Equal(t, []byte("bbb"), gotB) + + gotC, err := os.ReadFile(filepath.Join(dstDir, "sub", "c.txt")) + assert.NoError(t, err) + assert.Equal(t, []byte("ccc"), gotC) +} + +func TestCopyDir_SourceNotExist(t *testing.T) { + dstDir := t.TempDir() + err := copyDir("/nonexistent/path", dstDir) + assert.Error(t, err) +} + +func TestCopySubfolders(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + // Create subdirectories with files in source (only subdirs should be copied, not root files) + sub1 := filepath.Join(srcDir, "sub1") + sub2 := filepath.Join(srcDir, "sub2") + require.NoError(t, os.Mkdir(sub1, 0755)) + require.NoError(t, os.Mkdir(sub2, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(sub1, "file1.txt"), []byte("f1"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(sub2, "file2.txt"), []byte("f2"), 0644)) + + // Root-level file (should NOT be copied by copySubfolders) + require.NoError(t, os.WriteFile(filepath.Join(srcDir, "root.txt"), []byte("root"), 0644)) + + err := copySubfolders(srcDir, dstDir) + assert.NoError(t, err) + + // Subdirectory files should be present + got1, err := os.ReadFile(filepath.Join(dstDir, "sub1", "file1.txt")) + assert.NoError(t, err) + assert.Equal(t, []byte("f1"), got1) + + got2, err := os.ReadFile(filepath.Join(dstDir, "sub2", "file2.txt")) + assert.NoError(t, err) + assert.Equal(t, []byte("f2"), got2) + + // Root-level file should NOT be copied + _, err = os.Stat(filepath.Join(dstDir, "root.txt")) + assert.True(t, os.IsNotExist(err), "root-level files should not be copied by copySubfolders") +} + +func TestCopySubfolders_SourceNotExist(t *testing.T) { + dstDir := t.TempDir() + err := copySubfolders("/nonexistent/path", dstDir) + assert.Error(t, err) +} + +// TestGenerateDocs_LeafCommand verifies that a leaf command creates a .md file. +func TestGenerateDocs_LeafCommand(t *testing.T) { + tmpDir := t.TempDir() + + cmd := &cobra.Command{ + Use: "scan", + Short: "Scan for secrets", + Run: func(cmd *cobra.Command, args []string) {}, + } + + err := generateDocs(cmd, tmpDir, 1, false) + require.NoError(t, err) + + _, err = os.Stat(filepath.Join(tmpDir, "scan.md")) + assert.NoError(t, err, "leaf command should create a .md file") +} + +// TestGenerateDocs_ParentCommand verifies that a parent command creates index.md in a subdirectory. +func TestGenerateDocs_ParentCommand(t *testing.T) { + tmpDir := t.TempDir() + + parent := &cobra.Command{Use: "gitlab", Short: "GitLab commands"} + child := &cobra.Command{ + Use: "scan", + Short: "Scan CI/CD", + Run: func(cmd *cobra.Command, args []string) {}, + } + parent.AddCommand(child) + + err := generateDocs(parent, tmpDir, 0, false) + require.NoError(t, err) + + _, err = os.Stat(filepath.Join(tmpDir, "gitlab", "index.md")) + assert.NoError(t, err, "parent command should create index.md in subdirectory") + + _, err = os.Stat(filepath.Join(tmpDir, "gitlab", "scan.md")) + assert.NoError(t, err, "child command should create scan.md") +} + +// TestGenerateDocs_GithubPages verifies link rewriting for GitHub Pages. +func TestGenerateDocs_GithubPages(t *testing.T) { + tmpDir := t.TempDir() + + cmd := &cobra.Command{ + Use: "scan", + Short: "Scan for secrets", + Run: func(cmd *cobra.Command, args []string) {}, + } + + err := generateDocs(cmd, tmpDir, 1, true) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "scan.md")) + require.NoError(t, err) + assert.NotEmpty(t, content, "generated docs should not be empty") +} + +// TestGenerateDocs_OutputDirNotWritable verifies that an error is returned when the dir is not writable. +func TestGenerateDocs_OutputDirNotWritable(t *testing.T) { + if os.Getuid() == 0 { + t.Skip("skipping: running as root, permission restrictions don't apply") + } + if runtime.GOOS == "windows" { + t.Skip("skipping: read-only directory restriction test is Unix-specific") + } + tmpDir := t.TempDir() + // Create a read-only subdirectory so MkdirAll will fail for children + readonlyDir := filepath.Join(tmpDir, "readonly") + require.NoError(t, os.MkdirAll(readonlyDir, 0500)) + + parent := &cobra.Command{Use: "sub", Short: "sub with children"} + child := &cobra.Command{ + Use: "leaf", + Short: "leaf cmd", + Run: func(cmd *cobra.Command, args []string) {}, + } + parent.AddCommand(child) + + // generateDocs for a parent creates subdirectory - this should fail in the read-only dir + err := generateDocs(parent, readonlyDir, 0, false) + assert.Error(t, err, "should return error when output dir is not writable") +} + +// TestInlineSVGIntoGettingStarted_MissingFile verifies an error is returned when the markdown is missing. +func TestInlineSVGIntoGettingStarted_MissingFile(t *testing.T) { + tmpDir := t.TempDir() + err := inlineSVGIntoGettingStarted(tmpDir) + assert.Error(t, err, "should return error when getting_started.md does not exist") +} + +// TestInlineSVGIntoGettingStarted_NoPlaceholder verifies early return when placeholder is absent. +func TestInlineSVGIntoGettingStarted_NoPlaceholder(t *testing.T) { + tmpDir := t.TempDir() + + introDir := filepath.Join(tmpDir, "introduction") + require.NoError(t, os.MkdirAll(introDir, 0755)) + mdPath := filepath.Join(introDir, "getting_started.md") + require.NoError(t, os.WriteFile(mdPath, []byte("# Getting Started\nNo placeholder here."), 0644)) + + err := inlineSVGIntoGettingStarted(tmpDir) + assert.NoError(t, err, "no placeholder should return nil without error") + + // File should be unchanged + content, err := os.ReadFile(mdPath) + require.NoError(t, err) + assert.Contains(t, string(content), "No placeholder here.") +} + +// TestInlineSVGIntoGettingStarted_MissingSVG verifies an error is returned when the SVG file is missing. +func TestInlineSVGIntoGettingStarted_MissingSVG(t *testing.T) { + tmpDir := t.TempDir() + + introDir := filepath.Join(tmpDir, "introduction") + require.NoError(t, os.MkdirAll(introDir, 0755)) + mdPath := filepath.Join(introDir, "getting_started.md") + placeholder := "" + require.NoError(t, os.WriteFile(mdPath, []byte("# Getting Started\n"+placeholder), 0644)) + + // Change to tmpDir so the SVG file path ("docs/pipeleek-anim.svg") is relative to it + origDir, _ := os.Getwd() + require.NoError(t, os.Chdir(tmpDir)) + defer func() { _ = os.Chdir(origDir) }() + + err := inlineSVGIntoGettingStarted(tmpDir) + assert.Error(t, err, "should return error when SVG file does not exist") +} + +// TestInlineSVGIntoGettingStarted_ReplacesPlaceholder verifies SVG content is inlined. +func TestInlineSVGIntoGettingStarted_ReplacesPlaceholder(t *testing.T) { + tmpDir := t.TempDir() + + introDir := filepath.Join(tmpDir, "introduction") + require.NoError(t, os.MkdirAll(introDir, 0755)) + placeholder := "" + mdContent := "# Getting Started\n" + placeholder + "\nEnd of file." + mdPath := filepath.Join(introDir, "getting_started.md") + require.NoError(t, os.WriteFile(mdPath, []byte(mdContent), 0644)) + + // Create a mock SVG at the expected relative path "docs/pipeleek-anim.svg" + docsDir := filepath.Join(tmpDir, "docs") + require.NoError(t, os.MkdirAll(docsDir, 0755)) + svgContent := `` + require.NoError(t, os.WriteFile(filepath.Join(docsDir, "pipeleek-anim.svg"), []byte(svgContent), 0644)) + + origDir, _ := os.Getwd() + require.NoError(t, os.Chdir(tmpDir)) + defer func() { _ = os.Chdir(origDir) }() + + err := inlineSVGIntoGettingStarted(tmpDir) + require.NoError(t, err) + + result, err := os.ReadFile(mdPath) + require.NoError(t, err) + assert.Contains(t, string(result), svgContent, "SVG content should be inlined") + assert.NotContains(t, string(result), placeholder, "placeholder should be replaced") +} diff --git a/pkg/format/format_test.go b/pkg/format/format_test.go index 0af92f36..0855c689 100644 --- a/pkg/format/format_test.go +++ b/pkg/format/format_test.go @@ -9,6 +9,7 @@ import ( ) func TestCalculateZipFileSize(t *testing.T) { + t.Parallel() tests := []struct { name string setup func() []byte @@ -69,6 +70,7 @@ func TestCalculateZipFileSize(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() data := tt.setup() result := CalculateZipFileSize(data) if result != tt.expected { @@ -79,6 +81,7 @@ func TestCalculateZipFileSize(t *testing.T) { } func TestParseHumanSize(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -119,6 +122,7 @@ func TestParseHumanSize(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result, err := ParseHumanSize(tt.input) if tt.expectError { assert.Error(t, err) diff --git a/pkg/format/path_test.go b/pkg/format/path_test.go index b9f1935b..5cb0d1f5 100644 --- a/pkg/format/path_test.go +++ b/pkg/format/path_test.go @@ -7,6 +7,7 @@ import ( ) func TestIsDirectory(t *testing.T) { + t.Parallel() tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "testfile.txt") @@ -48,6 +49,7 @@ func TestIsDirectory(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := IsDirectory(tt.path) if result != tt.expected { t.Errorf("IsDirectory(%q) = %v, want %v", tt.path, result, tt.expected) diff --git a/pkg/format/string_test.go b/pkg/format/string_test.go index b1b11d9f..94561f1b 100644 --- a/pkg/format/string_test.go +++ b/pkg/format/string_test.go @@ -6,6 +6,8 @@ import ( ) func TestContainsI(t *testing.T) { + t.Parallel() + tests := []struct { name string a string @@ -64,6 +66,7 @@ func TestContainsI(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := ContainsI(tt.a, tt.b) if result != tt.expected { t.Errorf("ContainsI(%q, %q) = %v, want %v", tt.a, tt.b, result, tt.expected) @@ -73,6 +76,7 @@ func TestContainsI(t *testing.T) { } func TestGetPlatformAgnosticNewline(t *testing.T) { + t.Parallel() result := GetPlatformAgnosticNewline() if runtime.GOOS == "windows" { @@ -87,6 +91,7 @@ func TestGetPlatformAgnosticNewline(t *testing.T) { } func TestRandomStringN(t *testing.T) { + t.Parallel() tests := []struct { name string length int diff --git a/pkg/format/time_test.go b/pkg/format/time_test.go index 10582eb5..4c23aa2a 100644 --- a/pkg/format/time_test.go +++ b/pkg/format/time_test.go @@ -3,40 +3,57 @@ package format import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestParseISO8601(t *testing.T) { + t.Parallel() + tests := []struct { - name string - input string - shouldErr bool + name string + input string + expected time.Time }{ { - name: "valid RFC3339 format", - input: "2023-01-15T10:30:00Z", - shouldErr: false, + name: "valid RFC3339 UTC", + input: "2023-01-15T10:30:00Z", + expected: time.Date(2023, 1, 15, 10, 30, 0, 0, time.UTC), + }, + { + name: "valid RFC3339 with positive timezone offset", + input: "2023-01-15T10:30:00+01:00", + expected: func() time.Time { + loc := time.FixedZone("", 3600) + return time.Date(2023, 1, 15, 10, 30, 0, 0, loc) + }(), + }, + { + name: "valid RFC3339 with negative timezone offset", + input: "2023-01-15T10:30:00-05:00", + expected: func() time.Time { + loc := time.FixedZone("", -18000) + return time.Date(2023, 1, 15, 10, 30, 0, 0, loc) + }(), }, { - name: "valid RFC3339 with timezone", - input: "2023-01-15T10:30:00+01:00", - shouldErr: false, + name: "start of epoch", + input: "1970-01-01T00:00:00Z", + expected: time.Unix(0, 0).UTC(), }, { - name: "valid RFC3339 with milliseconds", - input: "2023-01-15T10:30:00.123Z", - shouldErr: false, + name: "end of year", + input: "2023-12-31T23:59:59Z", + expected: time.Date(2023, 12, 31, 23, 59, 59, 0, time.UTC), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if !tt.shouldErr { - result := ParseISO8601(tt.input) - expected, _ := time.Parse(time.RFC3339, tt.input) - if !result.Equal(expected) { - t.Errorf("ParseISO8601(%q) = %v, want %v", tt.input, result, expected) - } - } + t.Parallel() + result := ParseISO8601(tt.input) + assert.True(t, result.Equal(tt.expected), + "ParseISO8601(%q) = %v, want %v", tt.input, result, tt.expected) }) } } diff --git a/pkg/github/renovate/enum/enum.go b/pkg/github/renovate/enum/enum.go index bcb59f14..6e49ddd4 100644 --- a/pkg/github/renovate/enum/enum.go +++ b/pkg/github/renovate/enum/enum.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/CompassSecurity/pipeleek/pkg/format" + "github.com/CompassSecurity/pipeleek/pkg/httpclient" pkgrenovate "github.com/CompassSecurity/pipeleek/pkg/renovate" "github.com/google/go-github/v69/github" "github.com/rs/zerolog/log" @@ -39,7 +40,7 @@ func RunEnumerate(client *github.Client, opts EnumOptions) { ctx := context.Background() if opts.ExtendRenovateConfigService != "" { - err := pkgrenovate.ValidateRenovateConfigService(opts.ExtendRenovateConfigService) + err := pkgrenovate.ValidateRenovateConfigService(opts.ExtendRenovateConfigService, httpclient.GetPipeleekHTTPClient("", nil, nil)) if err != nil { log.Fatal().Stack().Err(err).Msg("Invalid extendRenovateConfigService URL") } @@ -315,7 +316,7 @@ func identifyRenovateBotWorkflow(ctx context.Context, client *github.Client, rep if opts.ExtendRenovateConfigService != "" { // Replace any occurrence of "local>" with "github>" this is best effort configFileContent = strings.ReplaceAll(configFileContent, "local>", "github>") - configFileContent = pkgrenovate.ExtendRenovateConfig(configFileContent, opts.ExtendRenovateConfigService, repo.GetHTMLURL()) + configFileContent = pkgrenovate.ExtendRenovateConfig(configFileContent, opts.ExtendRenovateConfigService, repo.GetHTMLURL(), httpclient.GetPipeleekHTTPClient("", nil, nil)) } } @@ -330,7 +331,7 @@ func identifyRenovateBotWorkflow(ctx context.Context, client *github.Client, rep selfHostedConfigFile := false if configFile != nil { - opts.SelfHostedOptions = pkgrenovate.FetchCurrentSelfHostedOptions(opts.SelfHostedOptions) + opts.SelfHostedOptions = pkgrenovate.FetchCurrentSelfHostedOptions(opts.SelfHostedOptions, httpclient.GetPipeleekHTTPClient("", nil, nil)) selfHostedConfigFile = pkgrenovate.IsSelfHostedConfig(configFileContent, opts.SelfHostedOptions) } diff --git a/pkg/gitlab/renovate/autodiscovery/autodiscovery_test.go b/pkg/gitlab/renovate/autodiscovery/autodiscovery_test.go index 9a74d2bb..c984dbbf 100644 --- a/pkg/gitlab/renovate/autodiscovery/autodiscovery_test.go +++ b/pkg/gitlab/renovate/autodiscovery/autodiscovery_test.go @@ -473,3 +473,22 @@ func TestContentQuality(t *testing.T) { } }) } + +func TestRunGenerate_WithUsername(t *testing.T) { + createdFiles := make(map[string]fileInfo) + memberAdded := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" && strings.Contains(r.URL.Path, "/members") { + memberAdded = true + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":456,"access_level":30}`)) + return + } + createMockGitLabServer(createdFiles).Config.Handler.ServeHTTP(w, r) + })) + defer server.Close() + + RunGenerate(server.URL, "test-token", "test-repo", "renovate-bot", false) + + assert.True(t, memberAdded, "invite should be called when username is provided") +} diff --git a/pkg/gitlab/renovate/enum/enum.go b/pkg/gitlab/renovate/enum/enum.go index dea61a11..7b2e0a30 100644 --- a/pkg/gitlab/renovate/enum/enum.go +++ b/pkg/gitlab/renovate/enum/enum.go @@ -4,8 +4,6 @@ import ( b64 "encoding/base64" "encoding/json" "fmt" - "io" - "net/url" "os" "path/filepath" "regexp" @@ -54,7 +52,9 @@ func RunEnumerate(opts EnumOptions) { log.Info().Str("service", opts.ExtendRenovateConfigService).Msg("Using renovate config extension service") } - validateOrderBy(opts.OrderBy) + if err := validateOrderBy(opts.OrderBy); err != nil { + log.Fatal().Stack().Err(err).Msg("Invalid order-by value") + } if opts.Repository != "" { scanSingleProject(git, opts.Repository, opts) @@ -164,7 +164,7 @@ func identifyRenovateBotJob(git *gitlab.Client, project *gitlab.Project, opts En if configFile != nil { filename = configFile.FileName } - dumpConfigFileContents(project, ciCdYml, configFileContent, filename) + dumpConfigFileContents(project, ciCdYml, configFileContent, filename, "renovate-enum-out") } selfHostedConfigFile := false @@ -303,31 +303,7 @@ func detectRenovateConfigFile(git *gitlab.Client, project *gitlab.Project) (*git } func fetchCurrentSelfHostedOptions(opts EnumOptions) []string { - if len(opts.SelfHostedOptions) > 0 { - return opts.SelfHostedOptions - } - - log.Debug().Msg("Fetching current self-hosted configuration from GitHub") - - client := httpclient.GetPipeleekHTTPClient("", nil, nil) - res, err := client.Get("https://raw.githubusercontent.com/renovatebot/renovate/refs/heads/main/docs/usage/self-hosted-configuration.md") - if err != nil { - log.Fatal().Stack().Err(err).Msg("Failed fetching self-hosted configuration documentation") - return []string{} - } - defer func() { _ = res.Body.Close() }() - if res.StatusCode != 200 { - log.Fatal().Int("status", res.StatusCode).Msg("Failed fetching self-hosted configuration documentation") - return []string{} - } - data, err := io.ReadAll(res.Body) - if err != nil { - log.Fatal().Stack().Err(err).Msg("Failed reading self-hosted configuration documentation") - return []string{} - } - - opts.SelfHostedOptions = extractSelfHostedOptions(data) - return opts.SelfHostedOptions + return renovateutil.FetchCurrentSelfHostedOptions(opts.SelfHostedOptions, httpclient.GetPipeleekHTTPClient("", nil, nil)) } func extractSelfHostedOptions(data []byte) []string { @@ -354,36 +330,15 @@ func isSelfHostedConfig(config string, opts EnumOptions) bool { } func extendRenovateConfig(renovateConfig string, project *gitlab.Project, opts EnumOptions) string { - return renovateutil.ExtendRenovateConfig(renovateConfig, opts.ExtendRenovateConfigService, project.WebURL) + return renovateutil.ExtendRenovateConfig(renovateConfig, opts.ExtendRenovateConfigService, project.WebURL, httpclient.GetPipeleekHTTPClient("", nil, nil)) } func validateRenovateConfigService(serviceUrl string) error { - client := httpclient.GetPipeleekHTTPClient("", nil, nil) - - u, err := url.Parse(serviceUrl) - if err != nil { - log.Error().Stack().Err(err).Msg("Failed to parse renovate config service URL") - return err - } - u = u.JoinPath("health") - - resp, err := client.Get(u.String()) - - if err != nil { - log.Error().Stack().Err(err).Msg("Renovate config service healthcheck failed") - return err - } - - if resp.StatusCode != 200 { - log.Error().Int("status", resp.StatusCode).Str("endpoint", u.String()).Msg("Renovate config service healthcheck failed") - return fmt.Errorf("renovate config service healthcheck failed: %d", resp.StatusCode) - } - - return nil + return renovateutil.ValidateRenovateConfigService(serviceUrl, httpclient.GetPipeleekHTTPClient("", nil, nil)) } -func dumpConfigFileContents(project *gitlab.Project, ciCdYml string, renovateConfigFile string, renovateConfigFileName string) { - projectDir := filepath.Join("renovate-enum-out", project.PathWithNamespace) +func dumpConfigFileContents(project *gitlab.Project, ciCdYml string, renovateConfigFile string, renovateConfigFileName string, outDir string) { + projectDir := filepath.Join(outDir, project.PathWithNamespace) if err := os.MkdirAll(projectDir, 0700); err != nil { log.Fatal().Err(err).Str("dir", projectDir).Msg("Failed to create project directory") } else { @@ -407,11 +362,12 @@ func dumpConfigFileContents(project *gitlab.Project, ciCdYml string, renovateCon } } -func validateOrderBy(orderBy string) { +func validateOrderBy(orderBy string) error { allowedOrderBy := map[string]struct{}{ "id": {}, "name": {}, "path": {}, "created_at": {}, "updated_at": {}, "star_count": {}, "last_activity_at": {}, "similarity": {}, } if _, ok := allowedOrderBy[orderBy]; !ok { - log.Fatal().Str("orderBy", orderBy).Msg("Invalid value for --order-by. Allowed: id, name, path, created_at, updated_at, star_count, last_activity_at, similarity") + return fmt.Errorf("invalid value for --order-by: %q. Allowed: id, name, path, created_at, updated_at, star_count, last_activity_at, similarity", orderBy) } + return nil } diff --git a/pkg/gitlab/renovate/enum/enum_test.go b/pkg/gitlab/renovate/enum/enum_test.go index 7e3961f4..11d7998c 100644 --- a/pkg/gitlab/renovate/enum/enum_test.go +++ b/pkg/gitlab/renovate/enum/enum_test.go @@ -1,9 +1,13 @@ package renovate import ( + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + gitlab "gitlab.com/gitlab-org/api/client-go" ) func TestExtractSelfHostedOptions(t *testing.T) { @@ -114,9 +118,9 @@ Controls where Renovate installs binaries.`) func TestValidateOrderBy(t *testing.T) { tests := []struct { - name string - orderBy string - shouldFail bool + name string + orderBy string + expectError bool }{ {"accepts id", "id", false}, {"accepts name", "name", false}, @@ -126,14 +130,19 @@ func TestValidateOrderBy(t *testing.T) { {"accepts star_count", "star_count", false}, {"accepts last_activity_at", "last_activity_at", false}, {"accepts similarity", "similarity", false}, + {"rejects invalid value", "random", true}, + {"rejects empty string", "", true}, + {"rejects uppercase variant", "Name", true}, + {"rejects SQL injection attempt", "id; DROP TABLE", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if !tt.shouldFail { - assert.NotPanics(t, func() { - validateOrderBy(tt.orderBy) - }, "Valid orderBy values should not panic") + err := validateOrderBy(tt.orderBy) + if tt.expectError { + assert.Error(t, err, "expected error for orderBy=%q", tt.orderBy) + } else { + assert.NoError(t, err, "expected no error for orderBy=%q", tt.orderBy) } }) } @@ -147,9 +156,7 @@ func TestValidOrderByValues(t *testing.T) { for _, value := range validValues { t.Run("validates_"+value, func(t *testing.T) { - assert.NotPanics(t, func() { - validateOrderBy(value) - }, "orderBy=%s should be valid", value) + assert.NoError(t, validateOrderBy(value), "orderBy=%s should be valid", value) }) } } @@ -288,3 +295,68 @@ func TestIsSelfHostedConfig(t *testing.T) { }) } } + +func TestDumpConfigFileContents_CreatesFiles(t *testing.T) { + tmpDir := t.TempDir() + project := &gitlab.Project{PathWithNamespace: "myorg/myproject"} + + dumpConfigFileContents(project, "# CI/CD YAML content", `{"extends":["config:base"]}`, "renovate.json", tmpDir) + + // Verify CI/CD YAML was written + ciCdPath := filepath.Join(tmpDir, "myorg", "myproject", "gitlab-ci.yml") + data, err := os.ReadFile(ciCdPath) + require.NoError(t, err) + assert.Equal(t, "# CI/CD YAML content", string(data)) + + // Verify renovate config was written + configPath := filepath.Join(tmpDir, "myorg", "myproject", "renovate.json") + data, err = os.ReadFile(configPath) + require.NoError(t, err) + assert.Equal(t, `{"extends":["config:base"]}`, string(data)) +} + +func TestDumpConfigFileContents_DefaultsFilenameToRenovateJSON(t *testing.T) { + tmpDir := t.TempDir() + project := &gitlab.Project{PathWithNamespace: "org/repo"} + + // Empty filename should default to renovate.json + dumpConfigFileContents(project, "", `{"key":"val"}`, "", tmpDir) + + configPath := filepath.Join(tmpDir, "org", "repo", "renovate.json") + data, err := os.ReadFile(configPath) + require.NoError(t, err) + assert.Equal(t, `{"key":"val"}`, string(data)) +} + +func TestDumpConfigFileContents_SkipsEmptyContent(t *testing.T) { + tmpDir := t.TempDir() + project := &gitlab.Project{PathWithNamespace: "org/repo"} + + // Both cicd and config are empty: files should NOT be created + dumpConfigFileContents(project, "", "", "renovate.json", tmpDir) + + ciCdPath := filepath.Join(tmpDir, "org", "repo", "gitlab-ci.yml") + _, err := os.Stat(ciCdPath) + assert.True(t, os.IsNotExist(err), "gitlab-ci.yml should not be created when ciCdYml is empty") + + configPath := filepath.Join(tmpDir, "org", "repo", "renovate.json") + _, err = os.Stat(configPath) + assert.True(t, os.IsNotExist(err), "renovate.json should not be created when content is empty") +} + +func TestDumpConfigFileContents_OnlyCICD(t *testing.T) { + tmpDir := t.TempDir() + project := &gitlab.Project{PathWithNamespace: "org/repo"} + + dumpConfigFileContents(project, "ci: content", "", "", tmpDir) + + ciCdPath := filepath.Join(tmpDir, "org", "repo", "gitlab-ci.yml") + data, err := os.ReadFile(ciCdPath) + require.NoError(t, err) + assert.Equal(t, "ci: content", string(data)) + + // renovate.json should NOT be created + configPath := filepath.Join(tmpDir, "org", "repo", "renovate.json") + _, err = os.Stat(configPath) + assert.True(t, os.IsNotExist(err)) +} diff --git a/pkg/gitlab/runners/list/list_test.go b/pkg/gitlab/runners/list/list_test.go index 26eb2c86..d15c0b5f 100644 --- a/pkg/gitlab/runners/list/list_test.go +++ b/pkg/gitlab/runners/list/list_test.go @@ -1,6 +1,10 @@ package runners import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -352,3 +356,102 @@ func TestCountRunnersBySource(t *testing.T) { }) } } + +func TestListProjectRunners(t *testing.T) { + // Mock server for projects and runners + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch { + case strings.Contains(r.URL.Path, "/api/v4/projects") && !strings.Contains(r.URL.Path, "/runners"): + // Return a single project page + projects := []*gitlab.Project{{ID: 1, Name: "test-project", PathWithNamespace: "org/test-project"}} + _ = json.NewEncoder(w).Encode(projects) + + case strings.Contains(r.URL.Path, "/projects/") && strings.Contains(r.URL.Path, "/runners"): + // Return runners for project + runners := []*gitlab.Runner{{ID: 100, Name: "runner-1"}, {ID: 101, Name: "runner-2"}} + _ = json.NewEncoder(w).Encode(runners) + + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + require.NoError(t, err) + + result := listProjectRunners(client) + assert.Len(t, result, 2, "Should return 2 runners") + + _, ok := result[100] + assert.True(t, ok, "Runner 100 should be in the result") + _, ok = result[101] + assert.True(t, ok, "Runner 101 should be in the result") +} + +func TestListProjectRunners_Empty(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if strings.Contains(r.URL.Path, "/api/v4/projects") { + // Return empty project list + _ = json.NewEncoder(w).Encode([]*gitlab.Project{}) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + require.NoError(t, err) + + result := listProjectRunners(client) + assert.Empty(t, result, "Should return empty map for no projects") +} + +func TestListGroupRunners(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch { + case r.URL.Path == "/api/v4/groups": + groups := []*gitlab.Group{{ID: 10, Name: "test-group"}} + _ = json.NewEncoder(w).Encode(groups) + + case strings.Contains(r.URL.Path, "/api/v4/groups/") && strings.Contains(r.URL.Path, "/runners"): + runners := []*gitlab.Runner{{ID: 200, Name: "group-runner-1"}} + _ = json.NewEncoder(w).Encode(runners) + + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + require.NoError(t, err) + + result := listGroupRunners(client) + assert.Len(t, result, 1, "Should return 1 group runner") + _, ok := result[200] + assert.True(t, ok, "Runner 200 should be in the result") +} + +func TestListGroupRunners_Empty(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.URL.Path == "/api/v4/groups" { + _ = json.NewEncoder(w).Encode([]*gitlab.Group{}) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + require.NoError(t, err) + + result := listGroupRunners(client) + assert.Empty(t, result, "Should return empty map for no groups") +} diff --git a/pkg/gitlab/scan/pipeline_test.go b/pkg/gitlab/scan/pipeline_test.go new file mode 100644 index 00000000..91c9c727 --- /dev/null +++ b/pkg/gitlab/scan/pipeline_test.go @@ -0,0 +1,60 @@ +package scan + +import ( + "net/http" + "net/http/httptest" + "testing" + + gitlab "gitlab.com/gitlab-org/api/client-go" +) + +func TestGetJobUrl(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + project := &gitlab.Project{PathWithNamespace: "myorg/myproject"} + job := &gitlab.Job{ID: 42} + + url := getJobUrl(client, project, job) + + // Should contain the host and job path + if url == "" { + t.Fatal("expected non-empty URL") + } + + expected := "myorg/myproject/-/jobs/42" + if len(url) < len(expected) { + t.Fatalf("expected URL to contain %q, got %q", expected, url) + } + + found := false + for i := 0; i <= len(url)-len(expected); i++ { + if url[i:i+len(expected)] == expected { + found = true + break + } + } + if !found { + t.Fatalf("expected URL to contain %q, got %q", expected, url) + } +} + +func TestGetQueueStatus_NilQueue(t *testing.T) { + // Save original queue state + original := globQueue + defer func() { globQueue = original }() + + // When queue is nil, should return 0 + globQueue = nil + status := GetQueueStatus() + if status != 0 { + t.Fatalf("expected 0 when queue is nil, got %d", status) + } +} diff --git a/pkg/gitlab/scan/queue_test.go b/pkg/gitlab/scan/queue_test.go index 8a8017e8..72f136f0 100644 --- a/pkg/gitlab/scan/queue_test.go +++ b/pkg/gitlab/scan/queue_test.go @@ -2,7 +2,11 @@ package scan import ( "bytes" + "compress/gzip" "encoding/json" + "net/http" + "net/http/httptest" + "os" "strings" "sync" "testing" @@ -11,6 +15,8 @@ import ( "github.com/nsqio/go-diskqueue" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" gitlab "gitlab.com/gitlab-org/api/client-go" ) @@ -105,3 +111,157 @@ func TestEnqueueItem_Marshaling(t *testing.T) { } wg.Wait() } + +func TestDownloadEnvArtifact_404Response(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + result := DownloadEnvArtifact("session-cookie", srv.URL, "owner/repo", 42) + if len(result) != 0 { + t.Fatalf("expected empty result on 404, got %d bytes", len(result)) + } +} + +func TestDownloadEnvArtifact_PlainTextResponse(t *testing.T) { + envContent := []byte("MY_VAR=secret_value\nOTHER=other_value\n") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the expected query parameter + if r.URL.Query().Get("file_type") != "dotenv" { + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(envContent) + })) + defer srv.Close() + + result := DownloadEnvArtifact("session-cookie", srv.URL, "owner/repo", 42) + // Plain text content is detected as unknown file type by filetype.Match, + // which triggers the "unexpected" error branch, returning empty bytes. + assert.Empty(t, result, "plain text response should return empty bytes due to unknown filetype") +} + +func TestDownloadEnvArtifact_GzipResponse(t *testing.T) { + envContent := "MY_VAR=secret_value\nOTHER=other_value\n" + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, err := gz.Write([]byte(envContent)) + if err != nil { + t.Fatalf("failed to create gzip: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("failed to close gzip: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(buf.Bytes()) + })) + defer srv.Close() + + result := DownloadEnvArtifact("session-cookie", srv.URL, "owner/repo", 42) + if string(result) != envContent { + t.Fatalf("expected decompressed content %q, got %q", envContent, string(result)) + } +} + +func TestDownloadEnvArtifact_URLBuildFailure(t *testing.T) { + // A URL that will cause join to fail or produce an unreachable host + result := DownloadEnvArtifact("cookie", "://invalid-url", "owner/repo", 1) + if len(result) != 0 { + t.Fatalf("expected empty result for bad URL, got %d bytes", len(result)) + } +} + +func TestSetupQueue_DefaultTempDir(t *testing.T) { + opts := &ScanOptions{QueueFolder: ""} + q, queueFile := setupQueue(opts) + defer func() { _ = q.Close(); _ = os.Remove(queueFile) }() + + assert.NotNil(t, q, "queue should not be nil") + assert.NotEmpty(t, queueFile, "queueFile path should not be empty") + + // Queue should be writable (depth 0 initially) + assert.Equal(t, int64(0), q.Depth()) +} + +func TestSetupQueue_CustomRelativeDir(t *testing.T) { + tmpDir := t.TempDir() + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("failed to get cwd: %v", err) + } + require.NoError(t, os.Chdir(tmpDir)) + defer func() { _ = os.Chdir(origDir) }() + + opts := &ScanOptions{QueueFolder: "custom-queue"} + q, queueFile := setupQueue(opts) + defer func() { _ = q.Close(); _ = os.Remove(queueFile) }() + + assert.NotNil(t, q) + assert.NotEmpty(t, queueFile) +} + +func TestSetupQueue_AbsoluteDir(t *testing.T) { + tmpDir := t.TempDir() + opts := &ScanOptions{QueueFolder: tmpDir} + q, queueFile := setupQueue(opts) + defer func() { _ = q.Close(); _ = os.Remove(queueFile) }() + + assert.NotNil(t, q) + assert.NotEmpty(t, queueFile) +} + +func TestGetQueueStatus_NilGlobQueue(t *testing.T) { + origQueue := globQueue + defer func() { globQueue = origQueue }() + + globQueue = nil + assert.Equal(t, 0, GetQueueStatus(), "nil globQueue should return 0") +} + +func TestGetQueueStatus_WithQueue(t *testing.T) { + origQueue := globQueue + defer func() { globQueue = origQueue }() + + tmpDir := t.TempDir() + opts := &ScanOptions{QueueFolder: tmpDir} + q, queueFile := setupQueue(opts) + defer func() { _ = q.Close(); _ = os.Remove(queueFile) }() + + globQueue = q + // Queue is empty, depth is 0 + assert.Equal(t, 0, GetQueueStatus()) +} + +func TestAnalyzeQueueItem_UnknownType(t *testing.T) { + item := QueueItem{ + Type: QueueItemType("unknown"), + Meta: QueueMeta{ProjectId: 1, JobId: 2}, + } + itemBytes, _ := json.Marshal(item) + + var wg sync.WaitGroup + wg.Add(1) + + // Should not panic and should call wg.Done() via defer + analyzeQueueItem(itemBytes, nil, &ScanOptions{}, &wg) + // wg.Wait() will complete because analyzeQueueItem calls wg.Done() via defer + wg.Wait() +} + +func TestAnalyzeQueueItem_InvalidJSON(t *testing.T) { + var wg sync.WaitGroup + wg.Add(1) + + withCapturedLogs(t, zerolog.ErrorLevel, func(buf *bytes.Buffer) { + analyzeQueueItem([]byte("not-valid-json"), nil, &ScanOptions{}, &wg) + wg.Wait() + assert.Contains(t, buf.String(), "Failed unmarshalling queue item") + }) +} diff --git a/pkg/gitlab/scan/scanner_test.go b/pkg/gitlab/scan/scanner_test.go new file mode 100644 index 00000000..5aec6172 --- /dev/null +++ b/pkg/gitlab/scan/scanner_test.go @@ -0,0 +1,119 @@ +package scan + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInitializeOptions_Valid(t *testing.T) { + opts, err := InitializeOptions( + "https://gitlab.example.com", + "glpat-token", + "", + "search-query", + "", + "", + "/tmp/queue", + "500MB", + true, false, false, true, + 100, 4, + []string{"high"}, + 30*time.Second, + ) + require.NoError(t, err) + assert.Equal(t, "https://gitlab.example.com", opts.GitlabUrl) + assert.Equal(t, "glpat-token", opts.GitlabApiToken) + assert.Equal(t, "search-query", opts.ProjectSearchQuery) + assert.True(t, opts.Artifacts) + assert.False(t, opts.Owned) + assert.Equal(t, 100, opts.JobLimit) + assert.Equal(t, 4, opts.MaxScanGoRoutines) + assert.Equal(t, []string{"high"}, opts.ConfidenceFilter) + assert.Equal(t, 30*time.Second, opts.HitTimeout) + assert.Equal(t, "/tmp/queue", opts.QueueFolder) + assert.True(t, opts.TruffleHogVerification) +} + +func TestInitializeOptions_InvalidURL(t *testing.T) { + _, err := InitializeOptions( + "not-a-valid-url", + "token", "", "", "", "", "/tmp/q", "100MB", + false, false, false, false, 0, 1, nil, 5*time.Second, + ) + assert.Error(t, err) +} + +func TestInitializeOptions_InvalidArtifactSize(t *testing.T) { + _, err := InitializeOptions( + "https://gitlab.example.com", + "token", "", "", "", "", "/tmp/q", "notasize", + false, false, false, false, 0, 1, nil, 5*time.Second, + ) + assert.Error(t, err) +} + +func TestInitializeOptions_ArtifactSizeVariants(t *testing.T) { + tests := []struct { + name string + sizeStr string + wantErr bool + }{ + {"MB size", "100MB", false}, + {"GB size", "1GB", false}, + {"KB size", "500KB", false}, + {"zero bytes", "0", false}, + {"invalid text", "lots", true}, + {"empty string", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := InitializeOptions( + "https://gitlab.example.com", + "token", "", "", "", "", "/tmp/q", tt.sizeStr, + false, false, false, false, 0, 1, nil, 5*time.Second, + ) + if tt.wantErr { + assert.Error(t, err, "expected error for size=%q", tt.sizeStr) + } else { + assert.NoError(t, err, "expected no error for size=%q", tt.sizeStr) + } + }) + } +} + +func TestInitializeOptions_MembersAndOwnedFlags(t *testing.T) { + opts, err := InitializeOptions( + "https://gitlab.example.com", + "token", "", "", "", "", "/tmp/q", "10MB", + false, true, true, false, 0, 1, nil, 5*time.Second, + ) + require.NoError(t, err) + assert.True(t, opts.Owned) + assert.True(t, opts.Member) +} + +func TestInitializeOptions_RepositoryAndNamespace(t *testing.T) { + opts, err := InitializeOptions( + "https://gitlab.example.com", + "token", "", "", "org/repo", "mygroup", "/tmp/q", "10MB", + false, false, false, false, 0, 1, nil, 5*time.Second, + ) + require.NoError(t, err) + assert.Equal(t, "org/repo", opts.Repository) + assert.Equal(t, "mygroup", opts.Namespace) +} + +func TestNewScanner_ReturnsScanner(t *testing.T) { + opts := &ScanOptions{ + GitlabUrl: "https://gitlab.example.com", + GitlabApiToken: "token", + } + s := NewScanner(opts) + assert.NotNil(t, s) + // Before any queue is set, GetQueueStatus must return 0 + assert.Equal(t, 0, s.GetQueueStatus()) +} diff --git a/pkg/gitlab/util/util.go b/pkg/gitlab/util/util.go index 58acb038..65b7d48b 100644 --- a/pkg/gitlab/util/util.go +++ b/pkg/gitlab/util/util.go @@ -11,6 +11,7 @@ import ( "github.com/CompassSecurity/pipeleek/pkg/httpclient" "github.com/PuerkitoBio/goquery" + "github.com/hashicorp/go-retryablehttp" "github.com/headzoo/surf" "github.com/rs/zerolog/log" gitlab "gitlab.com/gitlab-org/api/client-go" @@ -105,40 +106,45 @@ func DetermineVersion(gitlabUrl string, apiToken string) *gitlab.Metadata { return &gitlab.Metadata{Version: "none", Revision: "none", Enterprise: false} } return metadata - } else { - u, err := url.Parse(gitlabUrl) - if err != nil { - log.Error().Stack().Err(err).Msg("Failed determining GitLab version via Website") - return &gitlab.Metadata{Version: "none", Revision: "none", Enterprise: false} - } - u.Path = path.Join(u.Path, "/help") - - client := httpclient.GetPipeleekHTTPClient("", nil, nil) - response, err := client.Get(u.String()) + } - if err != nil { - log.Error().Stack().Err(err).Msg("Failed determining GitLab version via Website") - return &gitlab.Metadata{Version: "none", Revision: "none", Enterprise: false} - } + return fetchVersionFromHTML(gitlabUrl, httpclient.GetPipeleekHTTPClient("", nil, nil)) +} - responseData, err := io.ReadAll(response.Body) - if err != nil { - log.Error().Stack().Err(err).Msg("Failed determining GitLab version via Website") - return &gitlab.Metadata{Version: "none", Revision: "none", Enterprise: false} - } +// fetchVersionFromHTML fetches the GitLab version by scraping the /help page HTML. +// Accepts a retryable HTTP client to allow injection for testing. +func fetchVersionFromHTML(gitlabUrl string, client *retryablehttp.Client) *gitlab.Metadata { + u, err := url.Parse(gitlabUrl) + if err != nil { + log.Error().Stack().Err(err).Msg("Failed determining GitLab version via Website") + return &gitlab.Metadata{Version: "none", Revision: "none", Enterprise: false} + } + u.Path = path.Join(u.Path, "/help") - extractLineR := regexp.MustCompile(`instance_version":"\d*.\d*.\d*"`) - fullLine := extractLineR.Find(responseData) - versionR := regexp.MustCompile(`\d+.\d+.\d+`) - versionNumber := versionR.Find(fullLine) + response, err := client.Get(u.String()) - if len(versionNumber) > 3 { - return &gitlab.Metadata{Version: string(versionNumber), Revision: "none", Enterprise: false} - } + if err != nil { + log.Error().Stack().Err(err).Msg("Failed determining GitLab version via Website") + return &gitlab.Metadata{Version: "none", Revision: "none", Enterprise: false} + } - log.Error().Msg("Failed determining GitLab version via Website") + responseData, err := io.ReadAll(response.Body) + if err != nil { + log.Error().Stack().Err(err).Msg("Failed determining GitLab version via Website") return &gitlab.Metadata{Version: "none", Revision: "none", Enterprise: false} } + + extractLineR := regexp.MustCompile(`instance_version":"\d*.\d*.\d*"`) + fullLine := extractLineR.Find(responseData) + versionR := regexp.MustCompile(`\d+.\d+.\d+`) + versionNumber := versionR.Find(fullLine) + + if len(versionNumber) > 3 { + return &gitlab.Metadata{Version: string(versionNumber), Revision: "none", Enterprise: false} + } + + log.Error().Msg("Failed determining GitLab version via Website") + return &gitlab.Metadata{Version: "none", Revision: "none", Enterprise: false} } func RegisterNewAccount(targetUrl string, username string, password string, email string) { diff --git a/pkg/gitlab/util/util_test.go b/pkg/gitlab/util/util_test.go index c2fd2c3d..f0f3fa28 100644 --- a/pkg/gitlab/util/util_test.go +++ b/pkg/gitlab/util/util_test.go @@ -2,11 +2,13 @@ package util import ( "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" "testing" + "github.com/CompassSecurity/pipeleek/pkg/httpclient" gitlab "gitlab.com/gitlab-org/api/client-go" ) @@ -143,3 +145,201 @@ func TestFetchCICDYml_OtherError(t *testing.T) { t.Fatalf("expected error to contain syntax error, got %q", err.Error()) } } + +// TestIterateProjects ensures pagination calls the callback for each project. +func TestIterateProjects(t *testing.T) { + // Build two pages of projects; page 2 has NextPage=0 to terminate pagination. + page1 := []*gitlab.Project{ + {ID: 1, Name: "project-one"}, + {ID: 2, Name: "project-two"}, + } + page2 := []*gitlab.Project{ + {ID: 3, Name: "project-three"}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, "/api/v4/projects") { + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + pageParam := r.URL.Query().Get("page") + if pageParam == "2" { + // Last page – no X-Next-Page header + _ = json.NewEncoder(w).Encode(page2) + } else { + w.Header().Set("X-Next-Page", "2") + _ = json.NewEncoder(w).Encode(page1) + } + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + var seen []int64 + opts := &gitlab.ListProjectsOptions{} + err = IterateProjects(client, opts, func(p *gitlab.Project) error { + seen = append(seen, p.ID) + return nil + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(seen) != 3 { + t.Fatalf("expected 3 projects, got %d", len(seen)) + } +} + +// TestIterateProjects_CallbackError ensures iteration stops on callback error. +func TestIterateProjects_CallbackError(t *testing.T) { + projects := []*gitlab.Project{{ID: 1, Name: "p1"}, {ID: 2, Name: "p2"}} + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(projects) + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + callCount := 0 + opts := &gitlab.ListProjectsOptions{} + err = IterateProjects(client, opts, func(p *gitlab.Project) error { + callCount++ + return fmt.Errorf("stop iteration") + }) + + if err == nil { + t.Fatal("expected callback error to propagate") + } + if callCount != 1 { + t.Fatalf("expected callback called once before error, got %d", callCount) + } +} + +// TestIterateGroupProjects ensures pagination calls the callback for each group project. +func TestIterateGroupProjects(t *testing.T) { + page1 := []*gitlab.Project{ + {ID: 10, Name: "group-project-one"}, + {ID: 11, Name: "group-project-two"}, + } + page2 := []*gitlab.Project{ + {ID: 12, Name: "group-project-three"}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, "/api/v4/groups/") { + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + pageParam := r.URL.Query().Get("page") + if pageParam == "2" { + _ = json.NewEncoder(w).Encode(page2) + } else { + w.Header().Set("X-Next-Page", "2") + _ = json.NewEncoder(w).Encode(page1) + } + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + var seen []int64 + opts := &gitlab.ListGroupProjectsOptions{} + err = IterateGroupProjects(client, "my-group", opts, func(p *gitlab.Project) error { + seen = append(seen, p.ID) + return nil + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(seen) != 3 { + t.Fatalf("expected 3 projects, got %d", len(seen)) + } +} + +// TestIterateGroupProjects_CallbackError ensures iteration stops on callback error. +func TestIterateGroupProjects_CallbackError(t *testing.T) { + projects := []*gitlab.Project{{ID: 10, Name: "gp1"}} + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(projects) + })) + defer srv.Close() + + client, err := gitlab.NewClient("token", gitlab.WithBaseURL(srv.URL)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + opts := &gitlab.ListGroupProjectsOptions{} + err = IterateGroupProjects(client, "my-group", opts, func(p *gitlab.Project) error { + return fmt.Errorf("stop iteration") + }) + + if err == nil { + t.Fatal("expected callback error to propagate") + } +} + +// TestFetchVersionFromHTML verifies that fetchVersionFromHTML correctly parses the version +// from a mock /help HTML page response. +func TestFetchVersionFromHTML_ParsesVersion(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(``)) + })) + defer srv.Close() + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + meta := fetchVersionFromHTML(srv.URL, client) + if meta.Version != "17.2.1" { + t.Fatalf("expected version 17.2.1, got %s", meta.Version) + } +} + +// TestFetchVersionFromHTML_NoVersion verifies fallback when version is not found. +func TestFetchVersionFromHTML_NoVersion(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`No version here`)) + })) + defer srv.Close() + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + meta := fetchVersionFromHTML(srv.URL, client) + if meta.Version != "none" { + t.Fatalf("expected 'none' version, got %s", meta.Version) + } +} + +// TestFetchVersionFromHTML_BadURL verifies fallback when URL cannot be parsed. +func TestFetchVersionFromHTML_BadURL(t *testing.T) { + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + meta := fetchVersionFromHTML("://bad-url", client) + if meta.Version != "none" { + t.Fatalf("expected 'none' version, got %s", meta.Version) + } +} + +// TestFetchVersionFromHTML_Unreachable verifies fallback when HTTP request fails. +func TestFetchVersionFromHTML_Unreachable(t *testing.T) { + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + meta := fetchVersionFromHTML("http://127.0.0.1:0", client) + if meta.Version != "none" { + t.Fatalf("expected 'none' version, got %s", meta.Version) + } +} diff --git a/pkg/logging/hit_test.go b/pkg/logging/hit_test.go index 4141d54a..b65efda7 100644 --- a/pkg/logging/hit_test.go +++ b/pkg/logging/hit_test.go @@ -3,10 +3,12 @@ package logging import ( "bytes" "encoding/json" + "sync" "testing" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" ) func TestHit(t *testing.T) { @@ -283,3 +285,133 @@ func TestHitLevelWriter_ConcurrentAccess(t *testing.T) { // No panic = mutex protected correctly } + +func TestHitEvent_Bool(t *testing.T) { + var buf bytes.Buffer + hitWriter := NewHitLevelWriter(&buf) + logger := zerolog.New(hitWriter).With().Logger() + log.Logger = logger + globalHitWriter = hitWriter + + Hit().Bool("isSecret", true).Msg("Test bool field") + + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + if err != nil { + t.Fatalf("Failed to parse log output: %v", err) + } + + if logEntry["level"] != "hit" { + t.Errorf("Expected level 'hit', got '%v'", logEntry["level"]) + } + + if val, ok := logEntry["isSecret"].(bool); !ok || !val { + t.Errorf("Expected isSecret=true, got '%v'", logEntry["isSecret"]) + } +} + +func TestHitEvent_Err(t *testing.T) { + var buf bytes.Buffer + hitWriter := NewHitLevelWriter(&buf) + logger := zerolog.New(hitWriter).With().Logger() + log.Logger = logger + globalHitWriter = hitWriter + + Hit().Err(nil).Msg("Test nil error") + + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + if err != nil { + t.Fatalf("Failed to parse log output: %v", err) + } + + if logEntry["level"] != "hit" { + t.Errorf("Expected level 'hit', got '%v'", logEntry["level"]) + } +} + +func TestHitLevelWriter_SetOutput(t *testing.T) { + buf1 := &bytes.Buffer{} + buf2 := &bytes.Buffer{} + + writer := NewHitLevelWriter(buf1) + writer.SetOutput(buf2) + + _, err := writer.Write([]byte("test output\n")) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + + if buf2.String() != "test output\n" { + t.Errorf("Expected output to go to buf2, got: %q", buf2.String()) + } + if buf1.Len() != 0 { + t.Error("Expected buf1 to be empty after SetOutput") + } +} + +func TestSetGlobalHitWriter(t *testing.T) { + // Save original + original := globalHitWriter + defer func() { globalHitWriter = original }() + + buf := &bytes.Buffer{} + writer := NewHitLevelWriter(buf) + SetGlobalHitWriter(writer) + + if globalHitWriter != writer { + t.Error("Expected globalHitWriter to be the new writer") + } +} + +// TestHit_NilGlobalWriter verifies that Hit() initializes globalHitWriter via +// setupGlobalHitWriter when it is nil, and still returns a valid HitEvent. +func TestHit_NilGlobalWriter(t *testing.T) { + // Save original state for cleanup. globalHitWriterOnce is not saved/restored + // because sync.Once contains sync.noCopy and cannot be copied. Restoring + // globalHitWriter is sufficient since Hit() only calls setupGlobalHitWriter + // when globalHitWriter == nil. + origWriter := globalHitWriter + origLogger := log.Logger + defer func() { + globalHitWriter = origWriter + log.Logger = origLogger + }() + + // Force the initialization path by resetting both the writer and the sync.Once. + globalHitWriter = nil + globalHitWriterOnce = sync.Once{} + + event := Hit() + assert.NotNil(t, event, "Hit() should return a non-nil event even when globalHitWriter was nil") + assert.NotNil(t, globalHitWriter, "globalHitWriter should be initialized after Hit() call") +} + +// TestSetupGlobalHitWriter_IdempotentOnce verifies that setupGlobalHitWriter only +// initializes globalHitWriter once even when called multiple times concurrently. +func TestSetupGlobalHitWriter_IdempotentOnce(t *testing.T) { + // See TestHit_NilGlobalWriter for why globalHitWriterOnce is not saved/restored. + origWriter := globalHitWriter + origLogger := log.Logger + defer func() { + globalHitWriter = origWriter + log.Logger = origLogger + }() + + globalHitWriter = nil + globalHitWriterOnce = sync.Once{} + + // Call setupGlobalHitWriter from multiple goroutines concurrently. + done := make(chan struct{}, 5) + for i := 0; i < 5; i++ { + go func() { + setupGlobalHitWriter() + done <- struct{}{} + }() + } + for i := 0; i < 5; i++ { + <-done + } + + assert.NotNil(t, globalHitWriter, "globalHitWriter should be initialized after setupGlobalHitWriter") +} diff --git a/pkg/renovate/common.go b/pkg/renovate/common.go index e95d3a3b..dd8a5cc2 100644 --- a/pkg/renovate/common.go +++ b/pkg/renovate/common.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/CompassSecurity/pipeleek/pkg/format" - "github.com/CompassSecurity/pipeleek/pkg/httpclient" + "github.com/hashicorp/go-retryablehttp" "github.com/rs/zerolog/log" "github.com/yosuke-furukawa/json5/encoding/json5" ) @@ -114,14 +114,14 @@ func DetectAutodiscoveryFilters(cicdConf, configFileContent string) (bool, strin } // FetchCurrentSelfHostedOptions retrieves the list of self-hosted Renovate configuration options. -func FetchCurrentSelfHostedOptions(cachedOptions []string) []string { +// Accepts a retryable HTTP client to allow injection for testing. +func FetchCurrentSelfHostedOptions(cachedOptions []string, client *retryablehttp.Client) []string { if len(cachedOptions) > 0 { return cachedOptions } log.Debug().Msg("Fetching current self-hosted configuration from GitHub") - client := httpclient.GetPipeleekHTTPClient("", nil, nil) res, err := client.Get("https://raw.githubusercontent.com/renovatebot/renovate/refs/heads/main/docs/usage/self-hosted-configuration.md") if err != nil { log.Error().Stack().Err(err).Msg("Failed fetching self-hosted configuration documentation") @@ -166,9 +166,8 @@ func IsSelfHostedConfig(config string, selfHostedOptions []string) bool { // ExtendRenovateConfig extends a Renovate configuration by sending it to a resolver service. // The config is normalized to valid JSON before sending (removes JSON5 comments/trailing commas). -func ExtendRenovateConfig(renovateConfig string, serviceURL string, projectURL string) string { - client := httpclient.GetPipeleekHTTPClient("", nil, nil) - +// Accepts a retryable HTTP client to allow injection for testing. +func ExtendRenovateConfig(renovateConfig string, serviceURL string, projectURL string, client *retryablehttp.Client) string { u, err := url.Parse(serviceURL) if err != nil { log.Error().Stack().Err(err).Str("project", projectURL).Msg("Failed to parse renovate config service URL") @@ -229,9 +228,8 @@ func normalizeRenovateConfig(config string) string { } // ValidateRenovateConfigService checks if the Renovate config resolver service is available. -func ValidateRenovateConfigService(serviceUrl string) error { - client := httpclient.GetPipeleekHTTPClient("", nil, nil) - +// Accepts a retryable HTTP client to allow injection for testing. +func ValidateRenovateConfigService(serviceUrl string, client *retryablehttp.Client) error { u, err := url.Parse(serviceUrl) if err != nil { log.Error().Stack().Err(err).Msg("Failed to parse renovate config service URL") diff --git a/pkg/renovate/common_additional_test.go b/pkg/renovate/common_additional_test.go new file mode 100644 index 00000000..31750f5a --- /dev/null +++ b/pkg/renovate/common_additional_test.go @@ -0,0 +1,179 @@ +package renovate + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/CompassSecurity/pipeleek/pkg/httpclient" + "github.com/stretchr/testify/assert" +) + +func TestNormalizeRenovateConfig_ValidJSON(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "plain JSON", + input: `{"key": "value", "num": 42}`, + expected: `{"key":"value","num":42}`, + }, + { + name: "JSON with whitespace", + input: "{\n \"extends\": [\n \"config:base\"\n ]\n}", + expected: `{"extends":["config:base"]}`, + }, + { + name: "empty object", + input: `{}`, + expected: `{}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeRenovateConfig(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestNormalizeRenovateConfig_JSON5(t *testing.T) { + // JSON5 with trailing commas should be normalized + input := `{ + "extends": ["config:base"], + "prConcurrentLimit": 0, + }` + result := normalizeRenovateConfig(input) + // Should be valid JSON (no trailing commas) + assert.Contains(t, result, `"extends"`) + assert.Contains(t, result, `"prConcurrentLimit"`) +} + +func TestNormalizeRenovateConfig_InvalidInput(t *testing.T) { + // Completely invalid input should be returned unchanged + invalid := `this is not json at all !!!` + result := normalizeRenovateConfig(invalid) + assert.Equal(t, invalid, result) +} + +func TestTryParseJSON_StringValue(t *testing.T) { + val, ok := tryParseJSON(`{"key": "value"}`, "key") + assert.True(t, ok) + assert.Equal(t, "value", val) +} + +func TestTryParseJSON_ArrayValue(t *testing.T) { + val, ok := tryParseJSON(`{"items": ["a","b","c"]}`, "items") + assert.True(t, ok) + assert.Equal(t, `["a","b","c"]`, val) +} + +func TestTryParseJSON_ObjectValue(t *testing.T) { + val, ok := tryParseJSON(`{"nested": {"x": 1}}`, "nested") + assert.True(t, ok) + assert.Contains(t, val, `"x"`) +} + +func TestTryParseJSON_NumberValue(t *testing.T) { + val, ok := tryParseJSON(`{"count": 42}`, "count") + assert.True(t, ok) + assert.Equal(t, "42", val) +} + +func TestTryParseJSON_MissingKey(t *testing.T) { + _, ok := tryParseJSON(`{"other": "value"}`, "missing") + assert.False(t, ok) +} + +func TestTryParseJSON_InvalidJSON(t *testing.T) { + _, ok := tryParseJSON(`not json`, "key") + assert.False(t, ok) +} + +func TestFetchCurrentSelfHostedOptions_Cached(t *testing.T) { + // When cache is non-empty, it should be returned directly without HTTP call + cached := []string{"option1", "option2", "option3"} + // Use a real client - it should never be invoked since the cache returns early + result := FetchCurrentSelfHostedOptions(cached, httpclient.GetPipeleekHTTPClient("", nil, nil)) + assert.Equal(t, cached, result) +} + +func TestExtendRenovateConfig_ServiceUnavailable(t *testing.T) { + // Use a mock server that returns 404 (not 5xx) to avoid triggering retries + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + originalConfig := `{"extends": ["config:base"]}` + result := ExtendRenovateConfig(originalConfig, srv.URL, "owner/repo", httpclient.GetPipeleekHTTPClient("", nil, nil)) + // When service returns non-200, original config should be returned + assert.Equal(t, originalConfig, result) +} + +func TestExtendRenovateConfig_ServiceReturnsExtended(t *testing.T) { + extendedConfig := `{"extends": ["config:base"], "extra": "added"}` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/resolve", r.URL.Path) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(extendedConfig)) + })) + defer srv.Close() + + originalConfig := `{"extends": ["config:base"]}` + result := ExtendRenovateConfig(originalConfig, srv.URL, "owner/repo", httpclient.GetPipeleekHTTPClient("", nil, nil)) + assert.Equal(t, extendedConfig, result) +} + +func TestExtendRenovateConfig_ServiceReturnsError(t *testing.T) { + // Use 400 (client error, non-retryable) to test non-200 fallback + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error": "bad request"}`)) + })) + defer srv.Close() + + originalConfig := `{"extends": ["config:base"]}` + result := ExtendRenovateConfig(originalConfig, srv.URL, "owner/repo", httpclient.GetPipeleekHTTPClient("", nil, nil)) + // On non-200, original config should be returned + assert.Equal(t, originalConfig, result) +} + +func TestExtendRenovateConfig_InvalidServiceURL(t *testing.T) { + originalConfig := `{"extends": ["config:base"]}` + result := ExtendRenovateConfig(originalConfig, "://invalid-url", "owner/repo", httpclient.GetPipeleekHTTPClient("", nil, nil)) + assert.Equal(t, originalConfig, result) +} + +func TestValidateRenovateConfigService_Healthy(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/health", r.URL.Path) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status": "ok"}`)) + })) + defer srv.Close() + + err := ValidateRenovateConfigService(srv.URL, httpclient.GetPipeleekHTTPClient("", nil, nil)) + assert.NoError(t, err) +} + +func TestValidateRenovateConfigService_Unhealthy(t *testing.T) { + // Use 404 (not 5xx) to avoid triggering the retry mechanism + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + err := ValidateRenovateConfigService(srv.URL, httpclient.GetPipeleekHTTPClient("", nil, nil)) + assert.Error(t, err) +} + +func TestValidateRenovateConfigService_InvalidURL(t *testing.T) { + err := ValidateRenovateConfigService("://bad-url", httpclient.GetPipeleekHTTPClient("", nil, nil)) + assert.Error(t, err) +} diff --git a/pkg/renovate/common_test.go b/pkg/renovate/common_test.go index 5cf8c515..c41e872b 100644 --- a/pkg/renovate/common_test.go +++ b/pkg/renovate/common_test.go @@ -1,8 +1,11 @@ package renovate import ( + "net/http" + "net/http/httptest" "testing" + "github.com/CompassSecurity/pipeleek/pkg/httpclient" "github.com/stretchr/testify/assert" ) @@ -396,3 +399,131 @@ func TestRenovateConfigFiles(t *testing.T) { assert.Equal(t, files1, files2) }) } + +// TestFetchCurrentSelfHostedOptions_ReturnsCache verifies that non-empty cached options +// are returned immediately without making an HTTP request. +func TestFetchCurrentSelfHostedOptions_ReturnsCache(t *testing.T) { + cached := []string{"platform", "endpoint"} + // Use a real client - it will never be invoked since the cache returns early + result := FetchCurrentSelfHostedOptions(cached, httpclient.GetPipeleekHTTPClient("", nil, nil)) + assert.Equal(t, cached, result) +} + +// TestFetchCurrentSelfHostedOptions_ParsesResponse verifies that options are extracted +// from a mock HTTP server response. +func TestFetchCurrentSelfHostedOptions_ParsesResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("## platform\n## endpoint\n## binarySource\n")) + })) + defer srv.Close() + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + client.HTTPClient.Transport = &redirectTransport{targetURL: srv.URL} + + result := FetchCurrentSelfHostedOptions([]string{}, client) + assert.NotEmpty(t, result) +} + +// TestFetchCurrentSelfHostedOptions_Non200 verifies that an empty list is returned +// when the server responds with a non-200 status. +func TestFetchCurrentSelfHostedOptions_Non200(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Use 404: the retryablehttp client does not retry 4xx client errors by default + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + client.HTTPClient.Transport = &redirectTransport{targetURL: srv.URL} + + result := FetchCurrentSelfHostedOptions([]string{}, client) + assert.Empty(t, result) +} + +// TestExtendRenovateConfig_Success verifies that a successful response replaces the config. +func TestExtendRenovateConfig_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/resolve", r.URL.Path) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"extends":["config:base","security:openssf-scorecard"]}`)) + })) + defer srv.Close() + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + result := ExtendRenovateConfig(`{"extends":["config:base"]}`, srv.URL, "https://gitlab.example.com/org/repo", client) + assert.Equal(t, `{"extends":["config:base","security:openssf-scorecard"]}`, result) +} + +// TestExtendRenovateConfig_BadURL verifies that the original config is returned on URL parse error. +func TestExtendRenovateConfig_BadURL(t *testing.T) { + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + orig := `{"extends":["config:base"]}` + result := ExtendRenovateConfig(orig, "://bad-url", "https://project.example.com", client) + assert.Equal(t, orig, result) +} + +// TestExtendRenovateConfig_RequestError verifies that the original config is returned on error. +func TestExtendRenovateConfig_RequestError(t *testing.T) { + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + client.RetryMax = 0 + orig := `{"extends":["config:base"]}` + result := ExtendRenovateConfig(orig, "http://127.0.0.1:0", "https://project.example.com", client) + assert.Equal(t, orig, result) +} + +// TestValidateRenovateConfigService_Success verifies that a healthy service returns nil error. +func TestValidateRenovateConfigService_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/health", r.URL.Path) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + err := ValidateRenovateConfigService(srv.URL, client) + assert.NoError(t, err) +} + +// TestValidateRenovateConfigService_Non200 verifies that a non-200 response returns an error. +func TestValidateRenovateConfigService_Non200(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Use 404: the retryablehttp client does not retry 4xx client errors by default + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + err := ValidateRenovateConfigService(srv.URL, client) + assert.Error(t, err) +} + +// TestValidateRenovateConfigService_BadURL verifies that an unparseable URL returns an error. +func TestValidateRenovateConfigService_BadURL(t *testing.T) { + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + err := ValidateRenovateConfigService("://bad-url", client) + assert.Error(t, err) +} + +// TestValidateRenovateConfigService_Unreachable verifies that an unreachable host returns an error. +func TestValidateRenovateConfigService_Unreachable(t *testing.T) { + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + client.RetryMax = 0 + err := ValidateRenovateConfigService("http://127.0.0.1:0", client) + assert.Error(t, err) +} + +// redirectTransport is a test helper that redirects all requests to a fixed target URL, +// preserving path/query from the original request. +type redirectTransport struct { + targetURL string +} + +func (t *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) { + parsed, err := http.NewRequest(req.Method, t.targetURL+req.URL.Path, req.Body) + if err != nil { + return nil, err + } + parsed.Header = req.Header + return http.DefaultTransport.RoundTrip(parsed) +} diff --git a/pkg/renovate/privesc_test.go b/pkg/renovate/privesc_test.go new file mode 100644 index 00000000..db0e077c --- /dev/null +++ b/pkg/renovate/privesc_test.go @@ -0,0 +1,181 @@ +package renovate + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewBranchMonitor(t *testing.T) { + tests := []struct { + name string + pattern string + expectError bool + }{ + { + name: "valid renovate pattern", + pattern: `^renovate/`, + expectError: false, + }, + { + name: "valid complex pattern", + pattern: `renovate/(npm|pip|github-actions).*`, + expectError: false, + }, + { + name: "empty pattern", + pattern: "", + expectError: false, + }, + { + name: "invalid regex pattern", + pattern: `[invalid`, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + monitor, err := NewBranchMonitor(tt.pattern) + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, monitor) + } else { + assert.NoError(t, err) + require.NotNil(t, monitor) + assert.NotNil(t, monitor.originalBranches) + assert.NotNil(t, monitor.regex) + } + }) + } +} + +func TestCheckBranch_FirstScan(t *testing.T) { + monitor, err := NewBranchMonitor(`^renovate/`) + require.NoError(t, err) + + // During first scan, all branches are recorded but not flagged + result := monitor.CheckBranch("renovate/npm-dep", true) + assert.False(t, result, "first scan should never return true") + assert.True(t, monitor.originalBranches["renovate/npm-dep"], "branch should be recorded") + + result = monitor.CheckBranch("main", true) + assert.False(t, result, "first scan should never return true for any branch") + assert.True(t, monitor.originalBranches["main"], "branch should be recorded") +} + +func TestCheckBranch_SubsequentScan_NewRenovateBranch(t *testing.T) { + monitor, err := NewBranchMonitor(`^renovate/`) + require.NoError(t, err) + + // First scan: record existing branches + monitor.CheckBranch("main", true) + monitor.CheckBranch("feature/old", true) + + // Second scan: new renovate branch appears + result := monitor.CheckBranch("renovate/npm-jest-5.x", false) + assert.True(t, result, "new renovate branch should be detected") +} + +func TestCheckBranch_SubsequentScan_ExistingBranch(t *testing.T) { + monitor, err := NewBranchMonitor(`^renovate/`) + require.NoError(t, err) + + // First scan: record existing branches including a renovate one + monitor.CheckBranch("renovate/existing", true) + monitor.CheckBranch("main", true) + + // Second scan: existing renovate branch should not be flagged again + result := monitor.CheckBranch("renovate/existing", false) + assert.False(t, result, "existing branch should not be flagged") +} + +func TestCheckBranch_SubsequentScan_NonMatchingBranch(t *testing.T) { + monitor, err := NewBranchMonitor(`^renovate/`) + require.NoError(t, err) + + // First scan + monitor.CheckBranch("main", true) + + // Second scan: new branch that doesn't match pattern + result := monitor.CheckBranch("feature/new-feature", false) + assert.False(t, result, "non-matching new branch should not be flagged") +} + +func TestCheckBranch_MultiplePatterns(t *testing.T) { + monitor, err := NewBranchMonitor(`(^renovate/|^deps/update)`) + require.NoError(t, err) + + // First scan + monitor.CheckBranch("main", true) + + // Second scan: branches matching different parts of the pattern + assert.True(t, monitor.CheckBranch("renovate/lodash", false)) + assert.True(t, monitor.CheckBranch("deps/update-all", false)) + assert.False(t, monitor.CheckBranch("feature/thing", false)) +} + +func TestGetMonitoringInterval(t *testing.T) { + interval := GetMonitoringInterval() + assert.Equal(t, 1*time.Second, interval) +} + +func TestGetRetryInterval(t *testing.T) { + interval := GetRetryInterval() + assert.Equal(t, 5*time.Second, interval) +} + +func TestLogExploitInstructions(t *testing.T) { + // LogExploitInstructions only logs, ensure it doesn't panic + assert.NotPanics(t, func() { + LogExploitInstructions("renovate/npm-jest-5.x", "main") + }) +} + +func TestValidateRepositoryName(t *testing.T) { + tests := []struct { + name string + repoName string + expected bool + }{ + { + name: "valid owner/repo format", + repoName: "myorg/myrepo", + expected: true, + }, + { + name: "valid with deeper path", + repoName: "myorg/myrepo/subrepo", + expected: true, + }, + { + name: "single word", + repoName: "myrepo", + expected: true, + }, + { + name: "empty string", + repoName: "", + expected: false, + }, + { + name: "just slash", + repoName: "/", + expected: false, + }, + { + name: "leading slash", + repoName: "/myorg/myrepo", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ValidateRepositoryName(tt.repoName) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/scanner/engine/engine.go b/pkg/scanner/engine/engine.go index c9dfcbd7..a162fd63 100644 --- a/pkg/scanner/engine/engine.go +++ b/pkg/scanner/engine/engine.go @@ -115,23 +115,29 @@ func DetectHitsWithTimeout(text []byte, maxThreads int, enableTruffleHogVerifica } func deduplicateFindings(totalFindings []types.Finding) []types.Finding { + deduplicationMutex.Lock() + defer deduplicationMutex.Unlock() + var deduped []types.Finding + deduped, findingsDeduplicationList = deduplicateFindingsWithState(totalFindings, findingsDeduplicationList) + return deduped +} + +// deduplicateFindingsWithState is a pure deduplication function that accepts and returns the seen-hash state. +// This enables testing without relying on the package-level global. +func deduplicateFindingsWithState(totalFindings []types.Finding, seenHashes []string) ([]types.Finding, []string) { dedupedFindings := []types.Finding{} for _, finding := range totalFindings { hash, _ := rxhash.HashStruct(finding) - deduplicationMutex.Lock() - if !slices.Contains(findingsDeduplicationList, hash) { + if !slices.Contains(seenHashes, hash) { dedupedFindings = append(dedupedFindings, finding) - findingsDeduplicationList = append(findingsDeduplicationList, hash) + seenHashes = append(seenHashes, hash) } - if len(findingsDeduplicationList) > 500 { - findingsDeduplicationList[0] = "" - findingsDeduplicationList = findingsDeduplicationList[1:] + if len(seenHashes) > 500 { + seenHashes = seenHashes[1:] } - deduplicationMutex.Unlock() } - - return dedupedFindings + return dedupedFindings, seenHashes } func extractHitWithSurroundingText(text []byte, hitIndex []int, additionalBytes int) string { diff --git a/pkg/scanner/engine/engine_test.go b/pkg/scanner/engine/engine_test.go index 1b0b800d..ccb58343 100644 --- a/pkg/scanner/engine/engine_test.go +++ b/pkg/scanner/engine/engine_test.go @@ -1,6 +1,7 @@ package engine import ( + "fmt" "testing" "time" @@ -195,3 +196,56 @@ func TestCleanHitLine(t *testing.T) { }) } } + +// TestDeduplicateFindingsWithState_NoDependencyOnGlobal verifies that the pure function +// operates without relying on package-level global state. +func TestDeduplicateFindingsWithState_NoDependencyOnGlobal(t *testing.T) { + finding := types.Finding{ + Pattern: types.PatternElement{ + Pattern: types.PatternPattern{Name: "Pattern A", Confidence: "high"}, + }, + Text: "unique_secret_abc123", + } + + // First call: unique finding should be included + deduped, newState := deduplicateFindingsWithState([]types.Finding{finding}, nil) + if len(deduped) != 1 { + t.Fatalf("first call: expected 1 finding, got %d", len(deduped)) + } + if len(newState) != 1 { + t.Fatalf("expected state to have 1 entry, got %d", len(newState)) + } + + // Second call with same finding using the returned state: should be deduplicated + deduped2, _ := deduplicateFindingsWithState([]types.Finding{finding}, newState) + if len(deduped2) != 0 { + t.Fatalf("second call: expected 0 findings (duplicate), got %d", len(deduped2)) + } +} + +// TestDeduplicateFindingsWithState_TrimsAtLimit verifies that the seen-hash list is +// trimmed when it exceeds 500 entries (the previously untested branch). +func TestDeduplicateFindingsWithState_TrimsAtLimit(t *testing.T) { + // Build a state that already has 500 entries + seenHashes := make([]string, 500) + for i := range seenHashes { + seenHashes[i] = fmt.Sprintf("hash-%04d", i) + } + + // Add a new unique finding: the state must grow to 501 and then be trimmed + newFinding := types.Finding{ + Pattern: types.PatternElement{ + Pattern: types.PatternPattern{Name: "NewPattern", Confidence: "medium"}, + }, + Text: "brand_new_secret_xyz", + } + + deduped, newState := deduplicateFindingsWithState([]types.Finding{newFinding}, seenHashes) + if len(deduped) != 1 { + t.Fatalf("expected 1 unique finding, got %d", len(deduped)) + } + // After trim, length should be exactly 500 (grew to 501, first element removed) + if len(newState) != 500 { + t.Fatalf("expected state len 500 after trim, got %d", len(newState)) + } +} diff --git a/pkg/scanner/rules/rules.go b/pkg/scanner/rules/rules.go index 1e6bbfdf..7e0569bc 100644 --- a/pkg/scanner/rules/rules.go +++ b/pkg/scanner/rules/rules.go @@ -9,6 +9,7 @@ import ( "github.com/CompassSecurity/pipeleek/pkg/httpclient" "github.com/CompassSecurity/pipeleek/pkg/scanner/types" + "github.com/hashicorp/go-retryablehttp" "github.com/rs/zerolog/log" "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" "github.com/trufflesecurity/trufflehog/v3/pkg/engine/defaults" @@ -24,7 +25,7 @@ var truffelhogRules []detectors.Detector func DownloadRules() { if _, err := os.Stat(ruleFileName); errors.Is(err, os.ErrNotExist) { log.Debug().Msg("No rules file found, downloading") - err := downloadFile(ruleFile, ruleFileName) + err := downloadFile(ruleFile, ruleFileName, httpclient.GetPipeleekHTTPClient("", nil, nil)) if err != nil { log.Fatal().Stack().Err(err).Msg("Failed downloading rules file") os.Exit(1) @@ -32,7 +33,7 @@ func DownloadRules() { } } -func downloadFile(url string, filepath string) error { +func downloadFile(url string, filepath string, client *retryablehttp.Client) error { // #nosec G304 - Creating file for rules download at controlled internal temp path out, err := os.Create(filepath) if err != nil { @@ -40,7 +41,6 @@ func downloadFile(url string, filepath string) error { } defer func() { _ = out.Close() }() - client := httpclient.GetPipeleekHTTPClient("", nil, nil) resp, err := client.Get(url) if err != nil { return err diff --git a/pkg/scanner/rules/rules_test.go b/pkg/scanner/rules/rules_test.go index 221d53a8..6060c147 100644 --- a/pkg/scanner/rules/rules_test.go +++ b/pkg/scanner/rules/rules_test.go @@ -1,10 +1,16 @@ package rules import ( + "net/http" + "net/http/httptest" "os" + "path/filepath" "testing" + "github.com/CompassSecurity/pipeleek/pkg/httpclient" "github.com/CompassSecurity/pipeleek/pkg/scanner/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAppendPipeleekRules(t *testing.T) { @@ -268,3 +274,47 @@ func TestGetTruffleHogRules_AfterInit(t *testing.T) { t.Error("Expected non-empty TruffleHog rules") } } + +func TestDownloadFile_Success(t *testing.T) { + content := "rules file content from mock" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + })) + defer srv.Close() + + tmpDir := t.TempDir() + destFile := filepath.Join(tmpDir, "rules.yml") + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + err := downloadFile(srv.URL, destFile, client) + require.NoError(t, err) + + data, err := os.ReadFile(destFile) + require.NoError(t, err) + assert.Equal(t, content, string(data)) +} + +func TestDownloadFile_HTTPError(t *testing.T) { + tmpDir := t.TempDir() + destFile := filepath.Join(tmpDir, "rules.yml") + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + client.RetryMax = 0 + + err := downloadFile("http://127.0.0.1:0", destFile, client) + assert.Error(t, err) +} + +func TestDownloadFile_BadOutputPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data")) + })) + defer srv.Close() + + client := httpclient.GetPipeleekHTTPClient("", nil, nil) + // Attempting to write to a path inside a non-existent directory should fail + err := downloadFile(srv.URL, "/nonexistent-dir/rules.yml", client) + assert.Error(t, err) +}