Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions internal/cmd/auth/get-access-token/get_access_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,12 @@ func NewCmd(params *params.CmdParams) *cobra.Command {
return &cliErr.SessionExpiredError{}
}

accessToken, err := auth.GetAccessToken()
// Try to get a valid access token, refreshing if necessary
accessToken, err := auth.RefreshAccessToken(params.Printer)
if err != nil {
return err
}

accessTokenExpired, err := auth.TokenExpired(accessToken)
if err != nil {
return err
}
if accessTokenExpired {
return &cliErr.AccessTokenExpiredError{}
}

params.Printer.Outputf("%s\n", accessToken)
return nil
},
Expand Down
69 changes: 69 additions & 0 deletions internal/pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"fmt"
"net/http"
"os"
"strconv"
"time"
Expand Down Expand Up @@ -132,3 +133,71 @@ func getEmailFromToken(token string) (string, error) {

return claims.Email, nil
}

// RefreshAccessToken refreshes the access token if it's expired for the user token flow.
// It returns the new access token or an error if the refresh fails.
func RefreshAccessToken(p *print.Printer) (string, error) {
flow, err := GetAuthFlow()
if err != nil {
return "", fmt.Errorf("get authentication flow: %w", err)
}
if flow != AUTH_FLOW_USER_TOKEN {
return "", fmt.Errorf("token refresh is only supported for user token flow, current flow: %s", flow)
}

// Load tokens from storage
authFields := map[authFieldKey]string{
ACCESS_TOKEN: "",
REFRESH_TOKEN: "",
IDP_TOKEN_ENDPOINT: "",
}
err = GetAuthFieldMap(authFields)
if err != nil {
return "", fmt.Errorf("get tokens from auth storage: %w", err)
}

accessToken := authFields[ACCESS_TOKEN]
refreshToken := authFields[REFRESH_TOKEN]
tokenEndpoint := authFields[IDP_TOKEN_ENDPOINT]

if accessToken == "" {
return "", fmt.Errorf("access token not set")
}
if refreshToken == "" {
return "", fmt.Errorf("refresh token not set")
}
if tokenEndpoint == "" {
return "", fmt.Errorf("token endpoint not set")
}

// Check if access token is expired
accessTokenExpired, err := TokenExpired(accessToken)
if err != nil {
return "", fmt.Errorf("check if access token has expired: %w", err)
}
if !accessTokenExpired {
// Token is still valid, return it
return accessToken, nil
}

p.Debug(print.DebugLevel, "access token expired, refreshing...")

// Create a temporary userTokenFlow to reuse the refresh logic
utf := &userTokenFlow{
printer: p,
client: &http.Client{},
authFlow: flow,
accessToken: accessToken,
refreshToken: refreshToken,
tokenEndpoint: tokenEndpoint,
}

// Refresh the tokens
err = refreshTokens(utf)
if err != nil {
return "", fmt.Errorf("refresh access token: %w", err)
}

// Return the new access token
return utf.accessToken, nil
}
Loading