Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ type AuthorizationService interface {

// SecurityScheme for OpenAPI
type SecurityScheme struct {
Type string `json:"type"`
Scheme string `json:"scheme,omitempty"`
BearerFormat string `json:"bearerFormat,omitempty"`
In string `json:"in,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Flows map[string]interface{} `json:"flows,omitempty"`
Type string `json:"type" yaml:"type"`
Scheme string `json:"scheme,omitempty" yaml:"scheme,omitempty"`
BearerFormat string `json:"bearerFormat,omitempty" yaml:"bearerFormat,omitempty"`
In string `json:"in,omitempty" yaml:"in,omitempty"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
Flows map[string]interface{} `json:"flows,omitempty" yaml:"flows,omitempty"`
}

// GetAuthContext extracts the authentication context from Fiber
Expand Down
92 changes: 92 additions & 0 deletions auth_schemes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,98 @@ func TestPerRouteSecurity_OverridesGlobalDefault(t *testing.T) {
}
}

func TestSecurityScheme_YAMLSpec_CorrectKeys(t *testing.T) {
app := fiber.New()

oapi := New(app, Config{
EnableOpenAPIDocs: true,
SecuritySchemes: map[string]SecurityScheme{
"bearerAuth": {
Type: "http",
Scheme: "bearer",
BearerFormat: "JWT",
Description: "JWT Bearer token",
},
"apiKey": {
Type: "apiKey",
In: "header",
Name: "X-API-Key",
},
},
DefaultSecurity: []map[string][]string{
{"bearerAuth": {}},
},
})

// Register a dummy route so the spec is non-empty
Get(oapi, "/ping", func(c *fiber.Ctx, input struct{}) (fiber.Map, *ErrorResponse) {
return fiber.Map{"ok": true}, nil
}, OpenAPIOptions{Summary: "Ping"})

yamlSpec, err := oapi.GenerateOpenAPISpecYAML()
if err != nil {
t.Fatalf("GenerateOpenAPISpecYAML failed: %v", err)
}

// Bearer scheme: must use camelCase bearerFormat, and must NOT contain in/name/flows
t.Run("bearer scheme uses correct YAML keys", func(t *testing.T) {
if !containsString(yamlSpec, "bearerFormat: JWT") {
t.Error("Expected YAML to contain 'bearerFormat: JWT' (camelCase F)")
}
if containsString(yamlSpec, "bearerformat:") {
t.Error("YAML must not contain lowercase 'bearerformat' (wrong casing)")
}
})

// For type: http, the fields in/name/flows must be omitted
t.Run("http scheme omits apiKey-only fields", func(t *testing.T) {
// Parse the YAML to inspect the bearerAuth scheme specifically
spec := oapi.GenerateOpenAPISpec()
components := spec["components"].(map[string]interface{})
schemes := components["securitySchemes"].(map[string]SecurityScheme)
bearer := schemes["bearerAuth"]

if bearer.In != "" {
t.Errorf("Expected 'in' to be empty for http scheme, got %q", bearer.In)
}
if bearer.Name != "" {
t.Errorf("Expected 'name' to be empty for http scheme, got %q", bearer.Name)
}
if bearer.Flows != nil {
t.Errorf("Expected 'flows' to be nil for http scheme, got %v", bearer.Flows)
}

// Also verify these don't appear in the YAML output for the bearer section
// by checking the full YAML doesn't have empty-valued fields
if containsString(yamlSpec, "in: \"\"") || containsString(yamlSpec, "name: \"\"") || containsString(yamlSpec, "flows: {}") {
t.Error("YAML must not contain empty 'in', 'name', or 'flows' fields for http scheme")
}
})

// API Key scheme: must have in and name
t.Run("apiKey scheme has required fields", func(t *testing.T) {
if !containsString(yamlSpec, "in: header") {
t.Error("Expected YAML to contain 'in: header' for apiKey scheme")
}
if !containsString(yamlSpec, "name: X-API-Key") {
t.Error("Expected YAML to contain 'name: X-API-Key' for apiKey scheme")
}
})
}

func containsString(s, substr string) bool {
return len(s) > 0 && len(substr) > 0 && fmt.Sprintf("%s", s) != "" && stringContains(s, substr)
}

func stringContains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

func parseJSONResponse(resp *http.Response, target interface{}) error {
defer resp.Body.Close()
return json.NewDecoder(resp.Body).Decode(target)
Expand Down
Loading