diff --git a/pkg/auth/monitored_token_source.go b/pkg/auth/monitored_token_source.go index 5f3ed4c431..b0d075302c 100644 --- a/pkg/auth/monitored_token_source.go +++ b/pkg/auth/monitored_token_source.go @@ -11,6 +11,7 @@ import ( "net" "os" "strconv" + "strings" "sync" "syscall" "time" @@ -173,10 +174,11 @@ func (mts *MonitoredTokenSource) Stopped() <-chan struct{} { return mts.stopped } -// Token retrieves a token, retrying with exponential backoff on transient network -// errors (DNS failures, TCP errors). On non-transient errors (OAuth 4xx, TLS failures) -// it marks the workload as unauthenticated and returns immediately. Context cancellation -// (workload removal) stops the retry without marking the workload as unauthenticated. +// Token retrieves a token, retrying with exponential backoff on transient errors +// (see isTransientNetworkError for the full list). On non-transient errors +// (OAuth 4xx, TLS failures) it marks the workload as unauthenticated and returns +// immediately. Context cancellation (workload removal) stops the retry without +// marking the workload as unauthenticated. // // Concurrent callers are deduplicated via singleflight so that only one retry // loop runs at a time during transient failures. @@ -322,20 +324,39 @@ func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) { return false, wait } -// isTransientNetworkError reports whether err represents a transient network condition -// (DNS failure, TCP transport error, timeout) that is likely to resolve when the network -// recovers — for example, after a VPN reconnects. +// isTransientNetworkError reports whether err represents a transient condition +// (DNS failure, TCP transport error, timeout, OAuth server 5xx, unparsable +// token response) that is likely to resolve on its own. // -// OAuth2 HTTP-level auth failures (invalid_grant, 401, 400) and TLS errors +// OAuth2 client-level auth failures (invalid_grant, 401, 400) and TLS errors // (certificate verification, handshake failure) are NOT considered transient and // return false so the workload is marked unauthenticated immediately. func isTransientNetworkError(err error) bool { if err == nil || - errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || - errors.As(err, new(*oauth2.RetrieveError)) { + errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } + // OAuth HTTP-level errors: 5xx (Bad Gateway, Service Unavailable, Gateway + // Timeout) are transient server-side issues that typically resolve on their + // own. 4xx errors (invalid_grant, invalid_client) are permanent auth failures. + if retrieveErr, ok := errors.AsType[*oauth2.RetrieveError](err); ok { + if retrieveErr.Response != nil && retrieveErr.Response.StatusCode >= 500 { + slog.Debug("treating OAuth server error as transient", + "status_code", retrieveErr.Response.StatusCode, + ) + return true + } + return false + } + + // Non-JSON responses from the OAuth server (e.g. load balancer HTML pages). + // The oauth2 library returns a plain error (not *RetrieveError) when the + // HTTP status is 2xx but the body cannot be parsed as JSON. + if isOAuthParseError(err) { + return true + } + // DNS lookup failures — covers VPN-disconnect scenarios where the corporate DNS // resolver is unreachable. if _, ok := errors.AsType[*net.DNSError](err); ok { @@ -360,6 +381,21 @@ func isTransientNetworkError(err error) bool { return false } +// isOAuthParseError detects errors from the oauth2 library that indicate the +// token endpoint returned an unparsable response body on a 2xx status. This +// typically happens when a load balancer, CDN, or reverse proxy intercepts the +// request and returns its own HTML page instead of the expected JSON token +// response. The oauth2 library uses fmt.Errorf with %v (not %w) for these +// errors, so string matching is the only reliable detection method. +func isOAuthParseError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "oauth2: cannot parse json") || + strings.Contains(msg, "oauth2: cannot parse response") +} + // markAsUnauthenticated marks the workload as unauthenticated and stops background monitoring. func (mts *MonitoredTokenSource) markAsUnauthenticated(reason string) { _ = mts.statusUpdater.SetWorkloadStatus( diff --git a/pkg/auth/monitored_token_source_test.go b/pkg/auth/monitored_token_source_test.go index 342082820a..9a2fa527f0 100644 --- a/pkg/auth/monitored_token_source_test.go +++ b/pkg/auth/monitored_token_source_test.go @@ -6,6 +6,7 @@ package auth import ( "context" "errors" + "fmt" "net" "net/http" "net/url" @@ -561,11 +562,21 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T {name: "context.DeadlineExceeded", err: context.DeadlineExceeded, isTransient: false}, {name: "oauth2.RetrieveError 401", err: createRetrieveError(http.StatusUnauthorized, "unauthorized"), isTransient: false}, {name: "oauth2.RetrieveError 400 invalid_grant", err: createRetrieveError(http.StatusBadRequest, "invalid_grant"), isTransient: false}, + {name: "oauth2.RetrieveError nil response", err: &oauth2.RetrieveError{}, isTransient: false}, // Transient: network-level errors must be retried. {name: "*net.DNSError timeout", err: &net.DNSError{Err: "i/o timeout", Name: "example.com", IsTimeout: true}, isTransient: true}, {name: "*net.OpError connection refused", err: &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}}, isTransient: true}, {name: "*url.Error wrapping *net.OpError", err: &url.Error{Op: "Post", URL: "https://example.com/token", Err: &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}}}, isTransient: true}, {name: "net.Error timeout", err: &timeoutNetError{}, isTransient: true}, + // Transient: OAuth server 5xx errors (load balancer, server restart). + {name: "oauth2.RetrieveError 500", err: createRetrieveError(http.StatusInternalServerError, "Internal Server Error"), isTransient: true}, + {name: "oauth2.RetrieveError 502", err: createRetrieveError(http.StatusBadGateway, "Bad Gateway"), isTransient: true}, + {name: "oauth2.RetrieveError 503", err: createRetrieveError(http.StatusServiceUnavailable, "Service Unavailable"), isTransient: true}, + {name: "oauth2.RetrieveError 504", err: createRetrieveError(http.StatusGatewayTimeout, "Gateway Timeout"), isTransient: true}, + // Transient: unparsable OAuth responses (HTML from load balancer on 200). + {name: "oauth2 cannot parse json", err: fmt.Errorf("oauth2: cannot parse json: invalid character '<'"), isTransient: true}, + {name: "wrapped oauth2 parse error", err: fmt.Errorf("refresh failed: %w", fmt.Errorf("oauth2: cannot parse json: invalid character '<'")), isTransient: true}, + {name: "oauth2 cannot parse response", err: fmt.Errorf("oauth2: cannot parse response: invalid URL escape"), isTransient: true}, } for _, tt := range tests {