Skip to content

Commit a48cf3a

Browse files
committed
merge with main and adjust endpoint string to URL
1 parent 8b81690 commit a48cf3a

File tree

10 files changed

+136
-95
lines changed

10 files changed

+136
-95
lines changed

cmd/src/login.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"flag"
66
"fmt"
77
"io"
8+
"net/url"
89
"os"
910

1011
"github.com/sourcegraph/src-cli/internal/api"
@@ -49,25 +50,25 @@ Examples:
4950
return err
5051
}
5152

53+
var loginEndpointURL *url.URL
5254
if flagSet.NArg() >= 1 {
5355
arg := flagSet.Arg(0)
54-
parsed, err := parseEndpoint(arg)
56+
u, err := parseEndpoint(arg)
5557
if err != nil {
5658
return cmderrors.Usage(fmt.Sprintf("invalid endpoint URL: %s", arg))
5759
}
58-
if parsed.String() != cfg.endpointURL.String() {
59-
return cmderrors.Usage(fmt.Sprintf("The configured endpoint is %s, not %s", cfg.endpointURL, parsed))
60-
}
60+
loginEndpointURL = u
6161
}
6262

6363
client := cfg.apiClient(apiFlags, io.Discard)
6464

6565
return loginCmd(context.Background(), loginParams{
66-
cfg: cfg,
67-
client: client,
68-
out: os.Stdout,
69-
apiFlags: apiFlags,
70-
oauthClient: oauth.NewClient(oauth.DefaultClientID),
66+
cfg: cfg,
67+
client: client,
68+
out: os.Stdout,
69+
apiFlags: apiFlags,
70+
oauthClient: oauth.NewClient(oauth.DefaultClientID),
71+
loginEndpointURL: loginEndpointURL,
7172
})
7273
}
7374

@@ -79,11 +80,12 @@ Examples:
7980
}
8081

8182
type loginParams struct {
82-
cfg *config
83-
client api.Client
84-
out io.Writer
85-
apiFlags *api.Flags
86-
oauthClient oauth.Client
83+
cfg *config
84+
client api.Client
85+
out io.Writer
86+
apiFlags *api.Flags
87+
oauthClient oauth.Client
88+
loginEndpointURL *url.URL
8789
}
8890

8991
type loginFlow func(context.Context, loginParams) error
@@ -93,6 +95,7 @@ type loginFlowKind int
9395
const (
9496
loginFlowOAuth loginFlowKind = iota
9597
loginFlowMissingAuth
98+
loginFlowEndpointConflict
9699
loginFlowValidate
97100
)
98101

