Skip to content

Commit 4c3f356

Browse files
authored
feat: dynamically start oauth flow if no SRC_ACCESS_TOKEN (#1271)
* add AuthMode to config - AuthModeAccessToken and AuthModeOAuth - And check for CI and required AuthModeAccessToken * decide loginFlow based on config AuthMode * remove oauth flag * check for OAuth token precense before starting a new OAuth flow * remove unused ctx
1 parent b5a0552 commit 4c3f356

File tree

7 files changed

+181
-97
lines changed

7 files changed

+181
-97
lines changed

cmd/src/login.go

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,9 @@ Examples:
2929
3030
$ src login https://sourcegraph.com
3131
32-
Use OAuth device flow to authenticate:
32+
If no access token is configured, 'src login' uses OAuth device flow automatically:
3333
34-
$ src login --oauth https://sourcegraph.com
35-
36-
37-
Override the default client id used during device flow when authenticating:
38-
39-
$ src login --oauth https://sourcegraph.com
34+
$ src login https://sourcegraph.com
4035
`
4136

4237
flagSet := flag.NewFlagSet("login", flag.ExitOnError)
@@ -47,7 +42,6 @@ Examples:
4742

4843
var (
4944
apiFlags = api.NewFlags(flagSet)
50-
useOAuth = flagSet.Bool("oauth", false, "Use OAuth device flow to obtain an access token interactively")
5145
)
5246

5347
handler := func(args []string) error {
@@ -69,7 +63,6 @@ Examples:
6963
client: client,
7064
endpoint: endpoint,
7165
out: os.Stdout,
72-
useOAuth: *useOAuth,
7366
apiFlags: apiFlags,
7467
oauthClient: oauth.NewClient(oauth.DefaultClientID),
7568
})
@@ -87,7 +80,6 @@ type loginParams struct {
8780
client api.Client
8881
endpoint string
8982
out io.Writer
90-
useOAuth bool
9183
apiFlags *api.Flags
9284
oauthClient oauth.Client
9385
}
@@ -103,46 +95,31 @@ const (
10395
loginFlowValidate
10496
)
10597

106-
var loadStoredOAuthToken = oauth.LoadToken
107-
10898
func loginCmd(ctx context.Context, p loginParams) error {
10999
if p.cfg.ConfigFilePath != "" {
110100
fmt.Fprintln(p.out)
111101
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.ConfigFilePath)
112102
}
113103

114-
_, flow := selectLoginFlow(ctx, p)
104+
_, flow := selectLoginFlow(p)
115105
return flow(ctx, p)
116106
}
117107

118-
// selectLoginFlow decides what login flow to run based on flags and config.
119-
func selectLoginFlow(ctx context.Context, p loginParams) (loginFlowKind, loginFlow) {
108+
// selectLoginFlow decides what login flow to run based on configured AuthMode.
109+
func selectLoginFlow(p loginParams) (loginFlowKind, loginFlow) {
120110
endpointArg := cleanEndpoint(p.endpoint)
121111

122-
if p.useOAuth {
112+
switch p.cfg.AuthMode() {
113+
case AuthModeOAuth:
123114
return loginFlowOAuth, runOAuthLogin
124-
}
125-
if !hasEffectiveAuth(ctx, p.cfg, endpointArg) {
115+
case AuthModeAccessToken:
116+
if endpointArg != p.cfg.Endpoint {
117+
return loginFlowEndpointConflict, runEndpointConflictLogin
118+
}
119+
return loginFlowValidate, runValidatedLogin
120+
default:
126121
return loginFlowMissingAuth, runMissingAuthLogin
127122
}
128-
if endpointArg != p.cfg.Endpoint {
129-
return loginFlowEndpointConflict, runEndpointConflictLogin
130-
}
131-
return loginFlowValidate, runValidatedLogin
132-
}
133-
134-
// hasEffectiveAuth determines whether we have auth credentials to continue. It first checks for a resolved Access Token in
135-
// config, then it checks for a stored OAuth token.
136-
func hasEffectiveAuth(ctx context.Context, cfg *config, resolvedEndpoint string) bool {
137-
if cfg.AccessToken != "" {
138-
return true
139-
}
140-
141-
if _, err := loadStoredOAuthToken(ctx, resolvedEndpoint); err == nil {
142-
return true
143-
}
144-
145-
return false
146123
}
147124

148125
func printLoginProblem(out io.Writer, problem string) {
@@ -157,6 +134,6 @@ func loginAccessTokenMessage(endpoint string) string {
157134
158135
To verify that it's working, run the login command again.
159136
160-
Alternatively, you can try logging in using OAuth by running: src login --oauth %s
137+
Alternatively, you can try logging in interactively by running: src login %s
161138
`, endpoint, endpoint, endpoint)
162139
}

