diff --git a/CHANGELOG.md b/CHANGELOG.md index a060f878..2f52c6a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,26 @@ All notable changes to Authorizer will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- **`--rate-limit-fail-closed`**: when the rate-limit backend returns an error, respond with `503` instead of allowing the request (default remains fail-open). +- **`--metrics-host`**: bind address for the dedicated `/metrics` listener (default `127.0.0.1`). Use `0.0.0.0` when a scraper on another host/pod must reach the metrics port over the network; keep the metrics port off public ingress. + +### Changed + +- **Prometheus `/metrics`**: always served on a **dedicated** HTTP listener (`--metrics-host`:`--metrics-port`, default `127.0.0.1:8081`). **`--http-port` and `--metrics-port` must differ**; `/metrics` is not registered on the main Gin server. +- **HTTP metrics**: unmatched Gin routes use the fixed path label `unmatched` instead of the raw request URL (prevents cardinality attacks). +- **GraphQL metrics**: the `operation` label is now `anonymous` or `op_` so client-supplied operation names cannot explode time-series cardinality. +- **Health/readiness JSON**: failure responses return a generic `error` string; details remain in server logs. +- **OAuth callback JSON**: generic OAuth-style error body on provider processing failure; details remain in logs. +- **`/playground`** is subject to the same per-IP rate limits as other routes (health and OIDC discovery paths stay exempt). **`/metrics`** is not on the main HTTP router. + +### Removed + +- **`authorizer_client_id_not_found_total`**: replaced by **`authorizer_client_id_header_missing_total`**, which matches the actual behavior (header omitted, request still allowed). Update dashboards and alerts accordingly. + ## [2.0.0] - 2025-02-28 ### Added diff --git a/Dockerfile b/Dockerfile index 0bdcf282..9e061416 100644 --- a/Dockerfile +++ b/Dockerfile @@ -69,7 +69,18 @@ RUN addgroup -g 1000 authorizer && \ USER authorizer +# Ports (see docs: deployment/docker, deployment/kubernetes) +# - EXPOSE is documentation only: it does NOT publish ports on the Docker host. +# - 8080: main HTTP API (OAuth, GraphQL, health on /healthz, etc.). This is what you +# typically map with -p 8080:8080 or put behind an Ingress / load balancer. +# - 8081: dedicated Prometheus /metrics listener. By default the process binds it to +# 127.0.0.1, so other containers cannot scrape until you pass --metrics-host=0.0.0.0. +# Even then: do not map 8081 to the public internet; keep scraping on internal networks +# only (Docker internal network, Kubernetes ClusterIP / pod network). EXPOSE 8080 8081 -HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD wget -qO- http://localhost:8080/ || exit 1 + +# Liveness uses the main HTTP server only (metrics may be loopback-only). +HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD wget -qO- http://127.0.0.1:8080/healthz || exit 1 + ENTRYPOINT [ "./authorizer" ] CMD [] diff --git a/MIGRATION.md b/MIGRATION.md index 5c5db1ef..e8df1dca 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -205,8 +205,13 @@ Use these v2 **CLI flags** instead of v1 env or dashboard config. Flag names use | `PORT` | `--http-port` (default: 8080) | | Host | `--host` (default: 0.0.0.0) | | Metrics port | `--metrics-port` (default: 8081) | +| Metrics bind | `--metrics-host` (default: `127.0.0.1`) for the dedicated metrics listener only | | `LOG_LEVEL` | `--log-level` | +**Metrics:** `GET /metrics` is **always** on a **separate** minimal HTTP server at **`--metrics-host`:`--metrics-port`** (default **`127.0.0.1:8081`**). **`--http-port` and `--metrics-port` must differ**; the main Gin server does not expose `/metrics`. Use `--metrics-host=0.0.0.0` when Prometheus scrapes from another container or pod (keep the metrics port off public load balancers). + +**Rate limiting:** `--rate-limit-fail-closed` rejects requests with `503` when the rate-limit backend errors; the default remains fail-open (allow) for availability. + ### Database diff --git a/ROADMAP_V2.md b/ROADMAP_V2.md index 89e12738..3e365b0d 100644 --- a/ROADMAP_V2.md +++ b/ROADMAP_V2.md @@ -97,21 +97,13 @@ These are table-stakes features that every competitor has. Without them, Authori **Why**: Keycloak has full Prometheus/Grafana support. Essential for production deployments. -- [ ] **`/metrics` endpoint** (OpenMetrics/Prometheus format) - - `authorizer_login_total{method,status}` -- login attempts by method and success/failure - - `authorizer_signup_total{method,status}` -- signup attempts - - `authorizer_token_issued_total{type}` -- tokens issued by type - - `authorizer_active_sessions` -- current active sessions gauge - - `authorizer_request_duration_seconds{endpoint,method}` -- request latency histogram - - `authorizer_db_query_duration_seconds` -- database query latency - - `authorizer_failed_login_total` -- failed logins (for alerting) - - `authorizer_account_lockouts_total` -- lockout events - - Go runtime metrics (goroutines, memory, GC) +- [x] **`/metrics` endpoint** (OpenMetrics/Prometheus format) — implemented (`authorizer_*` metrics; always on dedicated `--metrics-host`:`--metrics-port`). Further metric parity (below) remains roadmap. + - Planned / partial vs Keycloak-style names: `authorizer_login_total{method,status}`, `authorizer_signup_total{method,status}`, `authorizer_token_issued_total{type}`, `authorizer_db_query_duration_seconds`, `authorizer_failed_login_total`, `authorizer_account_lockouts_total`, Go runtime metrics (goroutines, memory, GC) - [ ] **Enhanced `/health` endpoint** returning JSON with component status ```json {"status": "healthy", "db": "ok", "redis": "ok", "uptime": "72h"} ``` -- [ ] **Readiness/liveness probes** (`/healthz`, `/readyz`) for Kubernetes +- [x] **Readiness/liveness probes** (`/healthz`, `/readyz`, `/health`) for Kubernetes ### 1.5 Session Security Enhancements diff --git a/cmd/root.go b/cmd/root.go index b8aaef3a..a00dcb00 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "fmt" "os" "os/signal" "strings" @@ -32,6 +33,7 @@ import ( // Default values for flags (single source of truth for init and applyFlagDefaults). var ( defaultHost = "0.0.0.0" + defaultMetricsHost = "127.0.0.1" defaultLogLevel = "debug" defaultHTTPPort = 8080 defaultMetricsPort = 8081 @@ -52,8 +54,9 @@ var ( defaultDiscordScopes = []string{"identify", "email"} defaultTwitterScopes = []string{"tweet.read", "users.read"} defaultRobloxScopes = []string{"openid", "profile"} - defaultRateLimitRPS = float64(10) - defaultRateLimitBurst = 20 + // Default RPS cap per IP; raised from 10 to reduce false positives on busy UIs. + defaultRateLimitRPS = float64(30) + defaultRateLimitBurst = 20 ) var ( @@ -74,7 +77,8 @@ func init() { // Server flags f.StringVar(&rootArgs.server.Host, "host", defaultHost, "Host address to listen on") f.IntVar(&rootArgs.server.HTTPPort, "http-port", defaultHTTPPort, "Port to serve HTTP requests on") - f.IntVar(&rootArgs.server.MetricsPort, "metrics-port", defaultMetricsPort, "Port to serve metrics requests on") + f.IntVar(&rootArgs.server.MetricsPort, "metrics-port", defaultMetricsPort, "Port for the dedicated /metrics listener (must differ from --http-port)") + f.StringVar(&rootArgs.server.MetricsHost, "metrics-host", defaultMetricsHost, "Bind address for the dedicated /metrics listener (default loopback; use 0.0.0.0 when Prometheus scrapes from another host/pod)") // Logging flags f.StringVar(&rootArgs.logLevel, "log-level", defaultLogLevel, "Log level to use") @@ -159,6 +163,7 @@ func init() { // Rate limiting flags f.Float64Var(&rootArgs.config.RateLimitRPS, "rate-limit-rps", defaultRateLimitRPS, "Maximum requests per second per IP for rate limiting") f.IntVar(&rootArgs.config.RateLimitBurst, "rate-limit-burst", defaultRateLimitBurst, "Maximum burst size per IP for rate limiting") + f.BoolVar(&rootArgs.config.RateLimitFailClosed, "rate-limit-fail-closed", false, "On rate-limit backend errors, reject with 503 instead of allowing the request") // JWT flags f.StringVar(&rootArgs.config.JWTType, "jwt-type", "", "Type of JWT to use") @@ -230,6 +235,9 @@ func applyFlagDefaults() { if s.MetricsPort == 0 { s.MetricsPort = defaultMetricsPort } + if strings.TrimSpace(s.MetricsHost) == "" { + s.MetricsHost = defaultMetricsHost + } if strings.TrimSpace(rootArgs.logLevel) == "" { rootArgs.logLevel = defaultLogLevel } @@ -298,6 +306,10 @@ func applyFlagDefaults() { // Run the service func runRoot(c *cobra.Command, args []string) { applyFlagDefaults() + if rootArgs.server.HTTPPort == rootArgs.server.MetricsPort { + fmt.Fprintf(os.Stderr, "invalid server ports: --http-port and --metrics-port must differ (metrics are always served on a dedicated listener)\n") + os.Exit(1) + } // Prepare logger ctx := context.Background() @@ -340,6 +352,11 @@ func runRoot(c *cobra.Command, args []string) { if err != nil { log.Fatal().Err(err).Msg("failed to create storage provider") } + defer func() { + if err := storageProvider.Close(); err != nil { + log.Error().Err(err).Msg("failed to close storage provider") + } + }() // Authenticator provider authenticatorProvider, err := authenticators.New(&rootArgs.config, &authenticators.Dependencies{ diff --git a/internal/config/config.go b/internal/config/config.go index 2eb246a1..77f94795 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,6 +4,10 @@ package config type Config struct { // Env is the environment of the authorizer instance Env string + // SkipTestEndpointSSRFValidation relaxes SSRF checks for the admin TestEndpoint GraphQL + // mutation (e.g. to hit localhost in tests). Must remain false in production; integration + // tests enable it together with Env=test. + SkipTestEndpointSSRFValidation bool // OrganizationLogo is the logo of the organization OrganizationLogo string // OrganizationName is the name of the organization @@ -253,4 +257,6 @@ type Config struct { RateLimitRPS float64 // RateLimitBurst is the maximum burst size per IP RateLimitBurst int + // RateLimitFailClosed rejects requests when the rate limit backend errors (default: fail-open). + RateLimitFailClosed bool } diff --git a/internal/graphql/test_endpoint.go b/internal/graphql/test_endpoint.go index 48f77233..80a90c17 100644 --- a/internal/graphql/test_endpoint.go +++ b/internal/graphql/test_endpoint.go @@ -72,10 +72,14 @@ func (g *graphqlProvider) TestEndpoint(ctx context.Context, params *model.TestEn return nil, err } - // SSRF protection: validate endpoint URL and resolved IPs - if err := validators.ValidateEndpointURL(params.Endpoint); err != nil { - log.Debug().Err(err).Str("endpoint", params.Endpoint).Msg("endpoint URL rejected by SSRF filter") - return nil, fmt.Errorf("invalid endpoint: %s", err.Error()) + // SSRF protection: validate endpoint URL and resolved IPs. Skipped only when tests + // explicitly set SkipTestEndpointSSRFValidation (never enable that flag in production). + skipSSRF := g.Config.Env == constants.TestEnv && g.Config.SkipTestEndpointSSRFValidation + if !skipSSRF { + if err := validators.ValidateEndpointURL(params.Endpoint); err != nil { + log.Debug().Err(err).Str("endpoint", params.Endpoint).Msg("endpoint URL rejected by SSRF filter") + return nil, fmt.Errorf("invalid endpoint: %w", err) + } } req, err := http.NewRequest("POST", params.Endpoint, bytes.NewBuffer(requestBody)) @@ -107,4 +111,3 @@ func (g *graphqlProvider) TestEndpoint(ctx context.Context, params *model.TestEn Response: refs.NewStringRef(string(body)), }, nil } - diff --git a/internal/http_handlers/app.go b/internal/http_handlers/app.go index 4c1bd793..675d8625 100644 --- a/internal/http_handlers/app.go +++ b/internal/http_handlers/app.go @@ -68,6 +68,7 @@ func (h *httpProvider) AppHandler() gin.HandlerFunc { "state": state, "organizationName": orgName, "organizationLogo": orgLogo, + "clientId": h.Config.ClientID, }, }) } diff --git a/internal/http_handlers/client_check.go b/internal/http_handlers/client_check.go index f92b4f1c..e38c8eaa 100644 --- a/internal/http_handlers/client_check.go +++ b/internal/http_handlers/client_check.go @@ -2,8 +2,11 @@ package http_handlers import ( "net/http" + "strings" "github.com/gin-gonic/gin" + + "github.com/authorizerdev/authorizer/internal/metrics" ) // ClientCheckMiddleware is a middleware to verify the client ID @@ -12,16 +15,25 @@ import ( // (e.g., OAuth callbacks, JWKS, OpenID configuration, health checks). // The middleware only rejects requests with an explicitly wrong client ID. func (h *httpProvider) ClientCheckMiddleware() gin.HandlerFunc { - log := h.Log.With().Str("func", "ClientCheckMiddleware").Logger() return func(c *gin.Context) { + log := h.Log.With().Str("func", "ClientCheckMiddleware"). + Str("path", c.Request.URL.Path). + Logger() clientID := c.Request.Header.Get("X-Authorizer-Client-ID") if clientID == "" { log.Debug().Msg("request received without client ID header") + metrics.RecordClientIDHeaderMissing() c.Next() return } if clientID != h.Config.ClientID { + // Record metric for client-id mismatch, but skip dashboard and app UI routes + // as those are internal requests that should not trigger security alerts. + path := c.Request.URL.Path + if !strings.HasPrefix(path, "/dashboard") && !strings.HasPrefix(path, "/app") { + metrics.RecordSecurityEvent("client_id_mismatch", "invalid_client_id") + } log.Debug().Str("client_id", clientID).Msg("Client ID is invalid") c.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_client_id", diff --git a/internal/http_handlers/context.go b/internal/http_handlers/context.go index 3c4ca263..08867459 100644 --- a/internal/http_handlers/context.go +++ b/internal/http_handlers/context.go @@ -1,20 +1,15 @@ package http_handlers import ( - "context" - "github.com/gin-gonic/gin" -) -// Define a custom type for context key -type contextKey string - -const ginContextKey contextKey = "GinContextKey" + "github.com/authorizerdev/authorizer/internal/utils" +) // ContextMiddleware is a middleware to add gin context in context func (h *httpProvider) ContextMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - ctx := context.WithValue(c.Request.Context(), ginContextKey, c) + ctx := utils.ContextWithGin(c.Request.Context(), c) c.Request = c.Request.WithContext(ctx) c.Next() } diff --git a/internal/http_handlers/graphql.go b/internal/http_handlers/graphql.go index b8f3dc28..ffca64a7 100644 --- a/internal/http_handlers/graphql.go +++ b/internal/http_handlers/graphql.go @@ -3,6 +3,8 @@ package http_handlers import ( "context" "net/http" + "sort" + "sync" "time" gql "github.com/99designs/gqlgen/graphql" @@ -17,22 +19,48 @@ import ( "github.com/authorizerdev/authorizer/internal/graph/generated" "github.com/authorizerdev/authorizer/internal/graphql" "github.com/authorizerdev/authorizer/internal/metrics" + "github.com/authorizerdev/authorizer/internal/utils" ) -func (h *httpProvider) gqlLoggingMiddleware() gql.FieldMiddleware { - return func(ctx context.Context, next gql.Resolver) (res interface{}, err error) { - // Get details of the current operation - oc := gql.GetOperationContext(ctx) - field := gql.GetFieldContext(ctx) +type gqlResolvedFieldsCtxKey struct{} - // Log operation details - h.Dependencies.Log.Info(). - Str("operation", oc.OperationName). - Str("query", field.Field.Name). - // Interface("arguments", field.Args). // Enable only for debugging purpose else sensitive data will be logged - Msg("GraphQL field resolved") +// resolvedFieldsCollector gathers unique GraphQL field names for one operation. +type resolvedFieldsCollector struct { + mu sync.Mutex + fields map[string]struct{} +} - // Call the next resolver +func (c *resolvedFieldsCollector) add(name string) { + if name == "" { + return + } + c.mu.Lock() + defer c.mu.Unlock() + if c.fields == nil { + c.fields = make(map[string]struct{}) + } + c.fields[name] = struct{}{} +} + +func (c *resolvedFieldsCollector) sortedUnique() []string { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]string, 0, len(c.fields)) + for f := range c.fields { + out = append(out, f) + } + sort.Strings(out) + return out +} + +// gqlCollectResolvedFieldsMiddleware records each resolved field name into the per-operation collector. +func (*httpProvider) gqlCollectResolvedFieldsMiddleware() gql.FieldMiddleware { + return func(ctx context.Context, next gql.Resolver) (interface{}, error) { + if col, ok := ctx.Value(gqlResolvedFieldsCtxKey{}).(*resolvedFieldsCollector); ok && col != nil { + if fc := gql.GetFieldContext(ctx); fc != nil && fc.Field.Field != nil { + col.add(fc.Field.Name) + } + } return next(ctx) } } @@ -41,31 +69,50 @@ func (h *httpProvider) gqlLoggingMiddleware() gql.FieldMiddleware { // It captures errors returned in HTTP 200 responses (GraphQL convention). func (h *httpProvider) gqlMetricsMiddleware() gql.OperationMiddleware { return func(ctx context.Context, next gql.OperationHandler) gql.ResponseHandler { - oc := gql.GetOperationContext(ctx) - operationName := oc.OperationName - if operationName == "" { - operationName = "anonymous" + operationName := "" + if oc := gql.GetOperationContext(ctx); oc != nil { + operationName = oc.OperationName } + opMetricLabel := metrics.GraphQLOperationPrometheusLabel(operationName) start := time.Now() + collector := &resolvedFieldsCollector{} + ctx = context.WithValue(ctx, gqlResolvedFieldsCtxKey{}, collector) + responseHandler := next(ctx) return func(ctx context.Context) *gql.Response { resp := responseHandler(ctx) - if resp != nil { - duration := time.Since(start).Seconds() - metrics.GraphQLRequestDuration.WithLabelValues(operationName).Observe(duration) + fields := collector.sortedUnique() + if resp == nil { + h.Dependencies.Log.Warn(). + Str("operation", operationName). + Str("operation_metric_label", opMetricLabel). + Strs("resolved_fields", fields). + Msg("GraphQL operation returned no response") + return resp + } + duration := time.Since(start).Seconds() + metrics.GraphQLRequestDuration.WithLabelValues(opMetricLabel).Observe(duration) - if len(resp.Errors) > 0 { - metrics.RecordGraphQLError(operationName) - } + if len(resp.Errors) > 0 { + metrics.RecordGraphQLError(operationName) } + logEvt := h.Dependencies.Log.Info(). + Str("operation", operationName). + Str("operation_metric_label", opMetricLabel). + Int("resolved_field_count", len(fields)) + logEvt.Msg("GraphQL operation completed") + h.Dependencies.Log.Debug(). + Str("operation", operationName). + Strs("resolved_fields", fields). + Msg("GraphQL resolved fields") return resp } } } -// GraphqlHandler is the main handler that handels all the graphql requests +// GraphqlHandler is the main handler that handles all GraphQL requests. func (h *httpProvider) GraphqlHandler() gin.HandlerFunc { gqlProvider, err := graphql.New(h.Config, &graphql.Dependencies{ Log: h.Log, @@ -80,7 +127,12 @@ func (h *httpProvider) GraphqlHandler() gin.HandlerFunc { }) if err != nil { h.Log.Error().Err(err).Msg("Failed to create graphql provider") - return nil + return func(c *gin.Context) { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "graphql_unavailable", + "error_description": "GraphQL service failed to initialize.", + }) + } } // NewExecutableSchema and Config are in the generated.go file @@ -94,7 +146,7 @@ func (h *httpProvider) GraphqlHandler() gin.HandlerFunc { srv.AddTransport(transport.POST{}) srv.SetQueryCache(lru.New[*ast.QueryDocument](1000)) - srv.AroundFields(h.gqlLoggingMiddleware()) + srv.AroundFields(h.gqlCollectResolvedFieldsMiddleware()) srv.AroundOperations(h.gqlMetricsMiddleware()) if h.Config.EnableGraphQLIntrospection { srv.Use(extension.Introspection{}) @@ -109,7 +161,7 @@ func (h *httpProvider) GraphqlHandler() gin.HandlerFunc { // Create a custom handler that ensures gin context is available handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Ensure the gin context is available in the request context - ctx := context.WithValue(r.Context(), "GinContextKey", c) + ctx := utils.ContextWithGin(r.Context(), c) r = r.WithContext(ctx) srv.ServeHTTP(w, r) }) diff --git a/internal/http_handlers/health.go b/internal/http_handlers/health.go index d73944ac..fe9778f3 100644 --- a/internal/http_handlers/health.go +++ b/internal/http_handlers/health.go @@ -17,7 +17,7 @@ func (h *httpProvider) HealthHandler() gin.HandlerFunc { metrics.DBHealthCheckTotal.WithLabelValues("unhealthy").Inc() c.JSON(http.StatusServiceUnavailable, gin.H{ "status": "unhealthy", - "error": err.Error(), + "error": "storage unavailable", }) return } @@ -35,7 +35,7 @@ func (h *httpProvider) ReadyHandler() gin.HandlerFunc { metrics.DBHealthCheckTotal.WithLabelValues("unhealthy").Inc() c.JSON(http.StatusServiceUnavailable, gin.H{ "status": "not ready", - "error": err.Error(), + "error": "storage unavailable", }) return } diff --git a/internal/http_handlers/metrics.go b/internal/http_handlers/metrics.go index 11b95db1..f2fb2cf4 100644 --- a/internal/http_handlers/metrics.go +++ b/internal/http_handlers/metrics.go @@ -16,7 +16,8 @@ func (h *httpProvider) MetricsMiddleware() gin.HandlerFunc { start := time.Now() path := c.FullPath() if path == "" { - path = c.Request.URL.Path + // Avoid raw URL paths as Prometheus labels (unbounded cardinality on 404 scans). + path = "unmatched" } c.Next() @@ -24,6 +25,10 @@ func (h *httpProvider) MetricsMiddleware() gin.HandlerFunc { duration := time.Since(start).Seconds() status := fmt.Sprintf("%d", c.Writer.Status()) + if metrics.SkipHTTPRequestMetrics(path) { + return + } + metrics.HTTPRequestsTotal.WithLabelValues(c.Request.Method, path, status).Inc() metrics.HTTPRequestDuration.WithLabelValues(c.Request.Method, path).Observe(duration) } diff --git a/internal/http_handlers/oauth_callback.go b/internal/http_handlers/oauth_callback.go index 90be5a1f..55ab14c7 100644 --- a/internal/http_handlers/oauth_callback.go +++ b/internal/http_handlers/oauth_callback.go @@ -134,7 +134,10 @@ func (h *httpProvider) OAuthCallbackHandler() gin.HandlerFunc { IPAddress: utils.GetIP(ctx.Request), UserAgent: utils.GetUserAgent(ctx.Request), }) - ctx.JSON(400, gin.H{"error": err.Error()}) + ctx.JSON(400, gin.H{ + "error": "oauth_callback_failed", + "error_description": "OAuth callback could not be completed. Please try again.", + }) return } if user == nil { diff --git a/internal/http_handlers/rate_limit.go b/internal/http_handlers/rate_limit.go index 34fffe52..1af474dc 100644 --- a/internal/http_handlers/rate_limit.go +++ b/internal/http_handlers/rate_limit.go @@ -15,13 +15,12 @@ var exemptPrefixes = []string{ } // exemptPaths are exact paths that bypass rate limiting. +// /metrics is not served on this router (dedicated listener). /playground is rate-limited like other API routes. var exemptPaths = map[string]bool{ "/": true, "/health": true, "/healthz": true, "/readyz": true, - "/metrics": true, - "/playground": true, "/.well-known/openid-configuration": true, "/.well-known/jwks.json": true, } @@ -59,7 +58,16 @@ func (h *httpProvider) RateLimitMiddleware() gin.HandlerFunc { allowed, err := h.Dependencies.RateLimitProvider.Allow(c.Request.Context(), c.ClientIP()) if err != nil { log := h.Dependencies.Log.With().Str("func", "RateLimitMiddleware").Logger() - log.Error().Err(err).Msg("rate limit check failed, allowing request") + log.Error().Err(err).Msg("rate limit check failed") + if h.Config.RateLimitFailClosed { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "rate_limit_unavailable", + "error_description": "Rate limiting is temporarily unavailable. Please try again later.", + }) + c.Abort() + return + } + log.Warn().Msg("rate limit fail-open: allowing request after backend error") c.Next() return } diff --git a/internal/integration_tests/add_webhook_test.go b/internal/integration_tests/add_webhook_test.go index c60796aa..f2b927a9 100644 --- a/internal/integration_tests/add_webhook_test.go +++ b/internal/integration_tests/add_webhook_test.go @@ -87,9 +87,12 @@ func TestAddWebhookTest(t *testing.T) { h, err := crypto.EncryptPassword(cfg.AdminSecret) assert.Nil(t, err) + // Use UserDeactivatedWebhookEvent to avoid data leakage from other tests + // that create webhooks with more commonly used event names + uniqueEventName := constants.UserDeactivatedWebhookEvent req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) addedWebhook, err := ts.GraphQLProvider.AddWebhook(ctx, &model.AddWebhookRequest{ - EventName: constants.UserCreatedWebhookEvent, + EventName: uniqueEventName, EventDescription: refs.NewStringRef("test"), Endpoint: "test", Enabled: false, @@ -100,7 +103,7 @@ func TestAddWebhookTest(t *testing.T) { require.NoError(t, err) assert.NotNil(t, addedWebhook) - res, err := ts.StorageProvider.GetWebhookByEventName(ctx, constants.UserCreatedWebhookEvent) + res, err := ts.StorageProvider.GetWebhookByEventName(ctx, uniqueEventName) require.NoError(t, err) assert.NotNil(t, res) assert.Equal(t, 1, len(res)) diff --git a/internal/integration_tests/custom_access_token_script_test.go b/internal/integration_tests/custom_access_token_script_test.go new file mode 100644 index 00000000..80a58fbf --- /dev/null +++ b/internal/integration_tests/custom_access_token_script_test.go @@ -0,0 +1,365 @@ +package integration_tests + +import ( + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/authorizerdev/authorizer/internal/config" + "github.com/authorizerdev/authorizer/internal/graph/model" +) + +// parseTestJWTClaims parses a JWT token's claims without validation (for test inspection only). +func parseTestJWTClaims(t *testing.T, tokenString string) jwt.MapClaims { + t.Helper() + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + token, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{}) + require.NoError(t, err) + claims, ok := token.Claims.(jwt.MapClaims) + require.True(t, ok) + return claims +} + +// TestCustomAccessTokenScript tests the custom access token script functionality +// including the 5-second execution timeout added for DoS protection. +func TestCustomAccessTokenScript(t *testing.T) { + runForEachDB(t, func(t *testing.T, cfg *config.Config) { + t.Run("should_add_custom_claims_from_script", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + return { custom_claim: "hello", user_email: user.email }; + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "custom_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.AccessToken) + + // Parse the access token and verify custom claims are present + claims := parseTestJWTClaims(t, *loginRes.AccessToken) + assert.Equal(t, "hello", claims["custom_claim"]) + assert.Equal(t, email, claims["user_email"]) + }) + + t.Run("should_not_override_reserved_claims", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + return { sub: "hacked", iss: "hacked", roles: ["admin"], custom_field: "allowed" }; + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "reserved_claims_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + + claims := parseTestJWTClaims(t, *loginRes.AccessToken) + // Reserved claims must NOT be overridden + assert.NotEqual(t, "hacked", claims["sub"]) + assert.NotEqual(t, "hacked", claims["iss"]) + // Roles should NOT be overridden to admin + roles, ok := claims["roles"].([]interface{}) + if ok { + for _, r := range roles { + assert.NotEqual(t, "admin", r, "reserved 'roles' claim must not be overridden by script") + } + } + // Custom (non-reserved) claims should be added + assert.Equal(t, "allowed", claims["custom_field"]) + }) + + t.Run("should_timeout_infinite_loop_script", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + while(true) {} // infinite loop — should be killed after 5 seconds + return { never: "reached" }; + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "timeout_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + + // Measure execution time to verify the timeout works + start := time.Now() + + // Login should still succeed — the timeout is handled gracefully, + // custom claims are skipped but the token is still created. + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + elapsed := time.Since(start) + + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.AccessToken) + + // The token should be valid but without the custom claim from the timed-out script + claims := parseTestJWTClaims(t, *loginRes.AccessToken) + assert.Nil(t, claims["never"], "timed-out script claims must not appear in token") + // Standard claims should still be present + assert.NotEmpty(t, claims["sub"]) + assert.NotEmpty(t, claims["iss"]) + + // Verify the timeout kicked in within a reasonable range (5-8 seconds for the access + id token) + assert.Less(t, elapsed, 20*time.Second, "login with infinite loop script should complete within 20 seconds (two 5s timeouts + overhead)") + }) + + t.Run("should_handle_script_error_gracefully", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + throw new Error("intentional error"); + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "error_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + + // Login should still succeed even with a broken script + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.AccessToken) + }) + + t.Run("should_work_without_custom_script", func(t *testing.T) { + cfg.CustomAccessTokenScript = "" + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "no_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.AccessToken) + + claims := parseTestJWTClaims(t, *loginRes.AccessToken) + assert.NotEmpty(t, claims["sub"]) + // Ensure no unexpected claims were added + assert.Nil(t, claims["custom_claim"]) + }) + + t.Run("should_have_custom_claims_in_id_token_too", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + return { team: "engineering" }; + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "id_token_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.IDToken) + + // The custom script runs for both access token and ID token + claims := parseTestJWTClaims(t, *loginRes.IDToken) + assert.Equal(t, "engineering", claims["team"]) + }) + }) +} + +// TestClientIDMismatchMetric verifies that client ID mismatch records a security metric. +func TestClientIDMismatchMetric(t *testing.T) { + runForEachDB(t, func(t *testing.T, cfg *config.Config) { + ts := initTestSetup(t, cfg) + + router := setupTestRouter(ts) + + t.Run("records_metric_on_client_id_mismatch", func(t *testing.T) { + // Send request with wrong client ID to /graphql (not dashboard/app) + body := `{"query":"{ meta { version } }"}` + w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{ + "Content-Type": "application/json", + "X-Authorizer-Client-ID": "wrong-client-id", + "X-Authorizer-URL": "http://localhost:8080", + "Origin": "http://localhost:3000", + }) + + assert.Equal(t, 400, w.Code) + assert.Contains(t, w.Body.String(), "invalid_client_id") + + // Check that the security metric was recorded + metricsBody := getMetricsBody(t, router) + assert.Contains(t, metricsBody, `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}`) + }) + + t.Run("no_metric_for_valid_client_id", func(t *testing.T) { + body := `{"query":"{ meta { version } }"}` + w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{ + "Content-Type": "application/json", + "X-Authorizer-Client-ID": cfg.ClientID, + "X-Authorizer-URL": "http://localhost:8080", + "Origin": "http://localhost:3000", + }) + + // Should not be 400 + assert.NotEqual(t, 400, w.Code) + }) + + t.Run("no_metric_for_dashboard_path_mismatch", func(t *testing.T) { + mark := `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}` + before := prometheusCounterSample(t, getMetricsBody(t, router), mark) + w := sendTestRequest(t, router, "GET", "/dashboard/", "", map[string]string{ + "X-Authorizer-Client-ID": "wrong-client-id", + }) + assert.Equal(t, 400, w.Code) + after := prometheusCounterSample(t, getMetricsBody(t, router), mark) + assert.Equal(t, before, after, "dashboard path mismatch must not increment client_id_mismatch metric") + }) + + t.Run("records_client_id_header_missing_metric", func(t *testing.T) { + mark := "authorizer_client_id_header_missing_total" + before := prometheusCounterSample(t, getMetricsBody(t, router), mark) + sendTestRequest(t, router, "POST", "/graphql", `{"query":"{ meta { version } }"}`, map[string]string{ + "Content-Type": "application/json", + "X-Authorizer-URL": "http://localhost:8080", + "Origin": "http://localhost:3000", + }) + after := prometheusCounterSample(t, getMetricsBody(t, router), mark) + assert.Greater(t, after, before) + }) + }) +} + +// Helper functions for cleaner test code + +func setupTestRouter(ts *testSetup) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ts.HttpProvider.CORSMiddleware()) + router.Use(ts.HttpProvider.CSRFMiddleware()) + router.Use(ts.HttpProvider.ClientCheckMiddleware()) + router.Use(ts.HttpProvider.ContextMiddleware()) + router.POST("/graphql", ts.HttpProvider.GraphqlHandler()) + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) + // Dashboard route to test path exclusion + router.GET("/dashboard/", func(c *gin.Context) { + c.JSON(200, gin.H{"status": "ok"}) + }) + return router +} + +func sendTestRequest(t *testing.T, router *gin.Engine, method, path, body string, headers map[string]string) *httptest.ResponseRecorder { + t.Helper() + w := httptest.NewRecorder() + var req *http.Request + var err error + if body != "" { + req, err = http.NewRequest(method, path, strings.NewReader(body)) + } else { + req, err = http.NewRequest(method, path, nil) + } + require.NoError(t, err) + for k, v := range headers { + req.Header.Set(k, v) + } + router.ServeHTTP(w, req) + return w +} + +func getMetricsBody(t *testing.T, router *gin.Engine) string { + t.Helper() + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + return w.Body.String() +} + +// prometheusCounterSample parses the numeric value of the first Prometheus text exposition +// sample line for the given metric name prefix (name plus label set, e.g. "foo" or `bar{a="b"}`). +func prometheusCounterSample(t *testing.T, body, namePrefix string) float64 { + t.Helper() + prefix := namePrefix + " " + for _, line := range strings.Split(body, "\n") { + line = strings.TrimSpace(line) + if line == "" || line[0] == '#' { + continue + } + if !strings.HasPrefix(line, prefix) { + continue + } + valStr := strings.TrimSpace(strings.TrimPrefix(line, prefix)) + v, err := strconv.ParseFloat(valStr, 64) + require.NoError(t, err) + return v + } + return 0 +} diff --git a/internal/integration_tests/health_test.go b/internal/integration_tests/health_test.go index b8cf43db..cdb17163 100644 --- a/internal/integration_tests/health_test.go +++ b/internal/integration_tests/health_test.go @@ -1,18 +1,32 @@ package integration_tests import ( + "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/authorizerdev/authorizer/internal/config" + "github.com/authorizerdev/authorizer/internal/http_handlers" + "github.com/authorizerdev/authorizer/internal/storage" ) +// failingHealthStorage wraps a real storage provider but fails HealthCheck (for probe tests). +type failingHealthStorage struct { + storage.Provider +} + +func (*failingHealthStorage) HealthCheck(ctx context.Context) error { + return errors.New("test: forced storage health failure") +} + // TestHealthHandler verifies the /healthz liveness probe endpoint behaviour. func TestHealthHandler(t *testing.T) { runForEachDB(t, func(t *testing.T, cfg *config.Config) { @@ -62,3 +76,61 @@ func TestReadyHandler(t *testing.T) { }) }) } + +// TestHealthHandlersUnhealthyStorage verifies liveness/readiness and DB metrics when HealthCheck fails. +func TestHealthHandlersUnhealthyStorage(t *testing.T) { + runForEachDB(t, func(t *testing.T, cfg *config.Config) { + logger := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger() + realStorage, err := storage.New(cfg, &storage.Dependencies{Log: &logger}) + require.NoError(t, err) + t.Cleanup(func() { _ = realStorage.Close() }) + + wrapped := &failingHealthStorage{Provider: realStorage} + httpProv, err := http_handlers.New(cfg, &http_handlers.Dependencies{ + Log: &logger, + StorageProvider: wrapped, + }) + require.NoError(t, err) + + router := gin.New() + router.GET("/healthz", httpProv.HealthHandler()) + router.GET("/readyz", httpProv.ReadyHandler()) + router.GET("/metrics", httpProv.MetricsHandler()) + + t.Run("healthz_returns_503", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/healthz", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + var body map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + assert.Equal(t, "unhealthy", body["status"]) + }) + + t.Run("readyz_returns_503", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/readyz", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + var body map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + assert.Equal(t, "not ready", body["status"]) + }) + + t.Run("records_unhealthy_db_check_metric", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/healthz", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + require.Equal(t, http.StatusServiceUnavailable, w.Code) + + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + assert.Contains(t, w2.Body.String(), `authorizer_db_health_check_total{status="unhealthy"}`) + }) + }) +} diff --git a/internal/integration_tests/metrics_test.go b/internal/integration_tests/metrics_test.go index e97fc79d..c3672ff5 100644 --- a/internal/integration_tests/metrics_test.go +++ b/internal/integration_tests/metrics_test.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -31,6 +32,7 @@ func TestMetricsEndpoint(t *testing.T) { metrics.RecordAuthEvent("test", "test") metrics.RecordSecurityEvent("test", "test") metrics.RecordGraphQLError("test") + metrics.RecordClientIDHeaderMissing() metrics.DBHealthCheckTotal.WithLabelValues("test").Inc() w := httptest.NewRecorder() @@ -48,6 +50,15 @@ func TestMetricsEndpoint(t *testing.T) { assert.Contains(t, body, "authorizer_security_events_total") assert.Contains(t, body, "authorizer_graphql_errors_total") assert.Contains(t, body, "authorizer_db_health_check_total") + assert.Contains(t, body, "authorizer_client_id_header_missing_total") + }) + + t.Run("post_metrics_is_not_get_ok", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.NotEqual(t, http.StatusOK, w.Code) }) }) } @@ -78,6 +89,26 @@ func TestMetricsMiddleware(t *testing.T) { body := w2.Body.String() assert.Contains(t, body, `authorizer_http_requests_total{method="GET",path="/healthz",status="200"}`) }) + + t.Run("skips_http_metrics_for_excluded_paths", func(t *testing.T) { + router.GET("/app/foo", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/app/foo", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + + body := w2.Body.String() + assert.NotContains(t, body, `authorizer_http_requests_total{method="GET",path="/app/foo"`) + }) }) } @@ -125,24 +156,6 @@ func TestAuthEventMetrics(t *testing.T) { email := "metrics_" + uuid.New().String() + "@authorizer.dev" password := "Password@123" - t.Run("records_login_failure_on_bad_credentials", func(t *testing.T) { - loginReq := &model.LoginRequest{ - Email: &email, - Password: "wrong_password", - } - _, err := ts.GraphQLProvider.Login(ctx, loginReq) - assert.Error(t, err) - - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - - body := w.Body.String() - assert.Contains(t, body, `authorizer_auth_events_total{event="login",status="failure"}`) - assert.Contains(t, body, `authorizer_security_events_total{event="invalid_credentials"`) - }) - t.Run("records_signup_and_login_success", func(t *testing.T) { signupReq := &model.SignUpRequest{ Email: &email, @@ -173,6 +186,24 @@ func TestAuthEventMetrics(t *testing.T) { router.ServeHTTP(w2, req2) assert.Contains(t, w2.Body.String(), `authorizer_auth_events_total{event="login",status="success"}`) }) + + t.Run("records_login_failure_on_bad_credentials", func(t *testing.T) { + loginReq := &model.LoginRequest{ + Email: &email, + Password: "wrong_password", + } + _, err := ts.GraphQLProvider.Login(ctx, loginReq) + assert.Error(t, err) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + + body := w.Body.String() + assert.Contains(t, body, `authorizer_auth_events_total{event="login",status="failure"}`) + assert.Contains(t, body, `authorizer_security_events_total{event="invalid_credentials"`) + }) }) } @@ -208,7 +239,53 @@ func TestGraphQLErrorMetrics(t *testing.T) { metricsBody := w2.Body.String() assert.Contains(t, metricsBody, "authorizer_graphql_request_duration_seconds") + assert.Contains(t, metricsBody, `authorizer_graphql_errors_total{operation="anonymous"}`) + }) + + t.Run("captures_graphql_errors_with_named_operation", func(t *testing.T) { + body := `{"operationName":"LoginOp","query":"mutation LoginOp { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}` + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-authorizer-url", "http://localhost:8080") + req.Header.Set("Origin", "http://localhost:3000") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + loginOpLabel := metrics.GraphQLOperationPrometheusLabel("LoginOp") + assert.Contains(t, w2.Body.String(), `authorizer_graphql_errors_total{operation="`+loginOpLabel+`"}`) + }) + }) +} + +// TestClientIDHeaderMissingMiddlewareMetric verifies empty X-Authorizer-Client-ID increments the counter. +func TestClientIDHeaderMissingMiddlewareMetric(t *testing.T) { + runForEachDB(t, func(t *testing.T, cfg *config.Config) { + ts := initTestSetup(t, cfg) + + router := gin.New() + router.Use(ts.HttpProvider.ClientCheckMiddleware()) + router.GET("/probe", func(c *gin.Context) { + c.Status(http.StatusOK) }) + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/probe", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + assert.Contains(t, w2.Body.String(), "authorizer_client_id_header_missing_total") }) } @@ -218,6 +295,15 @@ func TestRecordAuthEventHelpers(t *testing.T) { metrics.RecordAuthEvent(metrics.EventVerifyEmail, metrics.StatusSuccess) metrics.RecordAuthEvent(metrics.EventVerifyOTP, metrics.StatusFailure) metrics.RecordSecurityEvent("brute_force", "rate_limit") + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + promhttp.Handler().ServeHTTP(w, req) + body := w.Body.String() + assert.Contains(t, body, `authorizer_auth_events_total{event="verify_email",status="success"}`) + assert.Contains(t, body, `authorizer_auth_events_total{event="verify_otp",status="failure"}`) + assert.Contains(t, body, `authorizer_security_events_total{event="brute_force",reason="rate_limit"}`) }) } diff --git a/internal/integration_tests/oauth_standards_compliance_test.go b/internal/integration_tests/oauth_standards_compliance_test.go index 094bb409..000a3235 100644 --- a/internal/integration_tests/oauth_standards_compliance_test.go +++ b/internal/integration_tests/oauth_standards_compliance_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/authorizerdev/authorizer/internal/crypto" "github.com/authorizerdev/authorizer/internal/graph/model" ) @@ -522,13 +523,14 @@ func TestAuthorizeEndpointCompliance(t *testing.T) { // TestJWKSEndpointCompliance verifies /.well-known/jwks.json func TestJWKSEndpointCompliance(t *testing.T) { - cfg := getTestConfig() - ts := initTestSetup(t, cfg) + t.Run("JWKS_returns_empty_keys_for_HMAC", func(t *testing.T) { + // HMAC (symmetric) keys must NOT be exposed via JWKS. + cfg := getTestConfig() // uses HS256 + ts := initTestSetup(t, cfg) - router := gin.New() - router.GET("/.well-known/jwks.json", ts.HttpProvider.JWKsHandler()) + router := gin.New() + router.GET("/.well-known/jwks.json", ts.HttpProvider.JWKsHandler()) - t.Run("JWKS_returns_valid_keyset", func(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/.well-known/jwks.json", nil) router.ServeHTTP(w, req) @@ -541,7 +543,36 @@ func TestJWKSEndpointCompliance(t *testing.T) { keys, ok := body["keys"].([]interface{}) require.True(t, ok, "JWKS response MUST contain 'keys' array") - require.NotEmpty(t, keys, "JWKS 'keys' array MUST not be empty") + assert.Empty(t, keys, "JWKS 'keys' array MUST be empty for HMAC-only config to prevent secret exposure") + }) + + t.Run("JWKS_returns_valid_keyset_for_RSA", func(t *testing.T) { + cfg := getTestConfig() + // Generate RSA keys for this test + _, privateKey, publicKey, _, err := crypto.NewRSAKey("RS256", cfg.ClientID) + require.NoError(t, err) + cfg.JWTType = "RS256" + cfg.JWTPrivateKey = privateKey + cfg.JWTPublicKey = publicKey + cfg.JWTSecret = "" // not needed for RSA + ts := initTestSetup(t, cfg) + + router := gin.New() + router.GET("/.well-known/jwks.json", ts.HttpProvider.JWKsHandler()) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/.well-known/jwks.json", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var body map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &body) + require.NoError(t, err) + + keys, ok := body["keys"].([]interface{}) + require.True(t, ok, "JWKS response MUST contain 'keys' array") + require.NotEmpty(t, keys, "JWKS 'keys' array MUST not be empty for RSA config") // Each key must have required JWK fields key := keys[0].(map[string]interface{}) diff --git a/internal/integration_tests/rate_limit_test.go b/internal/integration_tests/rate_limit_test.go index 142aa32d..968e1ff1 100644 --- a/internal/integration_tests/rate_limit_test.go +++ b/internal/integration_tests/rate_limit_test.go @@ -76,14 +76,11 @@ func TestRateLimitMiddleware(t *testing.T) { router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) - router.GET("/metrics", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"status": "ok"}) - }) router.GET("/.well-known/openid-configuration", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"issuer": "test"}) }) - exemptPaths := []string{"/health", "/metrics", "/.well-known/openid-configuration"} + exemptPaths := []string{"/health", "/.well-known/openid-configuration"} for _, path := range exemptPaths { // Make many requests - none should be limited for i := 0; i < 10; i++ { diff --git a/internal/integration_tests/resend_otp_test.go b/internal/integration_tests/resend_otp_test.go index 00d6d0f8..6a8b7fc2 100644 --- a/internal/integration_tests/resend_otp_test.go +++ b/internal/integration_tests/resend_otp_test.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" "testing" + "time" "github.com/authorizerdev/authorizer/internal/constants" "github.com/authorizerdev/authorizer/internal/graph/model" @@ -37,8 +38,8 @@ func TestResendOTP(t *testing.T) { ts := initTestSetup(t, cfg) req, ctx := createContext(ts) - // Create a test user - mobile := "+14155552672" + // Create a test user with a unique phone number to avoid collisions + mobile := fmt.Sprintf("+1%010d", time.Now().UnixNano()%10000000000) password := "Password@123" // Signup the user signupReq := &model.SignUpRequest{ @@ -48,8 +49,8 @@ func TestResendOTP(t *testing.T) { } signupRes, err := ts.GraphQLProvider.SignUp(ctx, signupReq) - assert.NoError(t, err) - assert.NotNil(t, signupRes) + require.NoError(t, err) + require.NotNil(t, signupRes) // Expect the user to be nil, as the email is not verified yet assert.Nil(t, signupRes.User) diff --git a/internal/integration_tests/reset_password_test.go b/internal/integration_tests/reset_password_test.go index 9ec65381..5b685f5f 100644 --- a/internal/integration_tests/reset_password_test.go +++ b/internal/integration_tests/reset_password_test.go @@ -72,7 +72,7 @@ func TestResetPassword(t *testing.T) { ts2 := initTestSetup(t, cfg2) req2, ctx2 := createContext(ts2) - mobile := "+14155550199" + mobile := fmt.Sprintf("+1%010d", time.Now().UnixNano()%10000000000) signupReq2 := &model.SignUpRequest{ PhoneNumber: &mobile, Password: password, diff --git a/internal/integration_tests/test_helper.go b/internal/integration_tests/test_helper.go index 7a6470e1..8b2cd28c 100644 --- a/internal/integration_tests/test_helper.go +++ b/internal/integration_tests/test_helper.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "os" + "path/filepath" "strings" "testing" @@ -21,10 +22,12 @@ import ( "github.com/authorizerdev/authorizer/internal/graphql" "github.com/authorizerdev/authorizer/internal/http_handlers" "github.com/authorizerdev/authorizer/internal/memory_store" + "github.com/authorizerdev/authorizer/internal/oauth" "github.com/authorizerdev/authorizer/internal/rate_limit" "github.com/authorizerdev/authorizer/internal/sms" "github.com/authorizerdev/authorizer/internal/storage" "github.com/authorizerdev/authorizer/internal/token" + "github.com/authorizerdev/authorizer/internal/utils" ) // testSetup represents the test setup @@ -41,13 +44,16 @@ type testSetup struct { } func createContext(s *testSetup) (*http.Request, context.Context) { - req, _ := http.NewRequest( - "POST", + req, err := http.NewRequest( + http.MethodPost, "http://"+s.HttpServer.Listener.Addr().String()+"/graphql", nil, ) + if err != nil { + panic("integration_tests.createContext: " + err.Error()) + } - ctx := context.WithValue(req.Context(), "GinContextKey", s.GinContext) + ctx := utils.ContextWithGin(req.Context(), s.GinContext) s.GinContext.Request = req return req, ctx } @@ -128,6 +134,7 @@ func getTestConfig() *config.Config { func getTestConfigForDB(dbType, dbURL string) *config.Config { cfg := &config.Config{ Env: constants.TestEnv, + SkipTestEndpointSSRFValidation: true, DatabaseType: dbType, DatabaseURL: dbURL, JWTSecret: "test-secret", @@ -179,6 +186,9 @@ func getTestConfigForDB(dbType, dbURL string) *config.Config { func runForEachDB(t *testing.T, testFn func(t *testing.T, cfg *config.Config)) { t.Helper() dbConfigs := getTestDBs() + if len(dbConfigs) == 0 { + t.Fatal("TEST_DBS produced no runnable database configurations; check TEST_DBS and that each database type resolves to a non-empty URL") + } for _, dbCfg := range dbConfigs { t.Run("db="+dbCfg.DbType, func(t *testing.T) { @@ -193,6 +203,10 @@ func initTestSetup(t *testing.T, cfg *config.Config) *testSetup { // Initialize logger logger := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger() + if cfg.DatabaseType == constants.DbTypeSqlite || cfg.DatabaseType == constants.DbTypeLibSQL { + cfg.DatabaseURL = filepath.Join(t.TempDir(), "authorizer_integration.db") + } + // Initialize storage provider first as it's required by other providers storageProvider, err := storage.New(cfg, &storage.Dependencies{ Log: &logger, @@ -239,6 +253,11 @@ func initTestSetup(t *testing.T, cfg *config.Config) *testSetup { }) require.NoError(t, err) + oauthProvider, err := oauth.New(cfg, &oauth.Dependencies{ + Log: &logger, + }) + require.NoError(t, err) + // Initialize audit provider auditProvider := audit.New(&audit.Dependencies{ Log: &logger, @@ -270,6 +289,7 @@ func initTestSetup(t *testing.T, cfg *config.Config) *testSetup { StorageProvider: storageProvider, TokenProvider: tokenProvider, RateLimitProvider: rateLimitProvider, + OAuthProvider: oauthProvider, } // Create GraphQL provider @@ -290,6 +310,15 @@ func initTestSetup(t *testing.T, cfg *config.Config) *testSetup { server := httptest.NewServer(r) + t.Cleanup(func() { + server.Close() + if storageProvider != nil { + if err := storageProvider.Close(); err != nil { + t.Errorf("close storage provider: %v", err) + } + } + }) + return &testSetup{ GraphQLProvider: gqlProvider, HttpProvider: httpProvider, diff --git a/internal/integration_tests/update_profile_test.go b/internal/integration_tests/update_profile_test.go index 85d259d9..fbea1597 100644 --- a/internal/integration_tests/update_profile_test.go +++ b/internal/integration_tests/update_profile_test.go @@ -1,7 +1,9 @@ package integration_tests import ( + "fmt" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -110,7 +112,7 @@ func TestUpdateProfile(t *testing.T) { givenName := "John" familyName := "Doe" nickname := "Johnny" - phoneNumber := "+1234567890" + phoneNumber := fmt.Sprintf("+1%010d", time.Now().UnixNano()%10000000000) updateReq := &model.UpdateProfileRequest{ GivenName: refs.NewStringRef(givenName), diff --git a/internal/integration_tests/verify_otp_test.go b/internal/integration_tests/verify_otp_test.go index 7b620834..cfe0f991 100644 --- a/internal/integration_tests/verify_otp_test.go +++ b/internal/integration_tests/verify_otp_test.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" "testing" + "time" "github.com/authorizerdev/authorizer/internal/constants" "github.com/authorizerdev/authorizer/internal/graph/model" @@ -36,7 +37,7 @@ func TestVerifyOTP(t *testing.T) { req, ctx := createContext(ts) // Create a test user - mobile := "+14155552671" + mobile := fmt.Sprintf("+1%010d", time.Now().UnixNano()%10000000000) password := "Password@123" // Signup the user signupReq := &model.SignUpRequest{ diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index e9166cd0..5f42dbed 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -1,6 +1,11 @@ +// Package metrics defines Prometheus collectors and helpers for Authorizer observability +// (HTTP traffic, auth events, GraphQL, security signals, and database health). package metrics import ( + "crypto/sha256" + "encoding/hex" + "strings" "sync" "github.com/prometheus/client_golang/prometheus" @@ -79,7 +84,7 @@ var ( GraphQLErrorsTotal = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "authorizer_graphql_errors_total", - Help: "Total number of GraphQL responses containing errors", + Help: "Total number of GraphQL responses containing errors (operation label is bounded: anonymous or op_)", }, []string{"operation"}, ) @@ -88,7 +93,7 @@ var ( GraphQLRequestDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Name: "authorizer_graphql_request_duration_seconds", - Help: "GraphQL operation duration in seconds", + Help: "GraphQL operation duration in seconds (operation label is bounded: anonymous or op_)", Buckets: prometheus.DefBuckets, }, []string{"operation"}, @@ -102,8 +107,82 @@ var ( }, []string{"status"}, ) + + // ClientIDHeaderMissingTotal counts allowed requests with no X-Authorizer-Client-ID header. + ClientIDHeaderMissingTotal = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "authorizer_client_id_header_missing_total", + Help: "Total requests that omitted X-Authorizer-Client-ID (allowed for some routes)", + }, + ) ) +// staticAssetPathSuffixes are path suffixes (after lowercasing) treated as static files +// for HTTP metrics filtering (images, icons, fonts, source maps, PWA manifest). +var staticAssetPathSuffixes = []string{ + ".png", ".jpg", ".jpeg", ".gif", ".webp", ".svg", ".ico", ".bmp", ".avif", ".jfif", + ".woff", ".woff2", ".ttf", ".otf", ".eot", + ".webmanifest", + ".map", +} + +// SkipHTTPRequestMetrics reports whether a request path should be omitted from +// HTTP request counters and histograms (UI routes, static assets, favicons, images, fonts). +func SkipHTTPRequestMetrics(path string) bool { + if path == "" { + return false + } + if path == "/app" || strings.HasPrefix(path, "/app/") { + return true + } + if path == "/dashboard" || strings.HasPrefix(path, "/dashboard/") { + return true + } + if path == "/metrics" { + return true + } + for _, seg := range strings.Split(path, "/") { + if strings.HasPrefix(seg, "chunk-") { + return true + } + } + return skipHTTPRequestMetricsStaticAsset(path) +} + +func skipHTTPRequestMetricsStaticAsset(path string) bool { + p := strings.ToLower(path) + if i := strings.Index(p, "?"); i >= 0 { + p = p[:i] + } + switch p { + case "/robots.txt", "/sitemap.xml", "/humans.txt", "/security.txt": + return true + } + for _, suf := range staticAssetPathSuffixes { + if strings.HasSuffix(p, suf) { + return true + } + } + file := p + if i := strings.LastIndex(p, "/"); i >= 0 { + file = p[i+1:] + } + if file == "" { + return false + } + if strings.HasPrefix(file, "favicon") { + return true + } + // Common browser / PWA icon filenames without matching suffix rules above. + if strings.Contains(file, "apple-touch-icon") || + strings.Contains(file, "android-chrome") || + strings.Contains(file, "safari-pinned-tab") || + strings.Contains(file, "mstile-") { + return true + } + return false +} + // Init registers all metrics with the default prometheus registry. // It is safe to call multiple times; registration happens only once. func Init() { @@ -116,20 +195,38 @@ func Init() { prometheus.MustRegister(GraphQLErrorsTotal) prometheus.MustRegister(GraphQLRequestDuration) prometheus.MustRegister(DBHealthCheckTotal) + prometheus.MustRegister(ClientIDHeaderMissingTotal) }) } +// GraphQLOperationPrometheusLabel maps an operation name to a bounded-cardinality value +// suitable for Prometheus labels (never use raw client-supplied names as labels). +func GraphQLOperationPrometheusLabel(operationName string) string { + if strings.TrimSpace(operationName) == "" { + return "anonymous" + } + sum := sha256.Sum256([]byte(operationName)) + return "op_" + hex.EncodeToString(sum[:8]) +} + // RecordAuthEvent records an authentication event with given status. +// event and status must be low-cardinality values (package constants); do not pass user input. func RecordAuthEvent(event, status string) { AuthEventsTotal.WithLabelValues(event, status).Inc() } // RecordSecurityEvent records a security-relevant event for alerting. +// event and reason must be low-cardinality values; do not pass user-controlled strings. func RecordSecurityEvent(event, reason string) { SecurityEventsTotal.WithLabelValues(event, reason).Inc() } -// RecordGraphQLError records a GraphQL error for the given operation. +// RecordGraphQLError records a GraphQL error for the given operation name. func RecordGraphQLError(operation string) { - GraphQLErrorsTotal.WithLabelValues(operation).Inc() + GraphQLErrorsTotal.WithLabelValues(GraphQLOperationPrometheusLabel(operation)).Inc() +} + +// RecordClientIDHeaderMissing records a request that had no client ID header. +func RecordClientIDHeaderMissing() { + ClientIDHeaderMissingTotal.Inc() } diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go new file mode 100644 index 00000000..22565d6b --- /dev/null +++ b/internal/metrics/metrics_test.go @@ -0,0 +1,65 @@ +package metrics + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGraphQLOperationPrometheusLabel(t *testing.T) { + assert.Equal(t, "anonymous", GraphQLOperationPrometheusLabel("")) + assert.Equal(t, "anonymous", GraphQLOperationPrometheusLabel(" ")) + got := GraphQLOperationPrometheusLabel("LoginOp") + assert.True(t, strings.HasPrefix(got, "op_")) + assert.Len(t, got, len("op_")+16) // 8 bytes hex +} + +func TestSkipHTTPRequestMetrics(t *testing.T) { + tests := []struct { + path string + skip bool + }{ + {path: "", skip: false}, + {path: "/api/v1", skip: false}, + {path: "/app", skip: true}, + {path: "/app/", skip: true}, + {path: "/app/static/chunk-abc.js", skip: true}, + {path: "/dashboard", skip: true}, + {path: "/dashboard/", skip: true}, + {path: "/dashboard/users", skip: true}, + {path: "/metrics", skip: true}, + {path: "/static/chunk-vendors.js", skip: true}, + {path: "/assets/chunk-main.hash.js", skip: true}, + {path: "/favicon.ico", skip: true}, + {path: "/icons/favicon-32x32.png", skip: true}, + {path: "/apple-touch-icon.png", skip: true}, + {path: "/PWA/android-chrome-192x192.png", skip: true}, + {path: "/file.woff2", skip: true}, + {path: "/site.webmanifest", skip: true}, + {path: "/app/bundle.js.map", skip: true}, + {path: "/robots.txt", skip: true}, + {path: "/sitemap.xml", skip: true}, + {path: "/humans.txt", skip: true}, + {path: "/security.txt", skip: true}, + {path: "/logo.PNG", skip: true}, + {path: "/image.JPG?query=1", skip: true}, + {path: "/path?query=/app/foo", skip: false}, + } + for _, tt := range tests { + name := tt.path + if name == "" { + name = "(empty)" + } + t.Run(name, func(t *testing.T) { + got := SkipHTTPRequestMetrics(tt.path) + assert.Equal(t, tt.skip, got, "path=%q", tt.path) + }) + } +} + +func TestSkipHTTPRequestMetrics_chunkSegment(t *testing.T) { + // Path segment must be prefixed with "chunk-", not merely contain it. + assert.False(t, SkipHTTPRequestMetrics("/foo/mychunk-file.js")) + assert.True(t, SkipHTTPRequestMetrics("/chunk-xyz")) +} diff --git a/internal/server/http_routes.go b/internal/server/http_routes.go index 86f7852a..ca332f9d 100644 --- a/internal/server/http_routes.go +++ b/internal/server/http_routes.go @@ -26,7 +26,6 @@ func (s *server) NewRouter() *gin.Engine { router.GET("/health", s.Dependencies.HTTPProvider.HealthHandler()) router.GET("/healthz", s.Dependencies.HTTPProvider.HealthHandler()) router.GET("/readyz", s.Dependencies.HTTPProvider.ReadyHandler()) - router.GET("/metrics", s.Dependencies.HTTPProvider.MetricsHandler()) router.POST("/graphql", s.Dependencies.HTTPProvider.GraphqlHandler()) router.GET("/playground", s.Dependencies.HTTPProvider.PlaygroundHandler()) router.GET("/oauth_login/:oauth_provider", s.Dependencies.HTTPProvider.OAuthLoginHandler()) diff --git a/internal/server/server.go b/internal/server/server.go index 3c361a63..2b146486 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,8 +2,12 @@ package server import ( "context" - "fmt" + "net" + "net/http" + "strconv" + "time" + "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" "github.com/authorizerdev/authorizer/internal/graphql" @@ -18,6 +22,8 @@ type Config struct { HTTPPort int // Port number to serve Metrics requests on MetricsPort int + // MetricsHost is the bind address for the dedicated /metrics listener. + MetricsHost string } // Dependencies for a server @@ -42,19 +48,38 @@ type server struct { Dependencies *Dependencies } -// Run the server until the given context is canceled +// Run the server until the given context is canceled. +// The main HTTP server (Gin) and the Prometheus /metrics server always run as separate listeners. func (s *server) Run(ctx context.Context) error { - // Create new router ginRouter := s.NewRouter() - // Start the server + httpAddr := net.JoinHostPort(s.Config.Host, strconv.Itoa(s.Config.HTTPPort)) go func() { - s.Dependencies.Log.Info().Str("host", s.Config.Host).Int("port", s.Config.HTTPPort).Msg("Starting HTTP server") - err := ginRouter.Run(s.Config.Host + ":" + fmt.Sprintf("%d", s.Config.HTTPPort)) - if err != nil { + s.Dependencies.Log.Info().Str("addr", httpAddr).Msg("Starting HTTP server") + if err := ginRouter.Run(httpAddr); err != nil { s.Dependencies.Log.Error().Err(err).Msg("HTTP server failed") } }() - // Wait until context closed + + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.Handler()) + metricsAddr := net.JoinHostPort(s.Config.MetricsHost, strconv.Itoa(s.Config.MetricsPort)) + metricsSrv := &http.Server{ + Addr: metricsAddr, + Handler: mux, + } + go func() { + s.Dependencies.Log.Info().Str("addr", metricsAddr).Msg("Starting metrics server") + if err := metricsSrv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + s.Dependencies.Log.Error().Err(err).Msg("Metrics server failed") + } + }() + go func() { + <-ctx.Done() + shCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = metricsSrv.Shutdown(shCtx) + }() + <-ctx.Done() return nil } diff --git a/internal/storage/db/arangodb/audit_log.go b/internal/storage/db/arangodb/audit_log.go index 5cc3240f..1a12d6c5 100644 --- a/internal/storage/db/arangodb/audit_log.go +++ b/internal/storage/db/arangodb/audit_log.go @@ -60,7 +60,9 @@ func (p *provider) ListAuditLogs(ctx context.Context, pagination *model.Paginati bindVariables["to_timestamp"] = toTimestamp } - query := fmt.Sprintf("FOR d in %s%s SORT d.created_at DESC LIMIT %d, %d RETURN d", schemas.Collections.AuditLog, filterQuery, pagination.Offset, pagination.Limit) + bindVariables["offset"] = pagination.Offset + bindVariables["limit"] = pagination.Limit + query := fmt.Sprintf("FOR d in %s%s SORT d.created_at DESC LIMIT @offset, @limit RETURN d", schemas.Collections.AuditLog, filterQuery) sctx := arangoDriver.WithQueryFullCount(ctx) cursor, err := p.db.Query(sctx, query, bindVariables) if err != nil { diff --git a/internal/storage/db/arangodb/email_template.go b/internal/storage/db/arangodb/email_template.go index e5cded70..4c2b22bb 100644 --- a/internal/storage/db/arangodb/email_template.go +++ b/internal/storage/db/arangodb/email_template.go @@ -48,9 +48,12 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *schem // ListEmailTemplates to list EmailTemplate func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) ([]*schemas.EmailTemplate, *model.Pagination, error) { emailTemplates := []*schemas.EmailTemplate{} - query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", schemas.Collections.EmailTemplate, pagination.Offset, pagination.Limit) + query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT @offset, @limit RETURN d", schemas.Collections.EmailTemplate) sctx := arangoDriver.WithQueryFullCount(ctx) - cursor, err := p.db.Query(sctx, query, nil) + cursor, err := p.db.Query(sctx, query, map[string]interface{}{ + "offset": pagination.Offset, + "limit": pagination.Limit, + }) if err != nil { return nil, nil, err } diff --git a/internal/storage/db/arangodb/provider.go b/internal/storage/db/arangodb/provider.go index 81c64378..0b840197 100644 --- a/internal/storage/db/arangodb/provider.go +++ b/internal/storage/db/arangodb/provider.go @@ -354,3 +354,9 @@ func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) { db: arangodb, }, err } + +// Close releases ArangoDB driver resources. The HTTP driver does not expose a pool close; +// connections are reclaimed when the provider is discarded. +func (p *provider) Close() error { + return nil +} diff --git a/internal/storage/db/arangodb/user.go b/internal/storage/db/arangodb/user.go index 17f62ff4..d45dca4c 100644 --- a/internal/storage/db/arangodb/user.go +++ b/internal/storage/db/arangodb/user.go @@ -87,8 +87,11 @@ func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) var users []*schemas.User sctx := arangoDriver.WithQueryFullCount(ctx) - query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", schemas.Collections.User, pagination.Offset, pagination.Limit) - cursor, err := p.db.Query(sctx, query, nil) + query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT @offset, @limit RETURN d", schemas.Collections.User) + cursor, err := p.db.Query(sctx, query, map[string]interface{}{ + "offset": pagination.Offset, + "limit": pagination.Limit, + }) if err != nil { return nil, nil, err } diff --git a/internal/storage/db/arangodb/verification_requests.go b/internal/storage/db/arangodb/verification_requests.go index 00a4a59a..41a855e6 100644 --- a/internal/storage/db/arangodb/verification_requests.go +++ b/internal/storage/db/arangodb/verification_requests.go @@ -89,8 +89,11 @@ func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email stri func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) ([]*schemas.VerificationRequest, *model.Pagination, error) { var verificationRequests []*schemas.VerificationRequest sctx := arangoDriver.WithQueryFullCount(ctx) - query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", schemas.Collections.VerificationRequest, pagination.Offset, pagination.Limit) - cursor, err := p.db.Query(sctx, query, nil) + query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT @offset, @limit RETURN d", schemas.Collections.VerificationRequest) + cursor, err := p.db.Query(sctx, query, map[string]interface{}{ + "offset": pagination.Offset, + "limit": pagination.Limit, + }) if err != nil { return nil, nil, err } diff --git a/internal/storage/db/arangodb/webhook_log.go b/internal/storage/db/arangodb/webhook_log.go index 27bd5ba3..f614038b 100644 --- a/internal/storage/db/arangodb/webhook_log.go +++ b/internal/storage/db/arangodb/webhook_log.go @@ -33,12 +33,12 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *schemas.Webhoo func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) ([]*schemas.WebhookLog, *model.Pagination, error) { webhookLogs := []*schemas.WebhookLog{} bindVariables := map[string]interface{}{} - query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", schemas.Collections.WebhookLog, pagination.Offset, pagination.Limit) + bindVariables["offset"] = pagination.Offset + bindVariables["limit"] = pagination.Limit + query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT @offset, @limit RETURN d", schemas.Collections.WebhookLog) if webhookID != "" { - query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhook_id SORT d.created_at DESC LIMIT %d, %d RETURN d", schemas.Collections.WebhookLog, pagination.Offset, pagination.Limit) - bindVariables = map[string]interface{}{ - "webhook_id": webhookID, - } + query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhook_id SORT d.created_at DESC LIMIT @offset, @limit RETURN d", schemas.Collections.WebhookLog) + bindVariables["webhook_id"] = webhookID } sctx := arangoDriver.WithQueryFullCount(ctx) cursor, err := p.db.Query(sctx, query, bindVariables) diff --git a/internal/storage/db/cassandradb/provider.go b/internal/storage/db/cassandradb/provider.go index bcc4edf6..ff7aa371 100644 --- a/internal/storage/db/cassandradb/provider.go +++ b/internal/storage/db/cassandradb/provider.go @@ -373,6 +373,15 @@ func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) { }, err } +// Close closes the Cassandra session. +func (p *provider) Close() error { + if p == nil || p.db == nil { + return nil + } + p.db.Close() + return nil +} + // convertMapValues converts json.Number values in a map to native Go types // (int64 or float64) so gocql can marshal them into CQL bigint/double columns. func convertMapValues(m map[string]interface{}) { diff --git a/internal/storage/db/couchbase/provider.go b/internal/storage/db/couchbase/provider.go index 04de4deb..9efccfcc 100644 --- a/internal/storage/db/couchbase/provider.go +++ b/internal/storage/db/couchbase/provider.go @@ -30,6 +30,7 @@ type provider struct { config *config.Config dependencies *Dependencies + cluster *gocb.Cluster db *gocb.Scope scopeName string } @@ -105,17 +106,33 @@ func NewProvider(config *config.Config, deps *Dependencies) (*provider, error) { for i := 0; i < v.NumField(); i++ { field := v.Field(i) for _, indexQuery := range indices[field.String()] { - scope.Query(indexQuery, nil) + _, qerr := scope.Query(indexQuery, nil) + if qerr != nil { + msg := qerr.Error() + if strings.Contains(msg, "already exists") || (strings.Contains(msg, "Index") && strings.Contains(msg, "already")) { + continue + } + return nil, fmt.Errorf("couchbase secondary index: %s: %w", indexQuery, qerr) + } } } return &provider{ config: config, dependencies: deps, + cluster: cluster, db: scope, scopeName: scopeIdentifier, }, nil } +// Close shuts down the Couchbase cluster connection. +func (p *provider) Close() error { + if p.cluster == nil { + return nil + } + return p.cluster.Close(&gocb.ClusterCloseOptions{}) +} + func createBucketAndScope(cluster *gocb.Cluster, bucketName string, scopeName string, ramQuota string, waitTimeout time.Duration) (*gocb.Bucket, error) { if ramQuota == "" { ramQuota = "1000" @@ -192,9 +209,9 @@ func getIndex(scopeName string) map[string][]string { // WebhookLog index webhookLogIndex1 := fmt.Sprintf("CREATE INDEX webhookLogIdIndex ON %s.%s(webhook_id)", scopeName, schemas.Collections.WebhookLog) - indices[schemas.Collections.Webhook] = []string{webhookLogIndex1} + indices[schemas.Collections.WebhookLog] = []string{webhookLogIndex1} - // WebhookLog index + // EmailTemplate index emailTempIndex1 := fmt.Sprintf("CREATE INDEX EmailTemplateEventNameIndex ON %s.%s(event_name)", scopeName, schemas.Collections.EmailTemplate) indices[schemas.Collections.EmailTemplate] = []string{emailTempIndex1} diff --git a/internal/storage/db/dynamodb/provider.go b/internal/storage/db/dynamodb/provider.go index 3a8bc802..7a305309 100644 --- a/internal/storage/db/dynamodb/provider.go +++ b/internal/storage/db/dynamodb/provider.go @@ -1,6 +1,10 @@ package dynamodb import ( + "context" + "fmt" + "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" @@ -29,43 +33,65 @@ func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) { awsAccessKeyID := cfg.AWSAccessKeyID awsSecretAccessKey := cfg.AWSSecretAccessKey - config := aws.Config{ + awsCfg := aws.Config{ MaxRetries: aws.Int(3), CredentialsChainVerboseErrors: aws.Bool(true), // for full error logs } if awsRegion != "" { - config.Region = aws.String(awsRegion) + awsCfg.Region = aws.String(awsRegion) } // custom awsAccessKeyID, awsSecretAccessKey took first priority, if not then fetch config from aws credentials if awsAccessKeyID != "" && awsSecretAccessKey != "" { - config.Credentials = credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, "") + awsCfg.Credentials = credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, "") } else if dbURL != "" { deps.Log.Info().Msg("Using DB URL for dynamodb") // static config in case of testing or local-setup - config.Credentials = credentials.NewStaticCredentials("key", "key", "") - config.Endpoint = aws.String(dbURL) + awsCfg.Credentials = credentials.NewStaticCredentials("key", "key", "") + awsCfg.Endpoint = aws.String(dbURL) } else { deps.Log.Info().Msg("Using default AWS credentials config from system for dynamodb") } - session := session.Must(session.NewSession(&config)) - db := dynamo.New(session) - db.CreateTable(schemas.Collections.User, schemas.User{}).Wait() - db.CreateTable(schemas.Collections.Session, schemas.Session{}).Wait() - db.CreateTable(schemas.Collections.EmailTemplate, schemas.EmailTemplate{}).Wait() - db.CreateTable(schemas.Collections.Env, schemas.Env{}).Wait() - db.CreateTable(schemas.Collections.OTP, schemas.OTP{}).Wait() - db.CreateTable(schemas.Collections.VerificationRequest, schemas.VerificationRequest{}).Wait() - db.CreateTable(schemas.Collections.Webhook, schemas.Webhook{}).Wait() - db.CreateTable(schemas.Collections.WebhookLog, schemas.WebhookLog{}).Wait() - db.CreateTable(schemas.Collections.Authenticators, schemas.Authenticator{}).Wait() - db.CreateTable(schemas.Collections.SessionToken, schemas.SessionToken{}).Wait() - db.CreateTable(schemas.Collections.MFASession, schemas.MFASession{}).Wait() - db.CreateTable(schemas.Collections.OAuthState, schemas.OAuthState{}).Wait() - db.CreateTable(schemas.Collections.AuditLog, schemas.AuditLog{}).Wait() + sess, err := session.NewSession(&awsCfg) + if err != nil { + return nil, fmt.Errorf("dynamodb session: %w", err) + } + db := dynamo.New(sess) + + createCtx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + + tables := []struct { + name string + model interface{} + }{ + {schemas.Collections.User, schemas.User{}}, + {schemas.Collections.Session, schemas.Session{}}, + {schemas.Collections.EmailTemplate, schemas.EmailTemplate{}}, + {schemas.Collections.Env, schemas.Env{}}, + {schemas.Collections.OTP, schemas.OTP{}}, + {schemas.Collections.VerificationRequest, schemas.VerificationRequest{}}, + {schemas.Collections.Webhook, schemas.Webhook{}}, + {schemas.Collections.WebhookLog, schemas.WebhookLog{}}, + {schemas.Collections.Authenticators, schemas.Authenticator{}}, + {schemas.Collections.SessionToken, schemas.SessionToken{}}, + {schemas.Collections.MFASession, schemas.MFASession{}}, + {schemas.Collections.OAuthState, schemas.OAuthState{}}, + {schemas.Collections.AuditLog, schemas.AuditLog{}}, + } + for _, tbl := range tables { + if werr := db.CreateTable(tbl.name, tbl.model).WaitWithContext(createCtx); werr != nil { + return nil, fmt.Errorf("dynamodb create/wait table %q: %w", tbl.name, werr) + } + } return &provider{ db: db, config: cfg, dependencies: deps, }, nil } + +// Close is a no-op; the AWS SDK session needs no explicit shutdown for typical use. +func (p *provider) Close() error { + return nil +} diff --git a/internal/storage/db/mongodb/provider.go b/internal/storage/db/mongodb/provider.go index 24466715..273107fc 100644 --- a/internal/storage/db/mongodb/provider.go +++ b/internal/storage/db/mongodb/provider.go @@ -202,3 +202,10 @@ func NewProvider(config *config.Config, deps *Dependencies) (*provider, error) { db: mongodb, }, nil } + +// Close disconnects the MongoDB client. +func (p *provider) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return p.db.Client().Disconnect(ctx) +} diff --git a/internal/storage/db/provider_template/provider.go b/internal/storage/db/provider_template/provider.go index 10303201..b224bed9 100644 --- a/internal/storage/db/provider_template/provider.go +++ b/internal/storage/db/provider_template/provider.go @@ -46,3 +46,15 @@ func NewProvider( db: sqlDB, }, nil } + +// Close closes the underlying database pool when initialized. +func (p *provider) Close() error { + if p.db == nil { + return nil + } + sqlDB, err := p.db.DB() + if err != nil { + return err + } + return sqlDB.Close() +} diff --git a/internal/storage/db/sql/provider.go b/internal/storage/db/sql/provider.go index 7ee576e0..c9c7eaa0 100644 --- a/internal/storage/db/sql/provider.go +++ b/internal/storage/db/sql/provider.go @@ -120,3 +120,15 @@ func NewProvider( db: sqlDB, }, nil } + +// Close closes the underlying SQL connection pool. +func (p *provider) Close() error { + if p.db == nil { + return nil + } + sqlDB, err := p.db.DB() + if err != nil { + return err + } + return sqlDB.Close() +} diff --git a/internal/storage/provider.go b/internal/storage/provider.go index 7e24706f..bbf1b551 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -18,7 +18,7 @@ import ( "github.com/authorizerdev/authorizer/internal/storage/schemas" ) -// Dependencies struct the data store provider +// Dependencies carries shared resources for constructing a storage Provider. type Dependencies struct { Log *zerolog.Logger } @@ -171,6 +171,9 @@ type Provider interface { // HealthCheck verifies that the storage backend is reachable and responsive. HealthCheck(ctx context.Context) error + + // Close releases resources held by the provider (e.g. database connection pools). + Close() error } // New creates a new database provider based on the configuration diff --git a/internal/storage/provider_test.go b/internal/storage/provider_test.go index 7504257b..cab6f6e6 100644 --- a/internal/storage/provider_test.go +++ b/internal/storage/provider_test.go @@ -112,6 +112,11 @@ func TestStorageProvider(t *testing.T) { require.NoError(t, err) require.NotNil(t, provider) + t.Run("HealthCheck", func(t *testing.T) { + err := provider.HealthCheck(ctx) + assert.NoError(t, err, "HealthCheck should succeed when the test database is reachable") + }) + t.Run("Authenticator Operations", func(t *testing.T) { testAuthenticatorOperations(t, ctx, provider) }) diff --git a/internal/utils/gin_context.go b/internal/utils/gin_context.go index 0491cbe5..801fd13a 100644 --- a/internal/utils/gin_context.go +++ b/internal/utils/gin_context.go @@ -7,9 +7,17 @@ import ( "github.com/gin-gonic/gin" ) -// GinContext to get gin context from context +type ginContextKey struct{} + +// ContextWithGin stores c in ctx for GinContextFromContext. Use this instead of ad-hoc +// context keys so lookups stay consistent across HTTP handlers and tests. +func ContextWithGin(ctx context.Context, c *gin.Context) context.Context { + return context.WithValue(ctx, ginContextKey{}, c) +} + +// GinContextFromContext returns the gin.Context previously stored with ContextWithGin. func GinContextFromContext(ctx context.Context) (*gin.Context, error) { - ginContext := ctx.Value("GinContextKey") + ginContext := ctx.Value(ginContextKey{}) if ginContext == nil { err := fmt.Errorf("could not retrieve gin.Context") return nil, err diff --git a/web/app/package-lock.json b/web/app/package-lock.json index 66a2a438..5fd2a7ab 100644 --- a/web/app/package-lock.json +++ b/web/app/package-lock.json @@ -9,7 +9,7 @@ "version": "1.0.0", "license": "ISC", "dependencies": { - "@authorizerdev/authorizer-react": "^2.0.0", + "@authorizerdev/authorizer-react": "^2.0.7", "react": "^18.3.1", "react-dom": "^18.3.1", "react-is": "^18.3.1", @@ -28,12 +28,12 @@ } }, "node_modules/@authorizerdev/authorizer-js": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-3.0.0.tgz", - "integrity": "sha512-g16Knpr7jHDbMaD88sJQhm0eax/khvBq352fegIbCSW2zRlhlfrVIfhj0654fwyaZjX9+xtcbnE60IkIxpguFg==", + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-3.0.4.tgz", + "integrity": "sha512-vkXg1inxC6U2eLra/EQmhTVKzdlpCF4+a93tOfUgSyISYnE8v9np54OAOrs//4aTVOwFTIhahSTvpERKj2NZAQ==", "license": "MIT", "dependencies": { - "cross-fetch": "^3.1.5" + "cross-fetch": "^4.1.0" }, "engines": { "node": ">=16" @@ -43,12 +43,12 @@ } }, "node_modules/@authorizerdev/authorizer-react": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-2.0.0.tgz", - "integrity": "sha512-yZRNh0EFZKMCwEpwuooROqda71yVxotS/grPAYMsXUzjXSgaOABKQmiV7K9J+MmxJfiNeoIrC3n7AiHgbGEuzw==", + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-2.0.7.tgz", + "integrity": "sha512-+qBrbdE6VyljOge1Ad2AmVIpoheymlLsMxCZHW0hnCeKf0R0wJz3MvoKBiaHAcRF/Bf8Wc6+NBCctAnzXjiwMg==", "license": "MIT", "dependencies": { - "@authorizerdev/authorizer-js": "3.0.0", + "@authorizerdev/authorizer-js": "3.0.4", "@storybook/preset-scss": "^1.0.3", "validator": "^13.11.0" }, @@ -1760,9 +1760,9 @@ "license": "MIT" }, "node_modules/cross-fetch": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.2.0.tgz", - "integrity": "sha512-Q+xVJLoGOeIMXZmbUK4HYk+69cQH6LudR0Vu/pRm2YlU/hDV9CiS0gKUMaWY5f2NeUH9C1nV3bsTlCo0FsTV1Q==", + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/cross-fetch/-/cross-fetch-4.1.0.tgz", + "integrity": "sha512-uKm5PU+MHTootlWEY+mZ4vvXoCn4fLQxT9dSc1sXVMSFkINTJVN8cAQROpwcKm8bJ/c7rgZVIBWzH5T78sNZZw==", "license": "MIT", "dependencies": { "node-fetch": "^2.7.0" @@ -2295,9 +2295,9 @@ "license": "ISC" }, "node_modules/picomatch": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", - "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", + "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "dev": true, "license": "MIT", "peer": true, diff --git a/web/app/package.json b/web/app/package.json index 38093659..69b0d731 100644 --- a/web/app/package.json +++ b/web/app/package.json @@ -13,7 +13,7 @@ "author": "Lakhan Samani", "license": "ISC", "dependencies": { - "@authorizerdev/authorizer-react": "^2.0.0", + "@authorizerdev/authorizer-react": "^2.0.7", "react": "^18.3.1", "react-dom": "^18.3.1", "react-is": "^18.3.1", diff --git a/web/app/src/App.tsx b/web/app/src/App.tsx index 281f6f24..b863beac 100644 --- a/web/app/src/App.tsx +++ b/web/app/src/App.tsx @@ -32,6 +32,7 @@ export default function App() { ...window['__authorizer__'], ...urlProps, }; + console.log(globalState); return (
diff --git a/web/app/src/Root.tsx b/web/app/src/Root.tsx index 3cfa1f48..18516679 100644 --- a/web/app/src/Root.tsx +++ b/web/app/src/Root.tsx @@ -4,18 +4,6 @@ import { useAuthorizer } from '@authorizerdev/authorizer-react'; import SetupPassword from './pages/setup-password'; import { hasWindow, createRandomString } from './utils/common'; -function isValidRedirectUri(uri: string): boolean { - try { - const url = new URL(uri, window.location.origin); - if (url.origin === window.location.origin) return true; - // Only allow http/https protocols to prevent javascript: or data: URIs - if (url.protocol !== 'http:' && url.protocol !== 'https:') return false; - return false; - } catch { - return false; - } -} - const ResetPassword = lazy(() => import('./pages/rest-password')); const Login = lazy(() => import('./pages/login')); const Dashboard = lazy(() => import('./pages/dashboard')); @@ -82,7 +70,10 @@ export default function Root({ const rawRedirectURL = searchParams.get('redirect_uri') || searchParams.get('redirectURL'); - if (rawRedirectURL && isValidRedirectUri(rawRedirectURL, config?.redirectURL)) { + if ( + rawRedirectURL && + isValidRedirectUri(rawRedirectURL, config?.redirectURL) + ) { urlProps.redirectURL = rawRedirectURL; } else { urlProps.redirectURL = hasWindow() ? window.location.origin : '/';