diff --git a/cmd/src/auth.go b/cmd/src/auth.go new file mode 100644 index 0000000000..e2147787d6 --- /dev/null +++ b/cmd/src/auth.go @@ -0,0 +1,37 @@ +package main + +import ( + "flag" + "fmt" +) + +var authCommands commander + +func init() { + usage := `'src auth' provides authentication-related helper commands. + +Usage: + + src auth command [command options] + +The commands are: + + token prints the current authentication token + +Use "src auth [command] -h" for more information about a command. +` + + flagSet := flag.NewFlagSet("auth", flag.ExitOnError) + handler := func(args []string) error { + authCommands.run(flagSet, "src auth", usage, args) + return nil + } + + commands = append(commands, &command{ + flagSet: flagSet, + handler: handler, + usageFunc: func() { + fmt.Println(usage) + }, + }) +} diff --git a/cmd/src/auth_token.go b/cmd/src/auth_token.go new file mode 100644 index 0000000000..3b78499593 --- /dev/null +++ b/cmd/src/auth_token.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "flag" + "fmt" + + "github.com/sourcegraph/sourcegraph/lib/errors" + + "github.com/sourcegraph/src-cli/internal/oauth" +) + +var ( + loadOAuthToken = oauth.LoadToken + newOAuthTokenRefresher = func(token *oauth.Token) oauthTokenRefresher { + return oauth.NewTokenRefresher(token) + } +) + +type oauthTokenRefresher interface { + GetToken(ctx context.Context) (oauth.Token, error) +} + +func init() { + flagSet := flag.NewFlagSet("token", flag.ExitOnError) + usageFunc := func() { + fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src auth token':\n") + flagSet.PrintDefaults() + } + + handler := func(args []string) error { + if err := flagSet.Parse(args); err != nil { + return err + } + + token, err := resolveAuthToken(context.Background(), cfg) + if err != nil { + return err + } + + fmt.Println(token) + return nil + } + + authCommands = append(authCommands, &command{ + flagSet: flagSet, + handler: handler, + usageFunc: usageFunc, + }) +} + +func resolveAuthToken(ctx context.Context, cfg *config) (string, error) { + if cfg.accessToken != "" { + return cfg.accessToken, nil + } + + oauthToken, err := loadOAuthToken(ctx, cfg.endpointURL) + if err != nil { + return "", errors.Wrap(err, "error loading OAuth token; set SRC_ACCESS_TOKEN or run `src login`") + } + + token, err := newOAuthTokenRefresher(oauthToken).GetToken(ctx) + if err != nil { + return "", errors.Wrap(err, "refreshing OAuth token") + } + + return token.AccessToken, nil +} diff --git a/cmd/src/auth_token_test.go b/cmd/src/auth_token_test.go new file mode 100644 index 0000000000..8471be37f2 --- /dev/null +++ b/cmd/src/auth_token_test.go @@ -0,0 +1,128 @@ +package main + +import ( + "context" + "fmt" + "net/url" + "testing" + + "github.com/sourcegraph/src-cli/internal/oauth" +) + +func TestResolveAuthToken(t *testing.T) { + t.Run("uses configured access token before keyring", func(t *testing.T) { + reset := stubAuthTokenDependencies(t) + defer reset() + + newRefresherCalled := false + newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher { + newRefresherCalled = true + return fakeOAuthTokenRefresher{} + } + + token, err := resolveAuthToken(context.Background(), &config{ + accessToken: "access-token", + endpointURL: mustParseURL(t, "https://example.com"), + }) + if err != nil { + t.Fatal(err) + } + if token != "access-token" { + t.Fatalf("token = %q, want %q", token, "access-token") + } + if newRefresherCalled { + t.Fatal("expected OAuth token refresher not to be created") + } + }) + + t.Run("uses stored oauth token", func(t *testing.T) { + reset := stubAuthTokenDependencies(t) + defer reset() + + loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) { + return &oauth.Token{ + AccessToken: "oauth-token", + }, nil + } + + newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher { + return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "oauth-token"}} + } + + token, err := resolveAuthToken(context.Background(), &config{ + endpointURL: mustParseURL(t, "https://example.com"), + }) + if err != nil { + t.Fatal(err) + } + if token != "oauth-token" { + t.Fatalf("token = %q, want %q", token, "oauth-token") + } + }) + + t.Run("refreshes expiring oauth token", func(t *testing.T) { + reset := stubAuthTokenDependencies(t) + defer reset() + + loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) { + return &oauth.Token{AccessToken: "old-token"}, nil + } + + newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher { + return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "new-token"}} + } + + token, err := resolveAuthToken(context.Background(), &config{ + endpointURL: mustParseURL(t, "https://example.com"), + }) + if err != nil { + t.Fatal(err) + } + if token != "new-token" { + t.Fatalf("token = %q, want %q", token, "new-token") + } + }) + + t.Run("returns refresh error when shared refresh logic fails", func(t *testing.T) { + reset := stubAuthTokenDependencies(t) + defer reset() + + loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) { + return &oauth.Token{AccessToken: "old-token"}, nil + } + newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher { + return fakeOAuthTokenRefresher{err: fmt.Errorf("refresh failed")} + } + + _, err := resolveAuthToken(context.Background(), &config{ + endpointURL: mustParseURL(t, "https://example.com"), + }) + if err == nil { + t.Fatal("expected error") + } + }) +} + +func stubAuthTokenDependencies(t *testing.T) func() { + t.Helper() + + prevLoad := loadOAuthToken + prevNewRefresher := newOAuthTokenRefresher + + return func() { + loadOAuthToken = prevLoad + newOAuthTokenRefresher = prevNewRefresher + } +} + +type fakeOAuthTokenRefresher struct { + token oauth.Token + err error +} + +func (r fakeOAuthTokenRefresher) GetToken(context.Context) (oauth.Token, error) { + if r.err != nil { + return oauth.Token{}, r.err + } + return r.token, nil +} diff --git a/cmd/src/main.go b/cmd/src/main.go index fa308072b7..0e42c2f465 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -50,6 +50,7 @@ The options are: The commands are: + auth authentication helper commands api interacts with the Sourcegraph GraphQL API batch manages batch changes code-intel manages code intelligence data diff --git a/internal/api/api.go b/internal/api/api.go index a824fae1a5..73a0416097 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -110,10 +110,7 @@ func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper { } if opts.AccessToken == "" && opts.OAuthToken != nil { - transport = &oauth.Transport{ - Base: transport, - Token: opts.OAuthToken, - } + transport = oauth.NewTransport(transport, opts.OAuthToken) } return transport diff --git a/internal/oauth/http_transport.go b/internal/oauth/http_transport.go index 854dd61138..adbc16d546 100644 --- a/internal/oauth/http_transport.go +++ b/internal/oauth/http_transport.go @@ -13,14 +13,24 @@ var _ http.Transport var _ http.RoundTripper = (*Transport)(nil) +const defaultRefreshWindow = 5 * time.Minute + type Transport struct { - Base http.RoundTripper - //Token is a OAuth token (which has a refresh token) that should be used during roundtrip to automatically - //refresh the OAuth access token once the current one has expired or is soon to expire - Token *Token + Base http.RoundTripper + refresher *TokenRefresher +} - //mu is a mutex that should be acquired whenever token used - mu sync.Mutex +type TokenRefresher struct { + token *Token + mu sync.Mutex +} + +func NewTokenRefresher(token *Token) *TokenRefresher { + return &TokenRefresher{token: token} +} + +func NewTransport(base http.RoundTripper, token *Token) *Transport { + return &Transport{Base: base, refresher: NewTokenRefresher(token)} } // storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during @@ -30,8 +40,7 @@ var storeRefreshedTokenFn = StoreToken // RoundTrip implements http.RoundTripper. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() - - token, err := t.getToken(ctx) + token, err := t.refresher.GetToken(ctx) if err != nil { return nil, err } @@ -45,36 +54,40 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return http.DefaultTransport.RoundTrip(req2) } -// getToken returns a value copy of the token. If the token has expired or expiring soon it will be refreshed before returning. +// GetToken returns a value copy of the token. If the token has expired or expiring soon it will be refreshed before returning. // Once the token is refreshed, the in-memory token is updated and a best effort is made to store the token. // // If storing the token fails, no error is returned. An error is only returned if refreshing the token // fails. -func (t *Transport) getToken(ctx context.Context) (Token, error) { - t.mu.Lock() - defer t.mu.Unlock() +func (r *TokenRefresher) GetToken(ctx context.Context) (Token, error) { + r.mu.Lock() + defer r.mu.Unlock() - prevToken := t.Token - token, err := maybeRefresh(ctx, t.Token) + prevToken := r.token + token, err := maybeRefreshToken(ctx, r.token) if err != nil { return Token{}, err } - t.Token = token + r.token = token if token != prevToken { // Try to save the token. // If we fail let the request continue with the in-memory token _ = storeRefreshedTokenFn(ctx, token) } - return *t.Token, nil + return *r.token, nil } -// maybeRefresh conditionally refreshes the token. If the token has expired or is expriing in the next 30s -// it will be refreshed and the updated token will be returned. Otherwise, no refresh occurs and the original -// token is returned. -func maybeRefresh(ctx context.Context, token *Token) (*Token, error) { +// maybeRefreshToken conditionally refreshes the token. If the token has expired or is +// expiring within the default refresh window, it will be refreshed and the updated token returned. +// Otherwise, no refresh occurs and the original token is returned. +func maybeRefreshToken(ctx context.Context, token *Token) (*Token, error) { + if token == nil { + return nil, errors.New("token is nil") + } + // token has NOT expired and is NOT about to expire in 30s - if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) { + if !(token.HasExpired() || token.ExpiringIn(defaultRefreshWindow)) { return token, nil } client := NewClient(token.ClientID) diff --git a/internal/oauth/http_transport_test.go b/internal/oauth/http_transport_test.go index 4dac832d05..47cb2b0dac 100644 --- a/internal/oauth/http_transport_test.go +++ b/internal/oauth/http_transport_test.go @@ -28,10 +28,14 @@ func newRefreshServer(t *testing.T, accessToken string) *httptest.Server { }) } -func TestMaybeRefresh(t *testing.T) { +func TestTokenRefresherGetToken(t *testing.T) { server := newRefreshServer(t, "new-token") defer server.Close() + originalStoreFn := storeRefreshedTokenFn + storeRefreshedTokenFn = func(context.Context, *Token) error { return nil } + defer func() { storeRefreshedTokenFn = originalStoreFn }() + tests := []struct { name string token *Token @@ -71,14 +75,15 @@ func TestMaybeRefresh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := maybeRefresh(context.Background(), tt.token) + refresher := NewTokenRefresher(tt.token) + got, err := refresher.GetToken(context.Background()) if err != nil { - t.Fatalf("maybeRefresh() error = %v", err) + t.Fatalf("GetToken() error = %v", err) } if got.AccessToken != tt.wantAccess { t.Errorf("AccessToken = %q, want %q", got.AccessToken, tt.wantAccess) } - if tt.wantSame && got != tt.token { + if tt.wantSame && refresher.token != tt.token { t.Errorf("token pointer changed for unexpired token") } }) @@ -145,13 +150,10 @@ func TestTransportRoundTrip(t *testing.T) { } var capturedAuth string - tr := &Transport{ - Base: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - capturedAuth = req.Header.Get("Authorization") - return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil - }), - Token: tt.token, - } + tr := NewTransport(roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedAuth = req.Header.Get("Authorization") + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }), tt.token) _, err := tr.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil)) if err != nil {