Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions pkg/auth/monitored_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net"
"os"
"strconv"
"strings"
"sync"
"syscall"
"time"
Expand Down Expand Up @@ -173,10 +174,11 @@ func (mts *MonitoredTokenSource) Stopped() <-chan struct{} {
return mts.stopped
}

// Token retrieves a token, retrying with exponential backoff on transient network
// errors (DNS failures, TCP errors). On non-transient errors (OAuth 4xx, TLS failures)
// it marks the workload as unauthenticated and returns immediately. Context cancellation
// (workload removal) stops the retry without marking the workload as unauthenticated.
// Token retrieves a token, retrying with exponential backoff on transient errors
// (see isTransientNetworkError for the full list). On non-transient errors
// (OAuth 4xx, TLS failures) it marks the workload as unauthenticated and returns
// immediately. Context cancellation (workload removal) stops the retry without
// marking the workload as unauthenticated.
//
// Concurrent callers are deduplicated via singleflight so that only one retry
// loop runs at a time during transient failures.
Expand Down Expand Up @@ -322,20 +324,39 @@ func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) {
return false, wait
}

// isTransientNetworkError reports whether err represents a transient network condition
// (DNS failure, TCP transport error, timeout) that is likely to resolve when the network
// recovers — for example, after a VPN reconnects.
// isTransientNetworkError reports whether err represents a transient condition
// (DNS failure, TCP transport error, timeout, OAuth server 5xx, unparsable
// token response) that is likely to resolve on its own.
//
// OAuth2 HTTP-level auth failures (invalid_grant, 401, 400) and TLS errors
// OAuth2 client-level auth failures (invalid_grant, 401, 400) and TLS errors
// (certificate verification, handshake failure) are NOT considered transient and
// return false so the workload is marked unauthenticated immediately.
func isTransientNetworkError(err error) bool {
if err == nil ||
errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) ||
errors.As(err, new(*oauth2.RetrieveError)) {
errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}

// OAuth HTTP-level errors: 5xx (Bad Gateway, Service Unavailable, Gateway
// Timeout) are transient server-side issues that typically resolve on their
// own. 4xx errors (invalid_grant, invalid_client) are permanent auth failures.
if retrieveErr, ok := errors.AsType[*oauth2.RetrieveError](err); ok {
if retrieveErr.Response != nil && retrieveErr.Response.StatusCode >= 500 {
slog.Debug("treating OAuth server error as transient",
"status_code", retrieveErr.Response.StatusCode,
)
return true
}
return false
}

// Non-JSON responses from the OAuth server (e.g. load balancer HTML pages).
// The oauth2 library returns a plain error (not *RetrieveError) when the
// HTTP status is 2xx but the body cannot be parsed as JSON.
if isOAuthParseError(err) {
return true
}

// DNS lookup failures — covers VPN-disconnect scenarios where the corporate DNS
// resolver is unreachable.
if _, ok := errors.AsType[*net.DNSError](err); ok {
Expand All @@ -360,6 +381,21 @@ func isTransientNetworkError(err error) bool {
return false
}

// isOAuthParseError detects errors from the oauth2 library that indicate the
// token endpoint returned an unparsable response body on a 2xx status. This
// typically happens when a load balancer, CDN, or reverse proxy intercepts the
// request and returns its own HTML page instead of the expected JSON token
// response. The oauth2 library uses fmt.Errorf with %v (not %w) for these
// errors, so string matching is the only reliable detection method.
func isOAuthParseError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "oauth2: cannot parse json") ||
strings.Contains(msg, "oauth2: cannot parse response")
}

// markAsUnauthenticated marks the workload as unauthenticated and stops background monitoring.
func (mts *MonitoredTokenSource) markAsUnauthenticated(reason string) {
_ = mts.statusUpdater.SetWorkloadStatus(
Expand Down
11 changes: 11 additions & 0 deletions pkg/auth/monitored_token_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package auth
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -561,11 +562,21 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T
{name: "context.DeadlineExceeded", err: context.DeadlineExceeded, isTransient: false},
{name: "oauth2.RetrieveError 401", err: createRetrieveError(http.StatusUnauthorized, "unauthorized"), isTransient: false},
{name: "oauth2.RetrieveError 400 invalid_grant", err: createRetrieveError(http.StatusBadRequest, "invalid_grant"), isTransient: false},
{name: "oauth2.RetrieveError nil response", err: &oauth2.RetrieveError{}, isTransient: false},
// Transient: network-level errors must be retried.
{name: "*net.DNSError timeout", err: &net.DNSError{Err: "i/o timeout", Name: "example.com", IsTimeout: true}, isTransient: true},
{name: "*net.OpError connection refused", err: &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}}, isTransient: true},
{name: "*url.Error wrapping *net.OpError", err: &url.Error{Op: "Post", URL: "https://example.com/token", Err: &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}}}, isTransient: true},
{name: "net.Error timeout", err: &timeoutNetError{}, isTransient: true},
// Transient: OAuth server 5xx errors (load balancer, server restart).
{name: "oauth2.RetrieveError 500", err: createRetrieveError(http.StatusInternalServerError, "Internal Server Error"), isTransient: true},
{name: "oauth2.RetrieveError 502", err: createRetrieveError(http.StatusBadGateway, "Bad Gateway"), isTransient: true},
{name: "oauth2.RetrieveError 503", err: createRetrieveError(http.StatusServiceUnavailable, "Service Unavailable"), isTransient: true},
{name: "oauth2.RetrieveError 504", err: createRetrieveError(http.StatusGatewayTimeout, "Gateway Timeout"), isTransient: true},
// Transient: unparsable OAuth responses (HTML from load balancer on 200).
{name: "oauth2 cannot parse json", err: fmt.Errorf("oauth2: cannot parse json: invalid character '<'"), isTransient: true},
{name: "wrapped oauth2 parse error", err: fmt.Errorf("refresh failed: %w", fmt.Errorf("oauth2: cannot parse json: invalid character '<'")), isTransient: true},
{name: "oauth2 cannot parse response", err: fmt.Errorf("oauth2: cannot parse response: invalid URL escape"), isTransient: true},
}

for _, tt := range tests {
Expand Down
Loading