diff --git a/cmd/src/login.go b/cmd/src/login.go index 5a73ef4cc8..c4989605f4 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -6,10 +6,6 @@ import ( "fmt" "io" "os" - "os/exec" - "runtime" - "strings" - "time" "github.com/sourcegraph/src-cli/internal/api" "github.com/sourcegraph/src-cli/internal/cmderrors" @@ -69,13 +65,13 @@ Examples: client := cfg.apiClient(apiFlags, io.Discard) return loginCmd(context.Background(), loginParams{ - cfg: cfg, - client: client, - endpoint: endpoint, - out: os.Stdout, - useOAuth: *useOAuth, - apiFlags: apiFlags, - deviceFlowClient: oauth.NewClient(oauth.DefaultClientID), + cfg: cfg, + client: client, + endpoint: endpoint, + out: os.Stdout, + useOAuth: *useOAuth, + apiFlags: apiFlags, + oauthClient: oauth.NewClient(oauth.DefaultClientID), }) } @@ -87,161 +83,80 @@ Examples: } type loginParams struct { - cfg *config - client api.Client - endpoint string - out io.Writer - useOAuth bool - apiFlags *api.Flags - deviceFlowClient oauth.Client + cfg *config + client api.Client + endpoint string + out io.Writer + useOAuth bool + apiFlags *api.Flags + oauthClient oauth.Client } -func loginCmd(ctx context.Context, p loginParams) error { - endpointArg := cleanEndpoint(p.endpoint) - cfg := p.cfg - client := p.client - out := p.out - - printProblem := func(problem string) { - fmt.Fprintf(out, "❌ Problem: %s\n", problem) - } - - createAccessTokenMessage := fmt.Sprintf("\n"+`🛠 To fix: Create an access token by going to %s/user/settings/tokens, then set the following environment variables in your terminal: +type loginFlow func(context.Context, loginParams) error - export SRC_ENDPOINT=%s - export SRC_ACCESS_TOKEN=(your access token) +type loginFlowKind int - To verify that it's working, run the login command again. - - Alternatively, you can try logging in using OAuth by running: src login --oauth %s -`, endpointArg, endpointArg, endpointArg) +const ( + loginFlowOAuth loginFlowKind = iota + loginFlowMissingAuth + loginFlowEndpointConflict + loginFlowValidate +) - if cfg.ConfigFilePath != "" { - fmt.Fprintln(out) - fmt.Fprintf(out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", cfg.ConfigFilePath) - } +var loadStoredOAuthToken = oauth.LoadToken - noToken := cfg.AccessToken == "" - endpointConflict := endpointArg != cfg.Endpoint - if !p.useOAuth && (noToken || endpointConflict) { - fmt.Fprintln(out) - switch { - case noToken: - printProblem("No access token is configured.") - case endpointConflict: - printProblem(fmt.Sprintf("The configured endpoint is %s, not %s.", cfg.Endpoint, endpointArg)) - } - fmt.Fprintln(out, createAccessTokenMessage) - return cmderrors.ExitCode1 +func loginCmd(ctx context.Context, p loginParams) error { + if p.cfg.ConfigFilePath != "" { + fmt.Fprintln(p.out) + fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.ConfigFilePath) } - if p.useOAuth { - token, err := runOAuthDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) - if err != nil { - printProblem(fmt.Sprintf("OAuth Device flow authentication failed: %s", err)) - fmt.Fprintln(out, createAccessTokenMessage) - return cmderrors.ExitCode1 - } - - if err := oauth.StoreToken(ctx, token); err != nil { - fmt.Fprintln(out) - fmt.Fprintf(out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err) - } + _, flow := selectLoginFlow(ctx, p) + return flow(ctx, p) +} - client = api.NewClient(api.ClientOpts{ - Endpoint: cfg.Endpoint, - AdditionalHeaders: cfg.AdditionalHeaders, - Flags: p.apiFlags, - Out: out, - ProxyURL: cfg.ProxyURL, - ProxyPath: cfg.ProxyPath, - OAuthToken: token, - }) - } +// selectLoginFlow decides what login flow to run based on flags and config. +func selectLoginFlow(ctx context.Context, p loginParams) (loginFlowKind, loginFlow) { + endpointArg := cleanEndpoint(p.endpoint) - // See if the user is already authenticated. - query := `query CurrentUser { currentUser { username } }` - var result struct { - CurrentUser *struct{ Username string } - } - if _, err := client.NewRequest(query, nil).Do(ctx, &result); err != nil { - if strings.HasPrefix(err.Error(), "error: 401 Unauthorized") || strings.HasPrefix(err.Error(), "error: 403 Forbidden") { - printProblem("Invalid access token.") - } else { - printProblem(fmt.Sprintf("Error communicating with %s: %s", endpointArg, err)) - } - fmt.Fprintln(out, createAccessTokenMessage) - fmt.Fprintln(out, " (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)") - return cmderrors.ExitCode1 + if p.useOAuth { + return loginFlowOAuth, runOAuthLogin } - - if result.CurrentUser == nil { - // This should never happen; we verified there is an access token, so there should always be - // a user. - printProblem(fmt.Sprintf("Unable to determine user on %s.", endpointArg)) - return cmderrors.ExitCode1 + if !hasEffectiveAuth(ctx, p.cfg, endpointArg) { + return loginFlowMissingAuth, runMissingAuthLogin } - fmt.Fprintln(out) - fmt.Fprintf(out, "✔️ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointArg) - - if p.useOAuth { - fmt.Fprintln(out) - fmt.Fprintf(out, "Authenticated with OAuth credentials") + if endpointArg != p.cfg.Endpoint { + return loginFlowEndpointConflict, runEndpointConflictLogin } - - fmt.Fprintln(out) - return nil + return loginFlowValidate, runValidatedLogin } -func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) { - authResp, err := client.Start(ctx, endpoint, nil) - if err != nil { - return nil, err +// hasEffectiveAuth determines whether we have auth credentials to continue. It first checks for a resolved Access Token in +// config, then it checks for a stored OAuth token. +func hasEffectiveAuth(ctx context.Context, cfg *config, resolvedEndpoint string) bool { + if cfg.AccessToken != "" { + return true } - authURL := authResp.VerificationURIComplete - msg := fmt.Sprintf("If your browser did not open automatically, visit %s.", authURL) - if authURL == "" { - authURL = authResp.VerificationURI - msg = fmt.Sprintf("If your browser did not open automatically, visit %s and enter the user code %s", authURL, authResp.DeviceCode) + if _, err := loadStoredOAuthToken(ctx, resolvedEndpoint); err == nil { + return true } - _ = openInBrowser(authURL) - fmt.Fprintln(out) - fmt.Fprint(out, msg) - fmt.Fprintln(out) - fmt.Fprint(out, "Waiting for authorization...") - defer fmt.Fprintf(out, "DONE\n\n") + return false +} - interval := time.Duration(authResp.Interval) * time.Second - if interval <= 0 { - interval = 5 * time.Second - } +func printLoginProblem(out io.Writer, problem string) { + fmt.Fprintf(out, "❌ Problem: %s\n", problem) +} - resp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) - if err != nil { - return nil, err - } +func loginAccessTokenMessage(endpoint string) string { + return fmt.Sprintf("\n"+`🛠 To fix: Create an access token by going to %s/user/settings/tokens, then set the following environment variables in your terminal: - token := resp.Token(endpoint) - token.ClientID = client.ClientID() - return token, nil -} + export SRC_ENDPOINT=%s + export SRC_ACCESS_TOKEN=(your access token) -func openInBrowser(url string) error { - if url == "" { - return nil - } + To verify that it's working, run the login command again. - var cmd *exec.Cmd - switch runtime.GOOS { - case "darwin": - cmd = exec.Command("open", url) - case "windows": - // "start" is a cmd.exe built-in; the empty string is the window title. - cmd = exec.Command("cmd", "/c", "start", "", url) - default: - cmd = exec.Command("xdg-open", url) - } - return cmd.Run() + Alternatively, you can try logging in using OAuth by running: src login --oauth %s +`, endpoint, endpoint, endpoint) } diff --git a/cmd/src/login_oauth.go b/cmd/src/login_oauth.go new file mode 100644 index 0000000000..df3f912f62 --- /dev/null +++ b/cmd/src/login_oauth.go @@ -0,0 +1,108 @@ +package main + +import ( + "context" + "fmt" + "io" + "os/exec" + "runtime" + "time" + + "github.com/sourcegraph/src-cli/internal/api" + "github.com/sourcegraph/src-cli/internal/cmderrors" + "github.com/sourcegraph/src-cli/internal/oauth" +) + +func runOAuthLogin(ctx context.Context, p loginParams) error { + endpointArg := cleanEndpoint(p.endpoint) + client, err := oauthLoginClient(ctx, p, endpointArg) + if err != nil { + printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err)) + fmt.Fprintln(p.out, loginAccessTokenMessage(endpointArg)) + return cmderrors.ExitCode1 + } + + if err := validateCurrentUser(ctx, client, p.out, endpointArg); err != nil { + return err + } + + fmt.Fprintln(p.out) + fmt.Fprint(p.out, "✔︎ Authenticated with OAuth credentials") + fmt.Fprintln(p.out) + return nil +} + +func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.Client, error) { + token, err := runOAuthDeviceFlow(ctx, endpoint, p.out, p.oauthClient) + if err != nil { + return nil, err + } + + if err := oauth.StoreToken(ctx, token); err != nil { + fmt.Fprintln(p.out) + fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err) + } + + return api.NewClient(api.ClientOpts{ + Endpoint: p.cfg.Endpoint, + AdditionalHeaders: p.cfg.AdditionalHeaders, + Flags: p.apiFlags, + Out: p.out, + ProxyURL: p.cfg.ProxyURL, + ProxyPath: p.cfg.ProxyPath, + OAuthToken: token, + }), nil +} + +func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) { + authResp, err := client.Start(ctx, endpoint, nil) + if err != nil { + return nil, err + } + + authURL := authResp.VerificationURIComplete + msg := fmt.Sprintf("If your browser did not open automatically, visit %s.", authURL) + if authURL == "" { + authURL = authResp.VerificationURI + msg = fmt.Sprintf("If your browser did not open automatically, visit %s and enter the user code %s", authURL, authResp.DeviceCode) + } + _ = openInBrowser(authURL) + fmt.Fprintln(out) + fmt.Fprint(out, msg) + + fmt.Fprintln(out) + fmt.Fprint(out, "Waiting for authorization... ") + defer fmt.Fprintf(out, "DONE\n\n") + + interval := time.Duration(authResp.Interval) * time.Second + if interval <= 0 { + interval = 5 * time.Second + } + + resp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) + if err != nil { + return nil, err + } + + token := resp.Token(endpoint) + token.ClientID = client.ClientID() + return token, nil +} + +func openInBrowser(url string) error { + if url == "" { + return nil + } + + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", url) + case "windows": + // "start" is a cmd.exe built-in; the empty string is the window title. + cmd = exec.Command("cmd", "/c", "start", "", url) + default: + cmd = exec.Command("xdg-open", url) + } + return cmd.Run() +} diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index ab7a15056a..a34288576f 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -18,13 +18,17 @@ func TestLogin(t *testing.T) { check := func(t *testing.T, cfg *config, endpointArg string) (output string, err error) { t.Helper() + restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) { + return nil, fmt.Errorf("not found") + }) + var out bytes.Buffer err = loginCmd(context.Background(), loginParams{ - cfg: cfg, - client: cfg.apiClient(nil, io.Discard), - endpoint: endpointArg, - out: &out, - deviceFlowClient: oauth.NewClient(oauth.DefaultClientID), + cfg: cfg, + client: cfg.apiClient(nil, io.Discard), + endpoint: endpointArg, + out: &out, + oauthClient: oauth.NewClient(oauth.DefaultClientID), }) return strings.TrimSpace(out.String()), err } @@ -63,7 +67,6 @@ func TestLogin(t *testing.T) { }) t.Run("invalid access token", func(t *testing.T) { - // Dummy HTTP server to return HTTP 401/403. s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "", http.StatusUnauthorized) })) @@ -82,7 +85,6 @@ func TestLogin(t *testing.T) { }) t.Run("valid", func(t *testing.T) { - // Dummy HTTP server to return JSON response with currentUser. s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`) })) @@ -93,10 +95,86 @@ func TestLogin(t *testing.T) { if err != nil { t.Fatal(err) } - wantOut := "✔️ Authenticated as alice on $ENDPOINT" + wantOut := "✔︎ Authenticated as alice on $ENDPOINT" wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", endpoint) if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } }) } + +func TestSelectLoginFlow(t *testing.T) { + restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) { + return nil, fmt.Errorf("not found") + }) + + t.Run("uses oauth flow when oauth flag is set", func(t *testing.T) { + params := loginParams{ + cfg: &config{Endpoint: "https://example.com"}, + endpoint: "https://example.com", + useOAuth: true, + } + + if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowOAuth { + t.Fatalf("flow = %v, want %v", got, loginFlowOAuth) + } + }) + + t.Run("uses missing auth flow when auth is unavailable", func(t *testing.T) { + params := loginParams{ + cfg: &config{Endpoint: "https://example.com"}, + endpoint: "https://sourcegraph.example.com", + } + + if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowMissingAuth { + t.Fatalf("flow = %v, want %v", got, loginFlowMissingAuth) + } + }) + + t.Run("uses endpoint conflict flow when auth exists for a different endpoint", func(t *testing.T) { + params := loginParams{ + cfg: &config{Endpoint: "https://example.com", AccessToken: "x"}, + endpoint: "https://sourcegraph.example.com", + } + + if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowEndpointConflict { + t.Fatalf("flow = %v, want %v", got, loginFlowEndpointConflict) + } + }) + + t.Run("uses validation flow when auth exists for the selected endpoint", func(t *testing.T) { + params := loginParams{ + cfg: &config{Endpoint: "https://example.com", AccessToken: "x"}, + endpoint: "https://example.com", + } + + if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowValidate { + t.Fatalf("flow = %v, want %v", got, loginFlowValidate) + } + }) + + t.Run("treats stored oauth as effective auth", func(t *testing.T) { + restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) { + return &oauth.Token{AccessToken: "oauth-token"}, nil + }) + + params := loginParams{ + cfg: &config{Endpoint: "https://example.com"}, + endpoint: "https://example.com", + } + + if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowValidate { + t.Fatalf("flow = %v, want %v", got, loginFlowValidate) + } + }) +} + +func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, string) (*oauth.Token, error)) { + t.Helper() + + prev := loadStoredOAuthToken + loadStoredOAuthToken = loader + t.Cleanup(func() { + loadStoredOAuthToken = prev + }) +} diff --git a/cmd/src/login_validate.go b/cmd/src/login_validate.go new file mode 100644 index 0000000000..1a0ddfa050 --- /dev/null +++ b/cmd/src/login_validate.go @@ -0,0 +1,61 @@ +package main + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/sourcegraph/src-cli/internal/api" + "github.com/sourcegraph/src-cli/internal/cmderrors" +) + +func runMissingAuthLogin(_ context.Context, p loginParams) error { + endpointArg := cleanEndpoint(p.endpoint) + + fmt.Fprintln(p.out) + printLoginProblem(p.out, "No access token is configured.") + fmt.Fprintln(p.out, loginAccessTokenMessage(endpointArg)) + return cmderrors.ExitCode1 +} + +func runEndpointConflictLogin(_ context.Context, p loginParams) error { + endpointArg := cleanEndpoint(p.endpoint) + + fmt.Fprintln(p.out) + printLoginProblem(p.out, fmt.Sprintf("The configured endpoint is %s, not %s.", p.cfg.Endpoint, endpointArg)) + fmt.Fprintln(p.out, loginAccessTokenMessage(endpointArg)) + return cmderrors.ExitCode1 +} + +func runValidatedLogin(ctx context.Context, p loginParams) error { + return validateCurrentUser(ctx, p.client, p.out, cleanEndpoint(p.endpoint)) +} + +func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, endpointArg string) error { + query := `query CurrentUser { currentUser { username } }` + var result struct { + CurrentUser *struct{ Username string } + } + if _, err := client.NewRequest(query, nil).Do(ctx, &result); err != nil { + if strings.HasPrefix(err.Error(), "error: 401 Unauthorized") || strings.HasPrefix(err.Error(), "error: 403 Forbidden") { + printLoginProblem(out, "Invalid access token.") + } else { + printLoginProblem(out, fmt.Sprintf("Error communicating with %s: %s", endpointArg, err)) + } + fmt.Fprintln(out, loginAccessTokenMessage(endpointArg)) + fmt.Fprintln(out, " (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)") + return cmderrors.ExitCode1 + } + + if result.CurrentUser == nil { + // This should never happen; we verified there is an access token, so there should always be + // a user. + printLoginProblem(out, fmt.Sprintf("Unable to determine user on %s.", endpointArg)) + return cmderrors.ExitCode1 + } + fmt.Fprintln(out) + fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointArg) + fmt.Fprintln(out) + return nil +}