cmd/src/login_oauth.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"github.com/sourcegraph/src-cli/internal/oauth"
1414
)
1515

16+
var loadStoredOAuthToken = oauth.LoadToken
17+
1618
func runOAuthLogin(ctx context.Context, p loginParams) error {
1719
endpointArg := cleanEndpoint(p.endpoint)
1820
client, err := oauthLoginClient(ctx, p, endpointArg)
@@ -32,7 +34,15 @@ func runOAuthLogin(ctx context.Context, p loginParams) error {
3234
return nil
3335
}
3436

37+
// oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token
38+
// and use it if one is present.
39+
// If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage.
3540
func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.Client, error) {
41+
// if we have a stored token, used it. Otherwise run the device flow
42+
if token, err := loadStoredOAuthToken(ctx, endpoint); err == nil {
43+
return newOAuthAPIClient(p, endpoint, token), nil
44+
}
45+
3646
token, err := runOAuthDeviceFlow(ctx, endpoint, p.out, p.oauthClient)
3747
if err != nil {
3848
return nil, err
@@ -43,15 +53,19 @@ func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.
4353
fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err)
4454
}
4555

56+
return newOAuthAPIClient(p, endpoint, token), nil
57+
}
58+
59+
func newOAuthAPIClient(p loginParams, endpoint string, token *oauth.Token) api.Client {
4660
return api.NewClient(api.ClientOpts{
47-
Endpoint: p.cfg.Endpoint,
61+
Endpoint: endpoint,
4862
AdditionalHeaders: p.cfg.AdditionalHeaders,
4963
Flags: p.apiFlags,
5064
Out: p.out,
5165
ProxyURL: p.cfg.ProxyURL,
5266
ProxyPath: p.cfg.ProxyPath,
5367
OAuthToken: token,
54-
}), nil
68+
})
5569
}
5670

5771
func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) {

cmd/src/login_test.go

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http/httptest"
1010
"strings"
1111
"testing"
12+
"time"
1213

1314
"github.com/sourcegraph/src-cli/internal/cmderrors"
1415
"github.com/sourcegraph/src-cli/internal/oauth"
@@ -18,51 +19,47 @@ func TestLogin(t *testing.T) {
1819
check := func(t *testing.T, cfg *config, endpointArg string) (output string, err error) {
1920
t.Helper()
2021

21-
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
22-
return nil, fmt.Errorf("not found")
23-
})
24-
2522
var out bytes.Buffer
2623
err = loginCmd(context.Background(), loginParams{
2724
cfg: cfg,
2825
client: cfg.apiClient(nil, io.Discard),
2926
endpoint: endpointArg,
3027
out: &out,
31-
oauthClient: oauth.NewClient(oauth.DefaultClientID),
28+
oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")},
3229
})
3330
return strings.TrimSpace(out.String()), err
3431
}
3532