@@ -108,6 +111,9 @@ func loginCmd(ctx context.Context, p loginParams) error {
108111

109112
// selectLoginFlow decides what login flow to run based on configured AuthMode.
110113
func selectLoginFlow(p loginParams) (loginFlowKind, loginFlow) {
114+
if p.loginEndpointURL != nil && p.loginEndpointURL.String() != p.cfg.endpointURL.String() {
115+
return loginFlowEndpointConflict, runEndpointConflictLogin
116+
}
111117
switch p.cfg.AuthMode() {
112118
case AuthModeOAuth:
113119
return loginFlowOAuth, runOAuthLogin
@@ -122,7 +128,7 @@ func printLoginProblem(out io.Writer, problem string) {
122128
fmt.Fprintf(out, "❌ Problem: %s\n", problem)
123129
}
124130

125-
func loginAccessTokenMessage(endpoint string) string {
131+
func loginAccessTokenMessage(endpointURL *url.URL) string {
126132
return fmt.Sprintf("\n"+`🛠 To fix: Create an access token by going to %s/user/settings/tokens, then set the following environment variables in your terminal:
127133
128134
export SRC_ENDPOINT=%s
@@ -131,5 +137,5 @@ func loginAccessTokenMessage(endpoint string) string {
131137
To verify that it's working, run the login command again.
132138
133139
Alternatively, you can try logging in interactively by running: src login %s
134-
`, endpoint, endpoint, endpoint)
140+
`, endpointURL, endpointURL, endpointURL)
135141
}

cmd/src/login_oauth.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"net/url"
78
"os/exec"
89
"runtime"
910
"time"
@@ -16,15 +17,14 @@ import (
1617
var loadStoredOAuthToken = oauth.LoadToken
1718

1819
func runOAuthLogin(ctx context.Context, p loginParams) error {
19-
endpoint := p.cfg.endpointURL.String()
20-
client, err := oauthLoginClient(ctx, p, endpoint)
20+
client, err := oauthLoginClient(ctx, p)
2121
if err != nil {
2222
printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err))
23-
fmt.Fprintln(p.out, loginAccessTokenMessage(endpoint))
23+
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
2424
return cmderrors.ExitCode1
2525
}
2626

27-
if err := validateCurrentUser(ctx, client, p.out, endpoint); err != nil {
27+
if err := validateCurrentUser(ctx, client, p.out, p.cfg.endpointURL); err != nil {
2828
return err
2929
}
3030

@@ -37,13 +37,13 @@ func runOAuthLogin(ctx context.Context, p loginParams) error {
3737
// oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token
3838
// and use it if one is present.
3939
// If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage.
40-
func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.Client, error) {
40+
func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, error) {
4141
// if we have a stored token, used it. Otherwise run the device flow
42-
if token, err := loadStoredOAuthToken(ctx, endpoint); err == nil {
43-
return newOAuthAPIClient(p, endpoint, token), nil
42+
if token, err := loadStoredOAuthToken(ctx, p.cfg.endpointURL); err == nil {
43+
return newOAuthAPIClient(p, token), nil
4444
}
4545

46-
token, err := runOAuthDeviceFlow(ctx, endpoint, p.out, p.oauthClient)
46+
token, err := runOAuthDeviceFlow(ctx, p.cfg.endpointURL, p.out, p.oauthClient)
4747
if err != nil {
4848
return nil, err
4949
}
@@ -53,10 +53,10 @@ func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.
5353
fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err)
5454
}
5555

56-
return newOAuthAPIClient(p, endpoint, token), nil
56+
return newOAuthAPIClient(p, token), nil
5757
}
5858

59-
func newOAuthAPIClient(p loginParams, endpoint string, token *oauth.Token) api.Client {
59+
func newOAuthAPIClient(p loginParams, token *oauth.Token) api.Client {
6060
return api.NewClient(api.ClientOpts{
6161
EndpointURL: p.cfg.endpointURL,
6262
AdditionalHeaders: p.cfg.additionalHeaders,
@@ -68,8 +68,8 @@ func newOAuthAPIClient(p loginParams, endpoint string, token *oauth.Token) api.C
6868
})
6969
}
7070

71-
func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) {
72-
authResp, err := client.Start(ctx, endpoint, nil)
71+
func runOAuthDeviceFlow(ctx context.Context, endpointURL *url.URL, out io.Writer, client oauth.Client) (*oauth.Token, error) {
72+
authResp, err := client.Start(ctx, endpointURL, nil)
7373
if err != nil {
7474
return nil, err
7575
}
@@ -93,12 +93,12 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, cli
9393
interval = 5 * time.Second
9494
}
9595

96-
resp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn)
96+
resp, err := client.Poll(ctx, endpointURL, authResp.DeviceCode, interval, authResp.ExpiresIn)
9797
if err != nil {
9898
return nil, err
9999
}
100100

101-
token := resp.Token(endpoint)
101+
token := resp.Token(endpointURL)
102102
token.ClientID = client.ClientID()
103103
return token, nil
104104
}

cmd/src/login_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func TestLogin(t *testing.T) {
9797
}))
9898
defer s.Close()
9999

100-
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
100+
restoreStoredOAuthLoader(t, func(_ context.Context, _ *url.URL) (*oauth.Token, error) {
101101
return &oauth.Token{
102102
Endpoint: s.URL,
103103
ClientID: oauth.DefaultClientID,
@@ -142,18 +142,18 @@ func (f fakeOAuthClient) ClientID() string {
142142
return oauth.DefaultClientID
143143
}
144144

145-
func (f fakeOAuthClient) Discover(context.Context, string) (*oauth.OIDCConfiguration, error) {
145+
func (f fakeOAuthClient) Discover(context.Context, *url.URL) (*oauth.OIDCConfiguration, error) {
146146
return nil, fmt.Errorf("unexpected call to Discover")
147147
}
148148

149-
func (f fakeOAuthClient) Start(context.Context, string, []string) (*oauth.DeviceAuthResponse, error) {
149+
func (f fakeOAuthClient) Start(context.Context, *url.URL, []string) (*oauth.DeviceAuthResponse, error) {
150150
if f.startCalled != nil {
151151
*f.startCalled = true
152152
}
153153
return nil, f.startErr
154154
}
155155

156-
func (f fakeOAuthClient) Poll(context.Context, string, string, time.Duration, int) (*oauth.TokenResponse, error) {
156+
func (f fakeOAuthClient) Poll(context.Context, *url.URL, string, time.Duration, int) (*oauth.TokenResponse, error) {
157157
return nil, fmt.Errorf("unexpected call to Poll")
158158
}
159159

@@ -191,7 +191,7 @@ func TestSelectLoginFlow(t *testing.T) {
191191
})
192192
}
193193

194-
func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, string) (*oauth.Token, error)) {
194+
func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, *url.URL) (*oauth.Token, error)) {
195195
t.Helper()
196196

197197
prev := loadStoredOAuthToken

cmd/src/login_validate.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,32 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"net/url"
78
"strings"
89

910
"github.com/sourcegraph/src-cli/internal/api"
1011
"github.com/sourcegraph/src-cli/internal/cmderrors"
1112
)
1213

1314
func runMissingAuthLogin(_ context.Context, p loginParams) error {
14-
endpoint := p.cfg.endpointURL.String()
15-
1615
fmt.Fprintln(p.out)
1716
printLoginProblem(p.out, "No access token is configured.")
18-
fmt.Fprintln(p.out, loginAccessTokenMessage(endpoint))
17+
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
18+
return cmderrors.ExitCode1
19+
}
20+
21+
func runEndpointConflictLogin(_ context.Context, p loginParams) error {
22+
fmt.Fprintln(p.out)
23+
printLoginProblem(p.out, fmt.Sprintf("The configured endpoint is %s, not %s.", p.cfg.endpointURL, p.loginEndpointURL))
24+
fmt.Fprintln(p.out, loginAccessTokenMessage(p.loginEndpointURL))
1925
return cmderrors.ExitCode1
2026
}
2127

2228
func runValidatedLogin(ctx context.Context, p loginParams) error {
23-
return validateCurrentUser(ctx, p.client, p.out, p.cfg.endpointURL.String())
29+
return validateCurrentUser(ctx, p.client, p.out, p.cfg.endpointURL)
2430
}
2531

26-
func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, endpoint string) error {
32+
func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, endpointURL *url.URL) error {
2733
query := `query CurrentUser { currentUser { username } }`
2834
var result struct {
2935
CurrentUser *struct{ Username string }
@@ -32,21 +38,21 @@ func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer,
3238
if strings.HasPrefix(err.Error(), "error: 401 Unauthorized") || strings.HasPrefix(err.Error(), "error: 403 Forbidden") {
3339
printLoginProblem(out, "Invalid access token.")
3440
} else {
35-
printLoginProblem(out, fmt.Sprintf("Error communicating with %s: %s", endpoint, err))
41+
printLoginProblem(out, fmt.Sprintf("Error communicating with %s: %s", endpointURL, err))
3642
}
37-
fmt.Fprintln(out, loginAccessTokenMessage(endpoint))
43+
fmt.Fprintln(out, loginAccessTokenMessage(endpointURL))
3844
fmt.Fprintln(out, " (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)")
3945
return cmderrors.ExitCode1
4046
}
4147

4248
if result.CurrentUser == nil {
4349
// This should never happen; we verified there is an access token, so there should always be
4450
// a user.
45-
printLoginProblem(out, fmt.Sprintf("Unable to determine user on %s.", endpoint))
51+
printLoginProblem(out, fmt.Sprintf("Unable to determine user on %s.", endpointURL))
4652
return cmderrors.ExitCode1
4753
}
4854
fmt.Fprintln(out)
49-
fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", result.CurrentUser.Username, endpoint)
55+
fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointURL)
5056
fmt.Fprintln(out)
5157
return nil
5258
}

cmd/src/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ type config struct {
136136
proxyURL *url.URL
137137
proxyPath string
138138
configFilePath string
139-
endpointURL *url.URL
139+
endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig
140140
}
141141

142142
// configFromFile holds the config as read from the config file,
@@ -176,7 +176,7 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client {
176176

177177
// Only use OAuth if we do not have SRC_ACCESS_TOKEN set
178178
if c.accessToken == "" {
179-
if t, err := oauth.LoadToken(context.Background(), c.endpointURL.String()); err == nil {
179+
if t, err := oauth.LoadToken(context.Background(), c.endpointURL); err == nil {
180180
opts.OAuthToken = t
181181
}
182182
}

0 commit comments

Comments
 (0)