Skip to content

Commit e2d7bd2

Browse files
committed
refactor: oauth
- add mutex to guard concurrent changes of token - pull refreshing of token out into `refreshToken` - additional comments
1 parent 258f67a commit e2d7bd2

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

internal/oauth/http_transport.go

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package oauth
33
import (
44
"context"
55
"net/http"
6+
"sync"
67
"time"
78
)
89

@@ -13,6 +14,8 @@ var _ http.RoundTripper = (*Transport)(nil)
1314
type Transport struct {
1415
Base http.RoundTripper
1516
Token *Token
17+
18+
mu sync.Mutex
1619
}
1720

1821
// storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during
@@ -22,16 +25,10 @@ var storeRefreshedTokenFn = StoreToken
2225
// RoundTrip implements http.RoundTripper.
2326
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
2427
ctx := req.Context()
25-
prevToken := t.Token
26-
token, err := maybeRefresh(ctx, t.Token)
27-
if err != nil {
28+
29+
if err := t.refreshToken(ctx); err != nil {
2830
return nil, err
2931
}
30-
t.Token = token
31-
if token != prevToken {
32-
// try to save the token if we fail let the request continue with in memory token
33-
_ = storeRefreshedTokenFn(ctx, token)
34-
}
3532

3633
req2 := req.Clone(req.Context())
3734
req2.Header.Set("Authorization", "Bearer "+t.Token.AccessToken)
@@ -42,8 +39,32 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
4239
return http.DefaultTransport.RoundTrip(req2)
4340
}
4441

42+
// refreshToken checks if the token has expired or expiring soon and refreshes it. Once the token is
43+
// refreshed, the in-memory token is updated and a best effort is made to store the token.
44+
// If storing the token fails, no error is returned.
45+
func (t *Transport) refreshToken(ctx context.Context) error {
46+
t.mu.Lock()
47+
defer t.mu.Unlock()
48+
49+
prevToken := t.Token
50+
token, err := maybeRefresh(ctx, t.Token)
51+
if err != nil {
52+
return err
53+
}
54+
t.Token = token
55+
if token != prevToken {
56+
// try to save the token if we fail let the request continue with in memory token
57+
_ = storeRefreshedTokenFn(ctx, token)
58+
}
59+
60+
return nil
61+
}
62+
63+
// maybeRefresh conditionally refreshes the token. If the token has expired or is expriing in the next 30s
64+
// it will be refreshed and the updated token will be returned. Otherwise, no refresh occurs and the original
65+
// token is returned.
4566
func maybeRefresh(ctx context.Context, token *Token) (*Token, error) {
46-
// token has NOT expired or NOT about to expire in 30s
67+
// token has NOT expired and is NOT about to expire in 30s
4768
if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) {
4869
return token, nil
4970
}
@@ -59,6 +80,7 @@ func maybeRefresh(ctx context.Context, token *Token) (*Token, error) {
5980
return next, nil
6081
}
6182

83+
// IsOAuthTransport checks wether the underlying type of the given RoundTripper is a OAuthTransport
6284
func IsOAuthTransport(trp http.RoundTripper) bool {
6385
_, ok := trp.(*Transport)
6486
return ok

0 commit comments

Comments
 (0)