From e402723921f10d8330df3e487bde906ef883aa11 Mon Sep 17 00:00:00 2001 From: Benjosh95 Date: Wed, 9 Jul 2025 09:49:10 +0200 Subject: [PATCH] add refreshing access token functionality to get-access-token command --- .../auth/get-access-token/get_access_token.go | 11 +-- internal/pkg/auth/auth.go | 69 +++++++++++++++++++ 2 files changed, 71 insertions(+), 9 deletions(-) diff --git a/internal/cmd/auth/get-access-token/get_access_token.go b/internal/cmd/auth/get-access-token/get_access_token.go index 0b13fd2fa..25576427a 100644 --- a/internal/cmd/auth/get-access-token/get_access_token.go +++ b/internal/cmd/auth/get-access-token/get_access_token.go @@ -29,19 +29,12 @@ func NewCmd(params *params.CmdParams) *cobra.Command { return &cliErr.SessionExpiredError{} } - accessToken, err := auth.GetAccessToken() + // Try to get a valid access token, refreshing if necessary + accessToken, err := auth.RefreshAccessToken(params.Printer) if err != nil { return err } - accessTokenExpired, err := auth.TokenExpired(accessToken) - if err != nil { - return err - } - if accessTokenExpired { - return &cliErr.AccessTokenExpiredError{} - } - params.Printer.Outputf("%s\n", accessToken) return nil }, diff --git a/internal/pkg/auth/auth.go b/internal/pkg/auth/auth.go index 89a39ac29..634813f24 100644 --- a/internal/pkg/auth/auth.go +++ b/internal/pkg/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "fmt" + "net/http" "os" "strconv" "time" @@ -132,3 +133,71 @@ func getEmailFromToken(token string) (string, error) { return claims.Email, nil } + +// RefreshAccessToken refreshes the access token if it's expired for the user token flow. +// It returns the new access token or an error if the refresh fails. +func RefreshAccessToken(p *print.Printer) (string, error) { + flow, err := GetAuthFlow() + if err != nil { + return "", fmt.Errorf("get authentication flow: %w", err) + } + if flow != AUTH_FLOW_USER_TOKEN { + return "", fmt.Errorf("token refresh is only supported for user token flow, current flow: %s", flow) + } + + // Load tokens from storage + authFields := map[authFieldKey]string{ + ACCESS_TOKEN: "", + REFRESH_TOKEN: "", + IDP_TOKEN_ENDPOINT: "", + } + err = GetAuthFieldMap(authFields) + if err != nil { + return "", fmt.Errorf("get tokens from auth storage: %w", err) + } + + accessToken := authFields[ACCESS_TOKEN] + refreshToken := authFields[REFRESH_TOKEN] + tokenEndpoint := authFields[IDP_TOKEN_ENDPOINT] + + if accessToken == "" { + return "", fmt.Errorf("access token not set") + } + if refreshToken == "" { + return "", fmt.Errorf("refresh token not set") + } + if tokenEndpoint == "" { + return "", fmt.Errorf("token endpoint not set") + } + + // Check if access token is expired + accessTokenExpired, err := TokenExpired(accessToken) + if err != nil { + return "", fmt.Errorf("check if access token has expired: %w", err) + } + if !accessTokenExpired { + // Token is still valid, return it + return accessToken, nil + } + + p.Debug(print.DebugLevel, "access token expired, refreshing...") + + // Create a temporary userTokenFlow to reuse the refresh logic + utf := &userTokenFlow{ + printer: p, + client: &http.Client{}, + authFlow: flow, + accessToken: accessToken, + refreshToken: refreshToken, + tokenEndpoint: tokenEndpoint, + } + + // Refresh the tokens + err = refreshTokens(utf) + if err != nil { + return "", fmt.Errorf("refresh access token: %w", err) + } + + // Return the new access token + return utf.accessToken, nil +}