@@ -3,6 +3,7 @@ package oauth
33import (
44 "context"
55 "net/http"
6+ "sync"
67 "time"
78)
89
@@ -13,6 +14,8 @@ var _ http.RoundTripper = (*Transport)(nil)
1314type Transport struct {
1415 Base http.RoundTripper
1516 Token * Token
17+
18+ mu sync.Mutex
1619}
1720
1821// storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during
@@ -22,16 +25,10 @@ var storeRefreshedTokenFn = StoreToken
2225// RoundTrip implements http.RoundTripper.
2326func (t * Transport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
2427 ctx := req .Context ()
25- prevToken := t .Token
26- token , err := maybeRefresh (ctx , t .Token )
27- if err != nil {
28+
29+ if err := t .refreshToken (ctx ); err != nil {
2830 return nil , err
2931 }
30- t .Token = token
31- if token != prevToken {
32- // try to save the token if we fail let the request continue with in memory token
33- _ = storeRefreshedTokenFn (ctx , token )
34- }
3532
3633 req2 := req .Clone (req .Context ())
3734 req2 .Header .Set ("Authorization" , "Bearer " + t .Token .AccessToken )
@@ -42,8 +39,32 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
4239 return http .DefaultTransport .RoundTrip (req2 )
4340}
4441
42+ // refreshToken checks if the token has expired or expiring soon and refreshes it. Once the token is
43+ // refreshed, the in-memory token is updated and a best effort is made to store the token.
44+ // If storing the token fails, no error is returned.
45+ func (t * Transport ) refreshToken (ctx context.Context ) error {
46+ t .mu .Lock ()
47+ defer t .mu .Unlock ()
48+
49+ prevToken := t .Token
50+ token , err := maybeRefresh (ctx , t .Token )
51+ if err != nil {
52+ return err
53+ }
54+ t .Token = token
55+ if token != prevToken {
56+ // try to save the token if we fail let the request continue with in memory token
57+ _ = storeRefreshedTokenFn (ctx , token )
58+ }
59+
60+ return nil
61+ }
62+
63+ // maybeRefresh conditionally refreshes the token. If the token has expired or is expriing in the next 30s
64+ // it will be refreshed and the updated token will be returned. Otherwise, no refresh occurs and the original
65+ // token is returned.
4566func maybeRefresh (ctx context.Context , token * Token ) (* Token , error ) {
46- // token has NOT expired or NOT about to expire in 30s
67+ // token has NOT expired and is NOT about to expire in 30s
4768 if ! (token .HasExpired () || token .ExpiringIn (time .Duration (30 )* time .Second )) {
4869 return token , nil
4970 }
@@ -59,6 +80,7 @@ func maybeRefresh(ctx context.Context, token *Token) (*Token, error) {
5980 return next , nil
6081}
6182
83+ // IsOAuthTransport checks wether the underlying type of the given RoundTripper is a OAuthTransport
6284func IsOAuthTransport (trp http.RoundTripper ) bool {
6385 _ , ok := trp .(* Transport )
6486 return ok
0 commit comments