3633
t.Run("different endpoint in config vs. arg", func(t *testing.T) {
3734
out, err := check(t, &config{Endpoint: "https://example.com"}, "https://sourcegraph.example.com")
38-
if err != cmderrors.ExitCode1 {
35+
if err == nil {
3936
t.Fatal(err)
4037
}
41-
wantOut := "❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://sourcegraph.example.com"
42-
if out != wantOut {
43-
t.Errorf("got output %q, want %q", out, wantOut)
38+
if !strings.Contains(out, "OAuth Device flow authentication failed:") {
39+
t.Errorf("got output %q, want oauth failure output", out)
4440
}
4541
})
4642

47-
t.Run("no access token", func(t *testing.T) {
43+
t.Run("no access token triggers oauth flow", func(t *testing.T) {
4844
out, err := check(t, &config{Endpoint: "https://example.com"}, "https://sourcegraph.example.com")
49-
if err != cmderrors.ExitCode1 {
45+
if err == nil {
5046
t.Fatal(err)
5147
}
52-
wantOut := "❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://sourcegraph.example.com"
53-
if out != wantOut {
54-
t.Errorf("got output %q, want %q", out, wantOut)
48+
if !strings.Contains(out, "OAuth Device flow authentication failed:") {
49+
t.Errorf("got output %q, want oauth failure output", out)
5550
}
5651
})
5752

5853
t.Run("warning when using config file", func(t *testing.T) {
5954
out, err := check(t, &config{Endpoint: "https://example.com", ConfigFilePath: "f"}, "https://example.com")
60-
if err != cmderrors.ExitCode1 {
55+
if err == nil {
6156
t.Fatal(err)
6257
}
63-
wantOut := "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove f. See https://github.com/sourcegraph/src-cli#readme for more information.\n\n❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://example.com"
64-
if out != wantOut {
65-
t.Errorf("got output %q, want %q", out, wantOut)
58+
if !strings.Contains(out, "Configuring src with a JSON file is deprecated") {
59+
t.Errorf("got output %q, want deprecation warning", out)
60+
}
61+
if !strings.Contains(out, "OAuth Device flow authentication failed:") {
62+
t.Errorf("got output %q, want oauth failure output", out)
6663
}
6764
})
6865

@@ -77,7 +74,7 @@ func TestLogin(t *testing.T) {
7774
if err != cmderrors.ExitCode1 {
7875
t.Fatal(err)
7976
}
80-
wantOut := "❌ Problem: Invalid access token.\n\n🛠 To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth $ENDPOINT\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)"
77+
wantOut := "❌ Problem: Invalid access token.\n\n🛠 To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in interactively by running: src login $ENDPOINT\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)"
8178
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", endpoint)
8279
if out != wantOut {
8380
t.Errorf("got output %q, want %q", out, wantOut)
@@ -101,33 +98,86 @@ func TestLogin(t *testing.T) {
10198
t.Errorf("got output %q, want %q", out, wantOut)
10299
}
103100
})
104-
}
105101

106-
func TestSelectLoginFlow(t *testing.T) {
107-
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
108-
return nil, fmt.Errorf("not found")
109-
})
102+
t.Run("reuses stored oauth token before device flow", func(t *testing.T) {
103+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
104+
fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`)
105+
}))
106+
defer s.Close()
110107

111-
t.Run("uses oauth flow when oauth flag is set", func(t *testing.T) {
112-
params := loginParams{
113-
cfg: &config{Endpoint: "https://example.com"},
114-
endpoint: "https://example.com",
115-
useOAuth: true,
116-
}
108+
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
109+
return &oauth.Token{
110+
Endpoint: s.URL,
111+
ClientID: oauth.DefaultClientID,
112+
AccessToken: "oauth-token",
113+
ExpiresAt: time.Now().Add(time.Hour),
114+
}, nil
115+
})
117116

118-
if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowOAuth {
119-
t.Fatalf("flow = %v, want %v", got, loginFlowOAuth)
117+
startCalled := false
118+
var out bytes.Buffer
119+
err := loginCmd(context.Background(), loginParams{
120+
cfg: &config{Endpoint: s.URL},
121+
client: (&config{Endpoint: s.URL}).apiClient(nil, io.Discard),
122+
endpoint: s.URL,
123+
out: &out,
124+
oauthClient: fakeOAuthClient{
125+
startErr: fmt.Errorf("unexpected call to Start"),
126+
startCalled: &startCalled,
127+
},
128+
})
129+
if err != nil {
130+
t.Fatal(err)
131+
}
132+
if startCalled {
133+
t.Fatal("expected stored oauth token to avoid device flow")
134+
}
135+
gotOut := strings.TrimSpace(out.String())
136+
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials"
137+
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL)
138+
if gotOut != wantOut {
139+
t.Errorf("got output %q, want %q", gotOut, wantOut)
120140
}
121141
})
142+
}
143+
144+
type fakeOAuthClient struct {
145+
startErr error
146+
startCalled *bool
147+
}
148+
149+
func (f fakeOAuthClient) ClientID() string {
150+
return oauth.DefaultClientID
151+
}
152+
153+
func (f fakeOAuthClient) Discover(context.Context, string) (*oauth.OIDCConfiguration, error) {
154+
return nil, fmt.Errorf("unexpected call to Discover")
155+
}
156+
157+
func (f fakeOAuthClient) Start(context.Context, string, []string) (*oauth.DeviceAuthResponse, error) {
158+
if f.startCalled != nil {
159+
*f.startCalled = true
160+
}
161+
return nil, f.startErr
162+
}
163+
164+
func (f fakeOAuthClient) Poll(context.Context, string, string, time.Duration, int) (*oauth.TokenResponse, error) {
165+
return nil, fmt.Errorf("unexpected call to Poll")
166+
}
167+
168+
func (f fakeOAuthClient) Refresh(context.Context, *oauth.Token) (*oauth.TokenResponse, error) {
169+
return nil, fmt.Errorf("unexpected call to Refresh")
170+
}
122171

123-
t.Run("uses missing auth flow when auth is unavailable", func(t *testing.T) {
172+
func TestSelectLoginFlow(t *testing.T) {
173+
t.Run("uses oauth flow when no access token is configured", func(t *testing.T) {
124174
params := loginParams{
125175
cfg: &config{Endpoint: "https://example.com"},
126176
endpoint: "https://sourcegraph.example.com",
127177
}
128178

129-
if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowMissingAuth {
130-
t.Fatalf("flow = %v, want %v", got, loginFlowMissingAuth)
179+
if got, _ := selectLoginFlow(params); got != loginFlowOAuth {
180+
t.Fatalf("flow = %v, want %v", got, loginFlowOAuth)
131181
}
132182
})
133183

@@ -137,7 +187,7 @@ func TestSelectLoginFlow(t *testing.T) {
137187
endpoint: "https://sourcegraph.example.com",
138188
}
139189

140-
if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowEndpointConflict {
190+
if got, _ := selectLoginFlow(params); got != loginFlowEndpointConflict {
141191
t.Fatalf("flow = %v, want %v", got, loginFlowEndpointConflict)
142192
}
143193
})
@@ -148,22 +198,7 @@ func TestSelectLoginFlow(t *testing.T) {
148198
endpoint: "https://example.com",
149199
}
150200

151-
if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowValidate {
152-
t.Fatalf("flow = %v, want %v", got, loginFlowValidate)
153-
}
154-
})
155-
156-
t.Run("treats stored oauth as effective auth", func(t *testing.T) {
157-
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
158-
return &oauth.Token{AccessToken: "oauth-token"}, nil
159-
})
160-
161-
params := loginParams{
162-
cfg: &config{Endpoint: "https://example.com"},
163-
endpoint: "https://example.com",
164-
}
165-
166-
if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowValidate {
201+
if got, _ := selectLoginFlow(params); got != loginFlowValidate {
167202
t.Fatalf("flow = %v, want %v", got, loginFlowValidate)
168203
}
169204
})

0 commit comments

Comments
 (0)