Skip to content

Commit b2d6321

Browse files
committed
Refactor auth flows to support storage contexts
Update authentication flows to support multiple storage contexts, enabling context-aware token management and refresh. Key changes: - Add *WithContext() variants for auth functions - Update user login flow to accept storage context parameter - Store access token expiry (JWT exp claim) instead of session expiry - Update token refresh to write tokens back to correct context - Add getAccessTokenExpiresAtUnix() to parse JWT exp claim - Update tests to use new context-aware functions This enables proper token refresh and bidirectional sync for both CLI and API authentication contexts.
1 parent 1703828 commit b2d6321

File tree

6 files changed

+173
-30
lines changed

6 files changed

+173
-30
lines changed

internal/pkg/auth/auth.go

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package auth
22

33
import (
4+
"bytes"
45
"fmt"
6+
"io"
57
"net/http"
68
"os"
79
"strconv"
@@ -25,7 +27,10 @@ type tokenClaims struct {
2527
//
2628
// If the user was logged in and the user session expired, reauthorizeUserRoutine is called to reauthenticate the user again.
2729
// If the environment variable STACKIT_ACCESS_TOKEN is set this token is used instead.
28-
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, _ bool) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
30+
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, context StorageContext, _ bool) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
31+
// Set the storage printer so debug messages use the correct verbosity
32+
SetStoragePrinter(p)
33+
2934
// Get access token from env and use this if present
3035
accessToken := os.Getenv(envAccessTokenName)
3136
if accessToken != "" {
@@ -70,7 +75,7 @@ func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print
7075
case AUTH_FLOW_USER_TOKEN:
7176
p.Debug(print.DebugLevel, "authenticating using user token")
7277
if userSessionExpired {
73-
err = reauthorizeUserRoutine(p, true)
78+
err = reauthorizeUserRoutine(p, StorageContextCLI, true)
7479
if err != nil {
7580
return nil, fmt.Errorf("user login: %w", err)
7681
}
@@ -84,7 +89,11 @@ func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print
8489
}
8590

8691
func UserSessionExpired() (bool, error) {
87-
sessionExpiresAtString, err := GetAuthField(SESSION_EXPIRES_AT_UNIX)
92+
return UserSessionExpiredWithContext(StorageContextCLI)
93+
}
94+
95+
func UserSessionExpiredWithContext(context StorageContext) (bool, error) {
96+
sessionExpiresAtString, err := GetAuthFieldWithContext(context, SESSION_EXPIRES_AT_UNIX)
8897
if err != nil {
8998
return false, fmt.Errorf("get %s: %w", SESSION_EXPIRES_AT_UNIX, err)
9099
}
@@ -98,7 +107,11 @@ func UserSessionExpired() (bool, error) {
98107
}
99108

100109
func GetAccessToken() (string, error) {
101-
accessToken, err := GetAuthField(ACCESS_TOKEN)
110+
return GetAccessTokenWithContext(StorageContextCLI)
111+
}
112+
113+
func GetAccessTokenWithContext(context StorageContext) (string, error) {
114+
accessToken, err := GetAuthFieldWithContext(context, ACCESS_TOKEN)
102115
if err != nil {
103116
return "", fmt.Errorf("get %s: %w", ACCESS_TOKEN, err)
104117
}
@@ -134,18 +147,47 @@ func getEmailFromToken(token string) (string, error) {
134147
return claims.Email, nil
135148
}
136149

150+
func getAccessTokenExpiresAtUnix(accessToken string) (string, error) {
151+
// Parse the access token to get its expiration time
152+
parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{})
153+
if err != nil {
154+
return "", fmt.Errorf("parse access token: %w", err)
155+
}
156+
157+
claims, ok := parsedAccessToken.Claims.(*jwt.RegisteredClaims)
158+
if !ok {
159+
return "", fmt.Errorf("get claims from parsed token: unknown claims type")
160+
}
161+
162+
if claims.ExpiresAt == nil {
163+
return "", fmt.Errorf("access token has no expiration claim")
164+
}
165+
166+
return strconv.FormatInt(claims.ExpiresAt.Unix(), 10), nil
167+
}
168+
137169
// GetValidAccessToken returns a valid access token for the current authentication flow.
138170
// For user token flows, it refreshes the token if necessary.
139171
// For service account flows, it returns the current access token.
140172
func GetValidAccessToken(p *print.Printer) (string, error) {
141-
flow, err := GetAuthFlow()
173+
return GetValidAccessTokenWithContext(p, StorageContextCLI)
174+
}
175+
176+
// GetValidAccessTokenWithContext returns a valid access token for the specified storage context.
177+
// For user token flows, it refreshes the token if necessary.
178+
// For service account flows, it returns the current access token.
179+
func GetValidAccessTokenWithContext(p *print.Printer, context StorageContext) (string, error) {
180+
// Set the storage printer so debug messages use the correct verbosity
181+
SetStoragePrinter(p)
182+
183+
flow, err := GetAuthFlowWithContext(context)
142184
if err != nil {
143185
return "", fmt.Errorf("get authentication flow: %w", err)
144186
}
145187

146188
// For service account flows, just return the current token
147189
if flow == AUTH_FLOW_SERVICE_ACCOUNT_TOKEN || flow == AUTH_FLOW_SERVICE_ACCOUNT_KEY {
148-
return GetAccessToken()
190+
return GetAccessTokenWithContext(context)
149191
}
150192

151193
if flow != AUTH_FLOW_USER_TOKEN {
@@ -158,7 +200,7 @@ func GetValidAccessToken(p *print.Printer) (string, error) {
158200
REFRESH_TOKEN: "",
159201
IDP_TOKEN_ENDPOINT: "",
160202
}
161-
err = GetAuthFieldMap(authFields)
203+
err = GetAuthFieldMapWithContext(context, authFields)
162204
if err != nil {
163205
return "", fmt.Errorf("get tokens from auth storage: %w", err)
164206
}
@@ -193,6 +235,7 @@ func GetValidAccessToken(p *print.Printer) (string, error) {
193235
utf := &userTokenFlow{
194236
printer: p,
195237
client: &http.Client{},
238+
context: context,
196239
authFlow: flow,
197240
accessToken: accessToken,
198241
refreshToken: refreshToken,
@@ -208,3 +251,53 @@ func GetValidAccessToken(p *print.Printer) (string, error) {
208251
// Return the new access token
209252
return utf.accessToken, nil
210253
}
254+
255+
// debugHTTPRequest logs the raw HTTP request details for debugging purposes
256+
func debugHTTPRequest(p *print.Printer, req *http.Request) {
257+
if p == nil || req == nil {
258+
return
259+
}
260+
261+
p.Debug(print.DebugLevel, "=== HTTP REQUEST ===")
262+
p.Debug(print.DebugLevel, "Method: %s", req.Method)
263+
p.Debug(print.DebugLevel, "URL: %s", req.URL.String())
264+
p.Debug(print.DebugLevel, "Headers:")
265+
for name, values := range req.Header {
266+
for _, value := range values {
267+
p.Debug(print.DebugLevel, " %s: %s", name, value)
268+
}
269+
}
270+
p.Debug(print.DebugLevel, "===================")
271+
}
272+
273+
// debugHTTPResponse logs the raw HTTP response details for debugging purposes
274+
func debugHTTPResponse(p *print.Printer, resp *http.Response) {
275+
if p == nil || resp == nil {
276+
return
277+
}
278+
279+
p.Debug(print.DebugLevel, "=== HTTP RESPONSE ===")
280+
p.Debug(print.DebugLevel, "Status: %s", resp.Status)
281+
p.Debug(print.DebugLevel, "Status Code: %d", resp.StatusCode)
282+
p.Debug(print.DebugLevel, "Headers:")
283+
for name, values := range resp.Header {
284+
for _, value := range values {
285+
p.Debug(print.DebugLevel, " %s: %s", name, value)
286+
}
287+
}
288+
289+
// Read and log body (need to restore it for later use)
290+
if resp.Body != nil {
291+
bodyBytes, err := io.ReadAll(resp.Body)
292+
if err != nil {
293+
p.Debug(print.ErrorLevel, "Error reading response body: %v", err)
294+
} else {
295+
// Restore the body for later use
296+
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
297+
298+
// Show raw body without sanitization
299+
p.Debug(print.DebugLevel, "Body: %s", string(bodyBytes))
300+
}
301+
}
302+
p.Debug(print.DebugLevel, "====================")
303+
}

internal/pkg/auth/auth_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ func TestAuthenticationConfig(t *testing.T) {
188188
}
189189

190190
reauthorizeUserCalled := false
191-
reauthenticateUser := func(_ *print.Printer, _ bool) error {
191+
reauthenticateUser := func(_ *print.Printer, _ StorageContext, _ bool) error {
192192
if reauthorizeUserCalled {
193193
t.Errorf("user reauthorized more than once")
194194
}

internal/pkg/auth/user_login.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ type apiClient interface {
5050
}
5151

5252
// AuthorizeUser implements the PKCE OAuth2 flow.
53-
func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
53+
func AuthorizeUser(p *print.Printer, context StorageContext, isReauthentication bool) error {
54+
// Set the storage printer so debug messages use the correct verbosity
55+
SetStoragePrinter(p)
56+
5457
idpWellKnownConfigURL, err := getIDPWellKnownConfigURL()
5558
if err != nil {
5659
return fmt.Errorf("get IDP well-known configuration: %w", err)
@@ -65,7 +68,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
6568

6669
p.Debug(print.DebugLevel, "get IDP well-known configuration from %s", idpWellKnownConfigURL)
6770
httpClient := &http.Client{}
68-
idpWellKnownConfig, err := parseWellKnownConfiguration(httpClient, idpWellKnownConfigURL)
71+
idpWellKnownConfig, err := parseWellKnownConfiguration(p, httpClient, idpWellKnownConfigURL, context)
6972
if err != nil {
7073
return fmt.Errorf("parse IDP well-known configuration: %w", err)
7174
}
@@ -159,29 +162,30 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
159162
p.Debug(print.DebugLevel, "trading authorization code for access and refresh tokens")
160163

161164
// Trade the authorization code and the code verifier for access and refresh tokens
162-
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpWellKnownConfig, idpClientID, codeVerifier, code, redirectURL)
165+
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(p, idpWellKnownConfig, idpClientID, codeVerifier, code, redirectURL)
163166
if err != nil {
164167
errServer = fmt.Errorf("retrieve tokens: %w", err)
165168
return
166169
}
167170

168171
p.Debug(print.DebugLevel, "received response from the authentication server")
169172

170-
sessionExpiresAtUnix, err := getStartingSessionExpiresAtUnix()
173+
// Get access token expiration from the token itself (not session time limit)
174+
sessionExpiresAtUnix, err := getAccessTokenExpiresAtUnix(accessToken)
171175
if err != nil {
172-
errServer = fmt.Errorf("compute session expiration timestamp: %w", err)
176+
errServer = fmt.Errorf("get access token expiration: %w", err)
173177
return
174178
}
175179

176180
sessionExpiresAtUnixInt, err := strconv.Atoi(sessionExpiresAtUnix)
177181
if err != nil {
178-
p.Debug(print.ErrorLevel, "parse session expiration value \"%s\": %s", sessionExpiresAtUnix, err)
182+
p.Debug(print.ErrorLevel, "parse access token expiration value \"%s\": %s", sessionExpiresAtUnix, err)
179183
} else {
180184
sessionExpiresAt := time.Unix(int64(sessionExpiresAtUnixInt), 0)
181-
p.Debug(print.DebugLevel, "session expires at %s", sessionExpiresAt)
185+
p.Debug(print.DebugLevel, "access token expires at %s", sessionExpiresAt)
182186
}
183187

184-
err = SetAuthFlow(AUTH_FLOW_USER_TOKEN)
188+
err = SetAuthFlowWithContext(context, AUTH_FLOW_USER_TOKEN)
185189
if err != nil {
186190
errServer = fmt.Errorf("set auth flow type: %w", err)
187191
return
@@ -195,7 +199,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
195199

196200
p.Debug(print.DebugLevel, "user %s logged in successfully", email)
197201

198-
err = LoginUser(email, accessToken, refreshToken, sessionExpiresAtUnix)
202+
err = LoginUserWithContext(context, email, accessToken, refreshToken, sessionExpiresAtUnix)
199203
if err != nil {
200204
errServer = fmt.Errorf("set in auth storage: %w", err)
201205
return
@@ -211,7 +215,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
211215
mux.HandleFunc(loginSuccessPath, func(w http.ResponseWriter, _ *http.Request) {
212216
defer cleanup(server)
213217

214-
email, err := GetAuthField(USER_EMAIL)
218+
email, err := GetAuthFieldWithContext(context, USER_EMAIL)
215219
if err != nil {
216220
errServer = fmt.Errorf("read user email: %w", err)
217221
}
@@ -265,7 +269,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
265269
}
266270

267271
// getUserAccessAndRefreshTokens trades the authorization code retrieved from the first OAuth2 leg for an access token and a refresh token
268-
func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) {
272+
func getUserAccessAndRefreshTokens(p *print.Printer, idpWellKnownConfig *wellKnownConfig, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) {
269273
// Set form-encoded data for the POST to the access token endpoint
270274
data := fmt.Sprintf(
271275
"grant_type=authorization_code&client_id=%s"+
@@ -278,6 +282,10 @@ func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID
278282
// Create the request and execute it
279283
req, _ := http.NewRequest("POST", idpWellKnownConfig.TokenEndpoint, payload)
280284
req.Header.Add("content-type", "application/x-www-form-urlencoded")
285+
286+
// Debug log the request
287+
debugHTTPRequest(p, req)
288+
281289
httpClient := &http.Client{}
282290
res, err := httpClient.Do(req)
283291
if err != nil {
@@ -291,6 +299,10 @@ func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID
291299
err = fmt.Errorf("close response body: %w", closeErr)
292300
}
293301
}()
302+
303+
// Debug log the response
304+
debugHTTPResponse(p, res)
305+
294306
body, err := io.ReadAll(res.Body)
295307
if err != nil {
296308
return "", "", fmt.Errorf("read response body: %w", err)
@@ -350,8 +362,12 @@ func openBrowser(pageUrl string) error {
350362

351363
// parseWellKnownConfiguration gets the well-known OpenID configuration from the provided URL and returns it as a JSON
352364
// the method also stores the IDP token endpoint in the authentication storage
353-
func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string) (wellKnownConfig *wellKnownConfig, err error) {
365+
func parseWellKnownConfiguration(p *print.Printer, httpClient apiClient, wellKnownConfigURL string, context StorageContext) (wellKnownConfig *wellKnownConfig, err error) {
354366
req, _ := http.NewRequest("GET", wellKnownConfigURL, http.NoBody)
367+
368+
// Debug log the request
369+
debugHTTPRequest(p, req)
370+
355371
res, err := httpClient.Do(req)
356372
if err != nil {
357373
return nil, fmt.Errorf("make the request: %w", err)
@@ -364,6 +380,10 @@ func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string
364380
err = fmt.Errorf("close response body: %w", closeErr)
365381
}
366382
}()
383+
384+
// Debug log the response
385+
debugHTTPResponse(p, res)
386+
367387
body, err := io.ReadAll(res.Body)
368388
if err != nil {
369389
return nil, fmt.Errorf("read response body: %w", err)
@@ -386,7 +406,7 @@ func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string
386406
return nil, fmt.Errorf("found no token endpoint")
387407
}
388408

389-
err = SetAuthField(IDP_TOKEN_ENDPOINT, wellKnownConfig.TokenEndpoint)
409+
err = SetAuthFieldWithContext(context, IDP_TOKEN_ENDPOINT, wellKnownConfig.TokenEndpoint)
390410
if err != nil {
391411
return nil, fmt.Errorf("set token endpoint in the authentication storage: %w", err)
392412
}

internal/pkg/auth/user_login_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99

1010
"github.com/google/go-cmp/cmp"
11+
"github.com/stackitcloud/stackit-cli/internal/pkg/print"
1112
"github.com/zalando/go-keyring"
1213
)
1314

@@ -93,7 +94,9 @@ func TestParseWellKnownConfig(t *testing.T) {
9394
tt.getResponse,
9495
}
9596

96-
got, err := parseWellKnownConfiguration(&testClient, "")
97+
p := print.NewPrinter()
98+
99+
got, err := parseWellKnownConfiguration(p, &testClient, "", StorageContextCLI)
97100

98101
if tt.isValid && err != nil {
99102
t.Fatalf("expected no error, got %v", err)

0 commit comments

Comments
 (0)