From 84ab8d93fe0247b9346748ab5d1e9bdb08c75146 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Mon, 9 Feb 2026 08:04:17 +1000 Subject: [PATCH 1/7] feat: add autofixer system with fixes for 32 lint rules --- cmd/openapi/commands/openapi/lint.go | 157 ++- cmd/openapi/commands/openapi/list_rules.go | 172 +++ cmd/openapi/commands/openapi/root.go | 1 + linter/doc.go | 2 + linter/fix/engine.go | 243 ++++ linter/fix/engine_test.go | 332 ++++++ linter/fix/registry.go | 56 + linter/fix/registry_test.go | 112 ++ linter/fix/terminal_prompter.go | 139 +++ linter/fix/terminal_prompter_test.go | 239 ++++ linter/format/format_test.go | 130 +++ linter/format/json.go | 4 +- linter/format/summary.go | 105 ++ linter/format/text.go | 7 +- linter/linter.go | 6 + openapi/linter/customrules/go.mod | 2 + openapi/linter/customrules/go.sum | 2 - openapi/linter/customrules/jsfix.go | 164 +++ openapi/linter/customrules/jsfix_test.go | 149 +++ openapi/linter/customrules/loader.go | 2 +- openapi/linter/customrules/runtime.go | 63 +- openapi/linter/customrules/shim/types-shim.js | 4 +- openapi/linter/rules/component_description.go | 104 +- openapi/linter/rules/contact_properties.go | 41 +- .../linter/rules/duplicated_entry_in_enum.go | 43 +- openapi/linter/rules/fix_available.go | 37 + openapi/linter/rules/fix_helpers.go | 574 ++++++++++ openapi/linter/rules/fix_helpers_test.go | 1014 +++++++++++++++++ openapi/linter/rules/fix_integration_test.go | 516 +++++++++ openapi/linter/rules/host_not_example.go | 13 +- openapi/linter/rules/host_trailing_slash.go | 34 +- openapi/linter/rules/info_contact.go | 14 +- openapi/linter/rules/info_description.go | 51 +- openapi/linter/rules/info_license.go | 14 +- openapi/linter/rules/license_url.go | 14 +- openapi/linter/rules/oas3_api_servers.go | 13 +- openapi/linter/rules/oas3_no_nullable.go | 70 +- openapi/linter/rules/operation_description.go | 14 +- openapi/linter/rules/operation_tag_defined.go | 43 +- openapi/linter/rules/operation_tags.go | 14 +- ...owasp_additional_properties_constrained.go | 13 +- openapi/linter/rules/owasp_array_limit.go | 13 +- .../rules/owasp_define_error_responses_401.go | 13 +- .../rules/owasp_define_error_responses_429.go | 13 +- .../rules/owasp_define_error_responses_500.go | 13 +- .../rules/owasp_define_error_validation.go | 13 +- openapi/linter/rules/owasp_integer_format.go | 13 +- openapi/linter/rules/owasp_integer_limit.go | 13 +- .../linter/rules/owasp_jwt_best_practices.go | 61 +- .../rules/owasp_no_additional_properties.go | 43 +- .../rules/owasp_rate_limit_retry_after.go | 31 +- .../rules/owasp_security_hosts_https_oas3.go | 32 +- openapi/linter/rules/owasp_string_limit.go | 13 +- openapi/linter/rules/parameter_description.go | 14 +- openapi/linter/rules/path_trailing_slash.go | 35 +- openapi/linter/rules/rule_fixes_test.go | 563 +++++++++ openapi/linter/rules/tag_description.go | 14 +- openapi/linter/rules/tags_alphabetical.go | 51 +- openapi/linter/rules/unused_components.go | 86 +- validation/errors.go | 6 - validation/fix.go | 70 ++ validation/prompter.go | 21 + 62 files changed, 5563 insertions(+), 295 deletions(-) create mode 100644 cmd/openapi/commands/openapi/list_rules.go create mode 100644 linter/fix/engine.go create mode 100644 linter/fix/engine_test.go create mode 100644 linter/fix/registry.go create mode 100644 linter/fix/registry_test.go create mode 100644 linter/fix/terminal_prompter.go create mode 100644 linter/fix/terminal_prompter_test.go create mode 100644 linter/format/summary.go create mode 100644 openapi/linter/customrules/jsfix.go create mode 100644 openapi/linter/customrules/jsfix_test.go create mode 100644 openapi/linter/rules/fix_available.go create mode 100644 openapi/linter/rules/fix_helpers.go create mode 100644 openapi/linter/rules/fix_helpers_test.go create mode 100644 openapi/linter/rules/fix_integration_test.go create mode 100644 openapi/linter/rules/rule_fixes_test.go create mode 100644 validation/fix.go create mode 100644 validation/prompter.go diff --git a/cmd/openapi/commands/openapi/lint.go b/cmd/openapi/commands/openapi/lint.go index d3011696..76d11dae 100644 --- a/cmd/openapi/commands/openapi/lint.go +++ b/cmd/openapi/commands/openapi/lint.go @@ -7,6 +7,7 @@ import ( "path/filepath" "github.com/speakeasy-api/openapi/linter" + "github.com/speakeasy-api/openapi/linter/fix" "github.com/speakeasy-api/openapi/openapi" openapiLinter "github.com/speakeasy-api/openapi/openapi/linter" "github.com/spf13/cobra" @@ -56,17 +57,28 @@ in your rules directory: Then configure the paths in your lint.yaml under custom_rules.paths. +AUTOFIXING: + +Use --fix to automatically apply non-interactive fixes. Use --fix-interactive to +also be prompted for fixes that require user input (choosing values, entering text). +Use --dry-run with either flag to preview what would be changed without modifying the file. + See the full documentation at: https://github.com/speakeasy-api/openapi/blob/main/cmd/openapi/commands/openapi/README.md#lint`, - Args: cobra.ExactArgs(1), - Run: runLint, + Args: cobra.ExactArgs(1), + PreRunE: validateLintFlags, + Run: runLint, } var ( - lintOutputFormat string - lintRuleset string - lintConfigFile string - lintDisableRules []string + lintOutputFormat string + lintRuleset string + lintConfigFile string + lintDisableRules []string + lintSummary bool + lintFix bool + lintFixInteractive bool + lintDryRun bool ) func init() { @@ -74,6 +86,20 @@ func init() { lintCmd.Flags().StringVarP(&lintRuleset, "ruleset", "r", "all", "Ruleset to use (default loads from config)") lintCmd.Flags().StringVarP(&lintConfigFile, "config", "c", "", "Path to lint config file (default: ~/.openapi/lint.yaml)") lintCmd.Flags().StringSliceVarP(&lintDisableRules, "disable", "d", nil, "Rule IDs to disable (can be repeated)") + lintCmd.Flags().BoolVar(&lintSummary, "summary", false, "Print a per-rule summary table of findings") + lintCmd.Flags().BoolVar(&lintFix, "fix", false, "Automatically apply non-interactive fixes and write back") + lintCmd.Flags().BoolVar(&lintFixInteractive, "fix-interactive", false, "Apply all fixes, prompting for interactive ones") + lintCmd.Flags().BoolVar(&lintDryRun, "dry-run", false, "Show what fixes would be applied without changing the file (requires --fix or --fix-interactive)") +} + +func validateLintFlags(_ *cobra.Command, _ []string) error { + if lintFix && lintFixInteractive { + return fmt.Errorf("--fix and --fix-interactive are mutually exclusive") + } + if lintDryRun && !lintFix && !lintFixInteractive { + return fmt.Errorf("--dry-run requires --fix or --fix-interactive") + } + return nil } func runLint(cmd *cobra.Command, args []string) { @@ -126,6 +152,42 @@ func lintOpenAPI(ctx context.Context, file string) error { return fmt.Errorf("linting failed: %w", err) } + // Determine fix mode + fixOpts := fix.Options{Mode: fix.ModeNone, DryRun: lintDryRun} + switch { + case lintFixInteractive: + fixOpts.Mode = fix.ModeInteractive + case lintFix: + fixOpts.Mode = fix.ModeAuto + } + + if fixOpts.Mode != fix.ModeNone { + if err := applyFixes(ctx, fixOpts, doc, output, cleanFile); err != nil { + return err + } + + // Re-lint after applying fixes (unless dry-run) to get accurate remaining count + if !lintDryRun { + // Reload and re-lint the fixed document + reloadedF, err := os.Open(cleanFile) + if err != nil { + return fmt.Errorf("failed to reopen file after fix: %w", err) + } + defer reloadedF.Close() + + reloadedDoc, reloadedValErrs, err := openapi.Unmarshal(ctx, reloadedF) + if err != nil { + return fmt.Errorf("failed to unmarshal fixed file: %w", err) + } + + reloadedDocInfo := linter.NewDocumentInfo(reloadedDoc, absPath) + output, err = lint.Lint(ctx, reloadedDocInfo, reloadedValErrs, nil) + if err != nil { + return fmt.Errorf("re-linting failed: %w", err) + } + } + } + // Format and print output switch lintOutputFormat { case "json": @@ -135,6 +197,11 @@ func lintOpenAPI(ctx context.Context, file string) error { fmt.Println(output.FormatText()) } + // Print per-rule summary if requested + if lintSummary { + fmt.Println(output.FormatSummary()) + } + // Exit with error code if there are errors if output.HasErrors() { return fmt.Errorf("linting found %d errors", output.ErrorCount()) @@ -143,6 +210,84 @@ func lintOpenAPI(ctx context.Context, file string) error { return nil } +func applyFixes(ctx context.Context, fixOpts fix.Options, doc *openapi.OpenAPI, output *linter.Output, cleanFile string) error { + // Create prompter for interactive mode + var prompter *fix.TerminalPrompter + if fixOpts.Mode == fix.ModeInteractive { + prompter = fix.NewTerminalPrompter(os.Stdin, os.Stderr) + } + + engine := fix.NewEngine(fixOpts, prompter, nil) + result, err := engine.ProcessErrors(ctx, doc, output.Results) + if err != nil { + return fmt.Errorf("fix processing failed: %w", err) + } + + // Report fix results to stderr + reportFixResults(result, fixOpts.DryRun) + + // Write modified document back if any fixes were applied (and not dry-run) + if len(result.Applied) > 0 && !fixOpts.DryRun { + processor, err := NewOpenAPIProcessor(cleanFile, "", true) + if err != nil { + return fmt.Errorf("failed to create processor: %w", err) + } + if err := processor.WriteDocument(ctx, doc); err != nil { + return fmt.Errorf("failed to write fixed document: %w", err) + } + fmt.Fprintf(os.Stderr, "Applied %d fix(es) to %s\n", len(result.Applied), cleanFile) + } + + return nil +} + +func reportFixResults(result *fix.Result, dryRun bool) { + prefix := "" + if dryRun { + prefix = "[dry-run] " + } + + if len(result.Applied) > 0 { + fmt.Fprintf(os.Stderr, "\n%sFixed:\n", prefix) + for _, af := range result.Applied { + fmt.Fprintf(os.Stderr, " [%d:%d] %s - %s\n", + af.Error.GetLineNumber(), af.Error.GetColumnNumber(), + af.Error.Rule, af.Fix.Description()) + } + } + + if len(result.Skipped) > 0 { + fmt.Fprintf(os.Stderr, "\n%sSkipped:\n", prefix) + for _, sf := range result.Skipped { + fmt.Fprintf(os.Stderr, " [%d:%d] %s - %s (%s)\n", + sf.Error.GetLineNumber(), sf.Error.GetColumnNumber(), + sf.Error.Rule, sf.Fix.Description(), skipReasonString(sf.Reason)) + } + } + + if len(result.Failed) > 0 { + fmt.Fprintf(os.Stderr, "\n%sFailed:\n", prefix) + for _, ff := range result.Failed { + fmt.Fprintf(os.Stderr, " [%d:%d] %s - %s: %v\n", + ff.Error.GetLineNumber(), ff.Error.GetColumnNumber(), + ff.Error.Rule, ff.Fix.Description(), ff.FixError) + } + } +} + +func skipReasonString(reason fix.SkipReason) string { + switch reason { + case fix.SkipInteractive: + return "requires interactive input" + case fix.SkipConflict: + return "conflict with previous fix" + case fix.SkipUser: + return "skipped by user" + default: + return "unknown" + } +} + func buildLintConfig() *linter.Config { config := linter.NewConfig() diff --git a/cmd/openapi/commands/openapi/list_rules.go b/cmd/openapi/commands/openapi/list_rules.go new file mode 100644 index 00000000..298a8c87 --- /dev/null +++ b/cmd/openapi/commands/openapi/list_rules.go @@ -0,0 +1,172 @@ +package openapi + +import ( + "encoding/json" + "fmt" + "os" + "strings" + "text/tabwriter" + + "github.com/speakeasy-api/openapi/linter" + openapiLinter "github.com/speakeasy-api/openapi/openapi/linter" + "github.com/spf13/cobra" +) + +var listRulesCmd = &cobra.Command{ + Use: "list-rules", + Short: "List all available linting rules", + Long: `List all available linting rules with their metadata. + +Shows each rule's ID, category, default severity, description, and fix guidance. +Use --category to filter by category, or --ruleset to show only rules in a ruleset. + +Examples: + openapi spec list-rules + openapi spec list-rules --category security + openapi spec list-rules --ruleset recommended + openapi spec list-rules --format json`, + Run: runListRules, +} + +var ( + listRulesFormat string + listRulesCategory string + listRulesRuleset string +) + +func init() { + listRulesCmd.Flags().StringVarP(&listRulesFormat, "format", "f", "text", "Output format: text or json") + listRulesCmd.Flags().StringVar(&listRulesCategory, "category", "", "Filter by category (e.g., security, style, semantic)") + listRulesCmd.Flags().StringVar(&listRulesRuleset, "ruleset", "", "Filter by ruleset (e.g., recommended, security, all)") +} + +// howToFixer is the interface satisfied by rules that provide fix guidance. +type howToFixer interface { + HowToFix() string +} + +type ruleInfo struct { + ID string `json:"id"` + Category string `json:"category"` + DefaultSeverity string `json:"defaultSeverity"` + Summary string `json:"summary"` + Description string `json:"description"` + HowToFix string `json:"howToFix,omitempty"` + FixAvailable bool `json:"fixAvailable,omitempty"` + Link string `json:"link,omitempty"` + Rulesets []string `json:"rulesets"` +} + +func runListRules(cmd *cobra.Command, _ []string) { + config := linter.NewConfig() + lint, err := openapiLinter.NewLinter(config) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + registry := lint.Registry() + allRules := registry.AllRules() + + var infos []ruleInfo + for _, rule := range allRules { + // Apply category filter + if listRulesCategory != "" && rule.Category() != listRulesCategory { + continue + } + + // Apply ruleset filter + if listRulesRuleset != "" { + rulesets := registry.RulesetsContaining(rule.ID()) + found := false + for _, rs := range rulesets { + if rs == listRulesRuleset { + found = true + break + } + } + if !found { + continue + } + } + + info := ruleInfo{ + ID: rule.ID(), + Category: rule.Category(), + DefaultSeverity: rule.DefaultSeverity().String(), + Summary: rule.Summary(), + Description: rule.Description(), + Link: rule.Link(), + Rulesets: registry.RulesetsContaining(rule.ID()), + } + + if fixer, ok := rule.(howToFixer); ok { + info.HowToFix = fixer.HowToFix() + } + + if fixable, ok := rule.(interface{ FixAvailable() bool }); ok { + info.FixAvailable = fixable.FixAvailable() + } + + infos = append(infos, info) + } + + switch listRulesFormat { + case "json": + printRulesJSON(infos) + default: + printRulesText(cmd, infos, registry.AllCategories()) + } +} + +func printRulesText(_ *cobra.Command, infos []ruleInfo, categories []string) { + if len(infos) == 0 { + fmt.Println("No rules found matching the specified filters.") + return + } + + // Group by category + byCategory := make(map[string][]ruleInfo) + for _, info := range infos { + byCategory[info.Category] = append(byCategory[info.Category], info) + } + + // Print in category order + for _, cat := range categories { + rules, ok := byCategory[cat] + if !ok { + continue + } + + fmt.Printf("\n%s (%d rules)\n", strings.ToUpper(cat), len(rules)) + fmt.Println(strings.Repeat("─", 80)) + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + for _, info := range rules { + fixMarker := "" + if info.FixAvailable { + fixMarker = " [fixable]" + } + fmt.Fprintf(w, " %s\t%s\t[%s]%s\n", info.ID, info.Summary, info.DefaultSeverity, fixMarker) + if info.HowToFix != "" { + fmt.Fprintf(w, " \tFix: %s\n", info.HowToFix) + } + if info.Link != "" { + fmt.Fprintf(w, " \tDocs: %s\n", info.Link) + } + fmt.Fprintf(w, " \tRulesets: %s\n", strings.Join(info.Rulesets, ", ")) + } + w.Flush() + } + + fmt.Printf("\n%d rules total\n", len(infos)) +} + +func printRulesJSON(infos []ruleInfo) { + bytes, err := json.MarshalIndent(infos, "", " ") + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + fmt.Println(string(bytes)) +} diff --git a/cmd/openapi/commands/openapi/root.go b/cmd/openapi/commands/openapi/root.go index f2619205..5f4c614f 100644 --- a/cmd/openapi/commands/openapi/root.go +++ b/cmd/openapi/commands/openapi/root.go @@ -6,6 +6,7 @@ import "github.com/spf13/cobra" func Apply(rootCmd *cobra.Command) { rootCmd.AddCommand(validateCmd) rootCmd.AddCommand(lintCmd) + rootCmd.AddCommand(listRulesCmd) rootCmd.AddCommand(upgradeCmd) rootCmd.AddCommand(inlineCmd) rootCmd.AddCommand(cleanCmd) diff --git a/linter/doc.go b/linter/doc.go index efb743f8..77465e7b 100644 --- a/linter/doc.go +++ b/linter/doc.go @@ -54,6 +54,8 @@ func (g *DocGenerator[T]) GenerateRuleDoc(rule RuleRunner[T]) *RuleDoc { doc.BadExample = documented.BadExample() doc.Rationale = documented.Rationale() doc.FixAvailable = documented.FixAvailable() + } else if fixable, ok := any(rule).(interface{ FixAvailable() bool }); ok { + doc.FixAvailable = fixable.FixAvailable() } // Check for configuration interface diff --git a/linter/fix/engine.go b/linter/fix/engine.go new file mode 100644 index 00000000..6d236075 --- /dev/null +++ b/linter/fix/engine.go @@ -0,0 +1,243 @@ +package fix + +import ( + "context" + "errors" + "sort" + + "github.com/speakeasy-api/openapi/openapi" + "github.com/speakeasy-api/openapi/validation" + "gopkg.in/yaml.v3" +) + +// Mode controls how fixes are applied. +type Mode int + +const ( + // ModeNone means no fixing (normal lint). + ModeNone Mode = iota + // ModeAuto applies only non-interactive fixes. + ModeAuto + // ModeInteractive applies all fixes, prompting for interactive ones. + ModeInteractive +) + +// Options configures fix engine behavior. +type Options struct { + // Mode controls which fixes are applied. + Mode Mode + // DryRun when true reports what would be fixed without applying changes. + // Acts as a modifier on ModeAuto or ModeInteractive. + DryRun bool +} + +// SkipReason explains why a fix was skipped. +type SkipReason int + +const ( + // SkipInteractive means the fix requires user input but the mode is non-interactive. + SkipInteractive SkipReason = iota + // SkipConflict means another fix already modified the same location. + SkipConflict + // SkipUser means the user chose to skip the fix in interactive mode. + SkipUser +) + +// AppliedFix records a successfully applied fix. +type AppliedFix struct { + Error *validation.Error + Fix validation.Fix +} + +// SkippedFix records a fix that was skipped. +type SkippedFix struct { + Error *validation.Error + Fix validation.Fix + Reason SkipReason +} + +// FailedFix records a fix that failed to apply. +type FailedFix struct { + Error *validation.Error + Fix validation.Fix + FixError error +} + +// Result tracks what the engine did. +type Result struct { + Applied []AppliedFix + Skipped []SkippedFix + Failed []FailedFix +} + +// Engine applies fixes to an OpenAPI document. +type Engine struct { + opts Options + prompter validation.Prompter + registry *FixRegistry +} + +// NewEngine creates a new fix engine. +func NewEngine(opts Options, prompter validation.Prompter, registry *FixRegistry) *Engine { + return &Engine{ + opts: opts, + prompter: prompter, + registry: registry, + } +} + +// conflictKey identifies a document location for conflict detection. +type conflictKey struct { + Line int + Column int +} + +// ProcessErrors takes lint output errors and applies fixes where available. +// The doc is modified in-place by successful fixes. +func (e *Engine) ProcessErrors(ctx context.Context, doc *openapi.OpenAPI, errs []error) (*Result, error) { + if e.opts.Mode == ModeNone { + return &Result{}, nil + } + + // Collect fixable errors + type fixableError struct { + vErr *validation.Error + fix validation.Fix + } + + var fixable []fixableError + + for _, err := range errs { + var vErr *validation.Error + if !errors.As(err, &vErr) { + continue + } + + fix := vErr.Fix + + // If no fix attached to the error, check the registry + if fix == nil && e.registry != nil { + fix = e.registry.GetFix(vErr) + } + + if fix != nil { + fixable = append(fixable, fixableError{vErr: vErr, fix: fix}) + } + } + + if len(fixable) == 0 { + return &Result{}, nil + } + + // Sort by document location for deterministic ordering + sort.Slice(fixable, func(i, j int) bool { + li, ci := fixable[i].vErr.GetLineNumber(), fixable[i].vErr.GetColumnNumber() + lj, cj := fixable[j].vErr.GetLineNumber(), fixable[j].vErr.GetColumnNumber() + if li != lj { + return li < lj + } + return ci < cj + }) + + result := &Result{} + modified := make(map[conflictKey]bool) + + for _, fe := range fixable { + fix := fe.fix + vErr := fe.vErr + + // Check for conflicts at the same location + key := conflictKey{Line: vErr.GetLineNumber(), Column: vErr.GetColumnNumber()} + if key.Line >= 0 && modified[key] { + result.Skipped = append(result.Skipped, SkippedFix{ + Error: vErr, + Fix: fix, + Reason: SkipConflict, + }) + continue + } + + // Skip interactive fixes in auto mode or when no prompter is available + if fix.Interactive() && (e.opts.Mode == ModeAuto || e.prompter == nil) { + result.Skipped = append(result.Skipped, SkippedFix{ + Error: vErr, + Fix: fix, + Reason: SkipInteractive, + }) + continue + } + + // Dry-run: record what would happen without applying + if e.opts.DryRun { + result.Applied = append(result.Applied, AppliedFix{Error: vErr, Fix: fix}) + if key.Line >= 0 { + modified[key] = true + } + continue + } + + // Handle interactive input + if fix.Interactive() && e.prompter != nil { + responses, err := e.prompter.PromptFix(vErr, fix) + if err != nil { + if errors.Is(err, validation.ErrSkipFix) { + result.Skipped = append(result.Skipped, SkippedFix{ + Error: vErr, + Fix: fix, + Reason: SkipUser, + }) + continue + } + result.Failed = append(result.Failed, FailedFix{ + Error: vErr, Fix: fix, FixError: err, + }) + continue + } + + if err := fix.SetInput(responses); err != nil { + result.Failed = append(result.Failed, FailedFix{ + Error: vErr, Fix: fix, FixError: err, + }) + continue + } + } + + // Apply the fix + var applyErr error + if nodeFix, ok := fix.(validation.NodeFix); ok { + rootNode := doc.GetRootNode() + if rootNode != nil { + applyErr = nodeFix.ApplyNode(rootNode) + } else { + applyErr = fix.Apply(doc) + } + } else { + applyErr = fix.Apply(doc) + } + + if applyErr != nil { + result.Failed = append(result.Failed, FailedFix{ + Error: vErr, Fix: fix, FixError: applyErr, + }) + continue + } + + // Mark location as modified for conflict detection + if key.Line >= 0 { + modified[key] = true + } + + result.Applied = append(result.Applied, AppliedFix{Error: vErr, Fix: fix}) + } + + return result, nil +} + +// ApplyNodeFix is a helper that applies a NodeFix if the fix implements the interface, +// otherwise falls back to Apply. +func ApplyNodeFix(fix validation.Fix, doc *openapi.OpenAPI, rootNode *yaml.Node) error { + if nodeFix, ok := fix.(validation.NodeFix); ok && rootNode != nil { + return nodeFix.ApplyNode(rootNode) + } + return fix.Apply(doc) +} diff --git a/linter/fix/engine_test.go b/linter/fix/engine_test.go new file mode 100644 index 00000000..3c57f913 --- /dev/null +++ b/linter/fix/engine_test.go @@ -0,0 +1,332 @@ +package fix_test + +import ( + "errors" + "testing" + + "github.com/speakeasy-api/openapi/linter/fix" + "github.com/speakeasy-api/openapi/openapi" + "github.com/speakeasy-api/openapi/validation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +// mockFix is a non-interactive fix for testing. +type mockFix struct { + description string + applied bool + applyErr error +} + +func (f *mockFix) Description() string { return f.description } +func (f *mockFix) Interactive() bool { return false } +func (f *mockFix) Prompts() []validation.Prompt { return nil } +func (f *mockFix) SetInput([]string) error { return nil } +func (f *mockFix) Apply(doc any) error { + f.applied = true + return f.applyErr +} + +// mockInteractiveFix is an interactive fix for testing. +type mockInteractiveFix struct { + description string + prompts []validation.Prompt + applied bool + inputs []string +} + +func (f *mockInteractiveFix) Description() string { return f.description } +func (f *mockInteractiveFix) Interactive() bool { return true } +func (f *mockInteractiveFix) Prompts() []validation.Prompt { return f.prompts } +func (f *mockInteractiveFix) SetInput(responses []string) error { + f.inputs = responses + return nil +} +func (f *mockInteractiveFix) Apply(doc any) error { + f.applied = true + return nil +} + +// mockNodeFix implements NodeFix for testing. +type mockNodeFix struct { + mockFix + nodeApplied bool +} + +func (f *mockNodeFix) ApplyNode(rootNode *yaml.Node) error { + f.nodeApplied = true + return nil +} + +// mockPrompter is a test prompter that returns predefined responses. +type mockPrompter struct { + responses []string + err error + called bool +} + +func (p *mockPrompter) PromptFix(_ *validation.Error, _ validation.Fix) ([]string, error) { + p.called = true + return p.responses, p.err +} + +func (p *mockPrompter) Confirm(_ string) (bool, error) { + return true, nil +} + +func makeError(rule string, line, col int, msg string, f validation.Fix) error { + return &validation.Error{ + UnderlyingError: errors.New(msg), + Node: &yaml.Node{Line: line, Column: col}, + Severity: validation.SeverityWarning, + Rule: rule, + Fix: f, + } +} + +func TestEngine_ModeNone(t *testing.T) { + t.Parallel() + + engine := fix.NewEngine(fix.Options{Mode: fix.ModeNone}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "some error", &mockFix{description: "fix it"}), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Empty(t, result.Applied, "should not apply any fixes in ModeNone") + assert.Empty(t, result.Skipped, "should not skip any fixes in ModeNone") + assert.Empty(t, result.Failed, "should not fail any fixes in ModeNone") +} + +func TestEngine_ModeAuto_NonInteractive(t *testing.T) { + t.Parallel() + + f := &mockFix{description: "auto fix"} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "should apply the non-interactive fix") + assert.True(t, f.applied, "fix should have been applied") +} + +func TestEngine_ModeAuto_SkipsInteractive(t *testing.T) { + t.Parallel() + + f := &mockInteractiveFix{ + description: "needs input", + prompts: []validation.Prompt{{Type: validation.PromptFreeText, Message: "enter value"}}, + } + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Empty(t, result.Applied, "should not apply interactive fix in auto mode") + assert.Len(t, result.Skipped, 1, "should skip the interactive fix") + assert.False(t, f.applied, "fix should not have been applied") +} + +func TestEngine_ModeInteractive_PromptsUser(t *testing.T) { + t.Parallel() + + f := &mockInteractiveFix{ + description: "needs input", + prompts: []validation.Prompt{{Type: validation.PromptFreeText, Message: "enter value"}}, + } + prompter := &mockPrompter{responses: []string{"user answer"}} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeInteractive}, prompter, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.True(t, prompter.called, "prompter should have been called") + assert.Len(t, result.Applied, 1, "should apply the fix after prompting") + assert.True(t, f.applied, "fix should have been applied") + assert.Equal(t, []string{"user answer"}, f.inputs, "fix should have received user input") +} + +func TestEngine_ModeInteractive_UserSkips(t *testing.T) { + t.Parallel() + + f := &mockInteractiveFix{ + description: "needs input", + prompts: []validation.Prompt{{Type: validation.PromptFreeText, Message: "enter value"}}, + } + prompter := &mockPrompter{err: validation.ErrSkipFix} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeInteractive}, prompter, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Empty(t, result.Applied, "should not apply skipped fix") + assert.Len(t, result.Skipped, 1, "should record skipped fix") + assert.False(t, f.applied, "fix should not have been applied") +} + +func TestEngine_DryRun(t *testing.T) { + t.Parallel() + + f := &mockFix{description: "auto fix"} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto, DryRun: true}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "should record the fix as would-apply") + assert.False(t, f.applied, "fix should NOT have been actually applied in dry-run") +} + +func TestEngine_ConflictDetection(t *testing.T) { + t.Parallel() + + f1 := &mockFix{description: "first fix"} + f2 := &mockFix{description: "second fix"} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("rule-a", 5, 3, "issue 1", f1), + makeError("rule-b", 5, 3, "issue 2", f2), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "should apply the first fix") + assert.Len(t, result.Skipped, 1, "should skip the second fix as conflict") + assert.True(t, f1.applied, "first fix should have been applied") + assert.False(t, f2.applied, "second fix should not have been applied") +} + +func TestEngine_FailedFix(t *testing.T) { + t.Parallel() + + f := &mockFix{description: "broken fix", applyErr: errors.New("fix failed")} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail even when fixes fail") + assert.Empty(t, result.Applied, "should not record failed fix as applied") + assert.Len(t, result.Failed, 1, "should record the failed fix") +} + +func TestEngine_NodeFix_FallsBackToApply(t *testing.T) { + t.Parallel() + + // When doc.GetRootNode() returns nil, NodeFix falls back to Apply() + f := &mockNodeFix{mockFix: mockFix{description: "node fix"}} + doc := &openapi.OpenAPI{} + + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), doc, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "should apply the fix") + assert.False(t, f.nodeApplied, "ApplyNode should not be called when root node is nil") + assert.True(t, f.applied, "Apply should be called as fallback") +} + +func TestEngine_RegistryFix(t *testing.T) { + t.Parallel() + + f := &mockFix{description: "registry fix"} + registry := fix.NewFixRegistry() + registry.Register("validation-empty-value", func(_ *validation.Error) validation.Fix { + return f + }) + + // Error without a fix attached, but registry provides one + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, registry) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("validation-empty-value", 1, 1, "empty value", nil), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "should apply the registry-provided fix") + assert.True(t, f.applied, "fix should have been applied") +} + +func TestEngine_NoFixableErrors(t *testing.T) { + t.Parallel() + + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue without fix", nil), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Empty(t, result.Applied, "should have no applied fixes") + assert.Empty(t, result.Skipped, "should have no skipped fixes") + assert.Empty(t, result.Failed, "should have no failed fixes") +} + +func TestEngine_ModeInteractive_SkipsWhenNoPrompter(t *testing.T) { + t.Parallel() + + f := &mockInteractiveFix{ + description: "needs input", + prompts: []validation.Prompt{{Type: validation.PromptFreeText, Message: "enter value"}}, + } + // Interactive mode but nil prompter — interactive fixes should be skipped + engine := fix.NewEngine(fix.Options{Mode: fix.ModeInteractive}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Empty(t, result.Applied, "should not apply interactive fix without prompter") + assert.Len(t, result.Skipped, 1, "should skip the interactive fix") + assert.Equal(t, fix.SkipInteractive, result.Skipped[0].Reason, "skip reason should be SkipInteractive") + assert.False(t, f.applied, "fix should not have been applied") +} + +func TestEngine_ModeInteractive_NonInteractiveFixAppliesWithoutPrompter(t *testing.T) { + t.Parallel() + + f := &mockFix{description: "auto fix"} + // Interactive mode but nil prompter — non-interactive fixes should still apply + engine := fix.NewEngine(fix.Options{Mode: fix.ModeInteractive}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "should apply non-interactive fix even without prompter") + assert.True(t, f.applied, "fix should have been applied") +} + +func TestEngine_SortsByLocation(t *testing.T) { + t.Parallel() + + var order []string + makeFix := func(name string) *mockFix { + return &mockFix{description: name} + } + + f1 := makeFix("fix-line10") + f2 := makeFix("fix-line2") + f3 := makeFix("fix-line5") + + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("rule", 10, 1, "issue at line 10", f1), + makeError("rule", 2, 1, "issue at line 2", f2), + makeError("rule", 5, 1, "issue at line 5", f3), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 3, "should apply all fixes") + + for _, af := range result.Applied { + order = append(order, af.Fix.Description()) + } + assert.Equal(t, []string{"fix-line2", "fix-line5", "fix-line10"}, order, "fixes should be applied in location order") +} diff --git a/linter/fix/registry.go b/linter/fix/registry.go new file mode 100644 index 00000000..8e94b342 --- /dev/null +++ b/linter/fix/registry.go @@ -0,0 +1,56 @@ +package fix + +import ( + "sync" + + "github.com/speakeasy-api/openapi/validation" +) + +// FixProvider creates a Fix for a specific validation error. +// It receives the error and returns a Fix, or nil if no fix is applicable +// for this particular error instance. +type FixProvider func(err *validation.Error) validation.Fix + +// FixRegistry maps validation rule IDs to fix providers. +// This allows registering fix providers for pre-existing validation errors +// that don't come from linter rules (e.g., errors from unmarshal/indexing). +type FixRegistry struct { + mu sync.RWMutex + providers map[string][]FixProvider +} + +// NewFixRegistry creates a new fix registry. +func NewFixRegistry() *FixRegistry { + return &FixRegistry{ + providers: make(map[string][]FixProvider), + } +} + +// Register registers a fix provider for a validation rule ID. +// Multiple providers can be registered for the same rule ID; the first one +// that returns a non-nil Fix wins. +func (r *FixRegistry) Register(ruleID string, provider FixProvider) { + r.mu.Lock() + defer r.mu.Unlock() + r.providers[ruleID] = append(r.providers[ruleID], provider) +} + +// GetFix returns a Fix for the given validation error, or nil if no provider +// can fix it. +func (r *FixRegistry) GetFix(err *validation.Error) validation.Fix { + r.mu.RLock() + defer r.mu.RUnlock() + + providers, ok := r.providers[err.Rule] + if !ok { + return nil + } + + for _, provider := range providers { + if fix := provider(err); fix != nil { + return fix + } + } + + return nil +} diff --git a/linter/fix/registry_test.go b/linter/fix/registry_test.go new file mode 100644 index 00000000..fdb6698e --- /dev/null +++ b/linter/fix/registry_test.go @@ -0,0 +1,112 @@ +package fix_test + +import ( + "errors" + "testing" + + "github.com/speakeasy-api/openapi/linter/fix" + "github.com/speakeasy-api/openapi/validation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFixRegistry_GetFix_Success(t *testing.T) { + t.Parallel() + + registry := fix.NewFixRegistry() + expectedFix := &mockFix{description: "test fix"} + + registry.Register("validation-empty-value", func(_ *validation.Error) validation.Fix { + return expectedFix + }) + + vErr := &validation.Error{ + UnderlyingError: errors.New("empty value"), + Rule: "validation-empty-value", + } + + result := registry.GetFix(vErr) + require.NotNil(t, result, "should return a fix") + assert.Equal(t, "test fix", result.Description(), "should return the registered fix") +} + +func TestFixRegistry_GetFix_NoProvider(t *testing.T) { + t.Parallel() + + registry := fix.NewFixRegistry() + + vErr := &validation.Error{ + UnderlyingError: errors.New("some error"), + Rule: "unknown-rule", + } + + result := registry.GetFix(vErr) + assert.Nil(t, result, "should return nil for unregistered rules") +} + +func TestFixRegistry_GetFix_ProviderReturnsNil(t *testing.T) { + t.Parallel() + + registry := fix.NewFixRegistry() + registry.Register("test-rule", func(_ *validation.Error) validation.Fix { + return nil + }) + + vErr := &validation.Error{ + UnderlyingError: errors.New("test"), + Rule: "test-rule", + } + + result := registry.GetFix(vErr) + assert.Nil(t, result, "should return nil when provider returns nil") +} + +func TestFixRegistry_GetFix_MultipleProviders(t *testing.T) { + t.Parallel() + + registry := fix.NewFixRegistry() + + // First provider returns nil + registry.Register("test-rule", func(_ *validation.Error) validation.Fix { + return nil + }) + + // Second provider returns a fix + expectedFix := &mockFix{description: "second provider fix"} + registry.Register("test-rule", func(_ *validation.Error) validation.Fix { + return expectedFix + }) + + vErr := &validation.Error{ + UnderlyingError: errors.New("test"), + Rule: "test-rule", + } + + result := registry.GetFix(vErr) + require.NotNil(t, result, "should return fix from second provider") + assert.Equal(t, "second provider fix", result.Description(), "first non-nil provider wins") +} + +func TestFixRegistry_GetFix_FirstProviderWins(t *testing.T) { + t.Parallel() + + registry := fix.NewFixRegistry() + + firstFix := &mockFix{description: "first"} + registry.Register("test-rule", func(_ *validation.Error) validation.Fix { + return firstFix + }) + + registry.Register("test-rule", func(_ *validation.Error) validation.Fix { + return &mockFix{description: "second"} + }) + + vErr := &validation.Error{ + UnderlyingError: errors.New("test"), + Rule: "test-rule", + } + + result := registry.GetFix(vErr) + require.NotNil(t, result, "should return a fix") + assert.Equal(t, "first", result.Description(), "first non-nil provider should win") +} diff --git a/linter/fix/terminal_prompter.go b/linter/fix/terminal_prompter.go new file mode 100644 index 00000000..9773f7a1 --- /dev/null +++ b/linter/fix/terminal_prompter.go @@ -0,0 +1,139 @@ +package fix + +import ( + "bufio" + "fmt" + "io" + "strconv" + "strings" + + "github.com/speakeasy-api/openapi/validation" +) + +// TerminalPrompter implements Prompter using stdin/stdout for terminal interaction. +type TerminalPrompter struct { + reader *bufio.Reader + writer io.Writer +} + +// NewTerminalPrompter creates a new terminal-based prompter. +func NewTerminalPrompter(in io.Reader, out io.Writer) *TerminalPrompter { + return &TerminalPrompter{ + reader: bufio.NewReader(in), + writer: out, + } +} + +// writef writes formatted output to the prompter's writer, ignoring write errors +// since terminal output failures are not recoverable. +func (p *TerminalPrompter) writef(format string, args ...any) { + _, _ = fmt.Fprintf(p.writer, format, args...) +} + +func (p *TerminalPrompter) PromptFix(finding *validation.Error, fix validation.Fix) ([]string, error) { + // Display context about the error + p.writef("\n[%d:%d] %s %s\n", finding.GetLineNumber(), finding.GetColumnNumber(), finding.Rule, finding.UnderlyingError.Error()) + p.writef(" Fix: %s\n", fix.Description()) + + prompts := fix.Prompts() + responses := make([]string, len(prompts)) + + for i, prompt := range prompts { + response, err := p.promptOne(prompt) + if err != nil { + return nil, err + } + responses[i] = response + } + + return responses, nil +} + +func (p *TerminalPrompter) promptOne(prompt validation.Prompt) (string, error) { + switch prompt.Type { + case validation.PromptChoice: + return p.promptChoice(prompt) + case validation.PromptFreeText: + return p.promptFreeText(prompt) + default: + return "", fmt.Errorf("unknown prompt type: %d", prompt.Type) + } +} + +func (p *TerminalPrompter) promptChoice(prompt validation.Prompt) (string, error) { + p.writef(" %s\n", prompt.Message) + for j, choice := range prompt.Choices { + p.writef(" [%d] %s\n", j+1, choice) + } + p.writef(" [s] Skip\n") + + for { + if prompt.Default != "" { + p.writef(" (default: %s) > ", prompt.Default) + } else { + p.writef(" > ") + } + + line, err := p.reader.ReadString('\n') + if err != nil { + return "", fmt.Errorf("reading input: %w", err) + } + line = strings.TrimSpace(line) + + if line == "s" || line == "skip" { + return "", validation.ErrSkipFix + } + + if line == "" && prompt.Default != "" { + return prompt.Default, nil + } + + idx, err := strconv.Atoi(line) + if err != nil || idx < 1 || idx > len(prompt.Choices) { + p.writef(" Invalid choice: %s (enter 1-%d or s to skip)\n", line, len(prompt.Choices)) + continue + } + + return prompt.Choices[idx-1], nil + } +} + +func (p *TerminalPrompter) promptFreeText(prompt validation.Prompt) (string, error) { + p.writef(" %s", prompt.Message) + if prompt.Default != "" { + p.writef(" (default: %s)", prompt.Default) + } + p.writef(" [s to skip]: ") + + line, err := p.reader.ReadString('\n') + if err != nil { + return "", fmt.Errorf("reading input: %w", err) + } + line = strings.TrimSpace(line) + + if line == "s" || line == "skip" { + return "", validation.ErrSkipFix + } + + if line == "" && prompt.Default != "" { + return prompt.Default, nil + } + + if line == "" { + return "", validation.ErrSkipFix + } + + return line, nil +} + +func (p *TerminalPrompter) Confirm(message string) (bool, error) { + p.writef("%s [y/n]: ", message) + + line, err := p.reader.ReadString('\n') + if err != nil { + return false, fmt.Errorf("reading input: %w", err) + } + line = strings.ToLower(strings.TrimSpace(line)) + + return line == "y" || line == "yes", nil +} diff --git a/linter/fix/terminal_prompter_test.go b/linter/fix/terminal_prompter_test.go new file mode 100644 index 00000000..bceccf53 --- /dev/null +++ b/linter/fix/terminal_prompter_test.go @@ -0,0 +1,239 @@ +package fix_test + +import ( + "bytes" + "errors" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/linter/fix" + "github.com/speakeasy-api/openapi/validation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestTerminalPrompter_Choice_Success(t *testing.T) { + t.Parallel() + + input := strings.NewReader("2\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("missing tag"), + Node: &yaml.Node{Line: 10, Column: 5}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "choose a tag", + prompts: []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "Select a tag:", + Choices: []string{"users", "accounts", "admin"}, + }, + }, + } + + responses, err := prompter.PromptFix(finding, f) + require.NoError(t, err, "PromptFix should not fail") + assert.Equal(t, []string{"accounts"}, responses, "should return the selected choice") + assert.Contains(t, output.String(), "choose a tag", "should display fix description") + assert.Contains(t, output.String(), "[1] users", "should display choices") +} + +func TestTerminalPrompter_Choice_Default(t *testing.T) { + t.Parallel() + + input := strings.NewReader("\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "pick", + prompts: []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "Choose:", + Choices: []string{"a", "b"}, + Default: "a", + }, + }, + } + + responses, err := prompter.PromptFix(finding, f) + require.NoError(t, err, "PromptFix should not fail") + assert.Equal(t, []string{"a"}, responses, "should return default") +} + +func TestTerminalPrompter_Choice_Skip(t *testing.T) { + t.Parallel() + + input := strings.NewReader("s\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "pick", + prompts: []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "Choose:", + Choices: []string{"a", "b"}, + }, + }, + } + + _, err := prompter.PromptFix(finding, f) + require.Error(t, err, "should return error on skip") + assert.ErrorIs(t, err, validation.ErrSkipFix, "should return ErrSkipFix") +} + +func TestTerminalPrompter_FreeText_Success(t *testing.T) { + t.Parallel() + + input := strings.NewReader("my description\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("missing description"), + Node: &yaml.Node{Line: 5, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "add description", + prompts: []validation.Prompt{ + { + Type: validation.PromptFreeText, + Message: "Enter description", + }, + }, + } + + responses, err := prompter.PromptFix(finding, f) + require.NoError(t, err, "PromptFix should not fail") + assert.Equal(t, []string{"my description"}, responses, "should return entered text") +} + +func TestTerminalPrompter_FreeText_Skip(t *testing.T) { + t.Parallel() + + input := strings.NewReader("s\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "add value", + prompts: []validation.Prompt{ + { + Type: validation.PromptFreeText, + Message: "Enter value", + }, + }, + } + + _, err := prompter.PromptFix(finding, f) + require.Error(t, err, "should return error on skip") + assert.ErrorIs(t, err, validation.ErrSkipFix, "should return ErrSkipFix") +} + +func TestTerminalPrompter_Choice_InvalidThenValid(t *testing.T) { + t.Parallel() + + // First input "abc" is invalid, then "2" is valid + input := strings.NewReader("abc\n2\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "pick", + prompts: []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "Choose:", + Choices: []string{"a", "b"}, + }, + }, + } + + responses, err := prompter.PromptFix(finding, f) + require.NoError(t, err, "PromptFix should succeed after re-prompt") + assert.Equal(t, []string{"b"}, responses, "should return the choice from second attempt") + assert.Contains(t, output.String(), "Invalid choice", "should show invalid choice message") +} + +func TestTerminalPrompter_Choice_OutOfRangeThenValid(t *testing.T) { + t.Parallel() + + // "99" is out of range, then "1" is valid + input := strings.NewReader("99\n1\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "pick", + prompts: []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "Choose:", + Choices: []string{"x", "y"}, + }, + }, + } + + responses, err := prompter.PromptFix(finding, f) + require.NoError(t, err, "PromptFix should succeed after re-prompt") + assert.Equal(t, []string{"x"}, responses, "should return the choice from second attempt") + assert.Contains(t, output.String(), "Invalid choice", "should show invalid choice message") +} + +func TestTerminalPrompter_Confirm_Yes(t *testing.T) { + t.Parallel() + + input := strings.NewReader("y\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + result, err := prompter.Confirm("Apply fix?") + require.NoError(t, err, "Confirm should not fail") + assert.True(t, result, "should return true for 'y'") +} + +func TestTerminalPrompter_Confirm_No(t *testing.T) { + t.Parallel() + + input := strings.NewReader("n\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + result, err := prompter.Confirm("Apply fix?") + require.NoError(t, err, "Confirm should not fail") + assert.False(t, result, "should return false for 'n'") +} diff --git a/linter/format/format_test.go b/linter/format/format_test.go index 766d7452..fb97e5f2 100644 --- a/linter/format/format_test.go +++ b/linter/format/format_test.go @@ -1,6 +1,7 @@ package format_test import ( + "encoding/json" "errors" "strings" "testing" @@ -12,6 +13,18 @@ import ( "gopkg.in/yaml.v3" ) +// testFix is a minimal validation.Fix for testing formatter output. +type testFix struct { + description string + interactive bool +} + +func (f *testFix) Description() string { return f.description } +func (f *testFix) Interactive() bool { return f.interactive } +func (f *testFix) Prompts() []validation.Prompt { return nil } +func (f *testFix) SetInput([]string) error { return nil } +func (f *testFix) Apply(any) error { return nil } + func TestTextFormatter_Format(t *testing.T) { t.Parallel() @@ -138,3 +151,120 @@ func TestJSONFormatter_Format(t *testing.T) { }) } } + +func TestTextFormatter_FixableMarker(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + fix validation.Fix + shouldHave string + shouldNotHave string + }{ + { + name: "error with fix shows fixable marker", + fix: &testFix{description: "auto fix", interactive: false}, + shouldHave: "[fixable]", + }, + { + name: "error without fix has no fixable marker", + fix: nil, + shouldNotHave: "[fixable]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + errs := []error{ + &validation.Error{ + UnderlyingError: errors.New("test issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Severity: validation.SeverityWarning, + Rule: "test-rule", + Fix: tt.fix, + }, + } + + formatter := format.NewTextFormatter() + result, err := formatter.Format(errs) + require.NoError(t, err) + + if tt.shouldHave != "" { + assert.Contains(t, result, tt.shouldHave, "text output should contain fixable marker") + } + if tt.shouldNotHave != "" { + assert.NotContains(t, result, tt.shouldNotHave, "text output should not contain fixable marker") + } + }) + } +} + +func TestJSONFormatter_FixMetadata(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + fix validation.Fix + expectFix bool + expectInteract bool + }{ + { + name: "non-interactive fix", + fix: &testFix{description: "trim slash", interactive: false}, + expectFix: true, + expectInteract: false, + }, + { + name: "interactive fix", + fix: &testFix{description: "add description", interactive: true}, + expectFix: true, + expectInteract: true, + }, + { + name: "no fix", + fix: nil, + expectFix: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + errs := []error{ + &validation.Error{ + UnderlyingError: errors.New("test issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Severity: validation.SeverityWarning, + Rule: "test-rule", + Fix: tt.fix, + }, + } + + formatter := format.NewJSONFormatter() + result, err := formatter.Format(errs) + require.NoError(t, err) + + var output struct { + Results []struct { + Fix *struct { + Description string `json:"description"` + Interactive bool `json:"interactive,omitempty"` + } `json:"fix,omitempty"` + } `json:"results"` + } + require.NoError(t, json.Unmarshal([]byte(result), &output), "should be valid JSON") + require.Len(t, output.Results, 1, "should have one result") + + if tt.expectFix { + require.NotNil(t, output.Results[0].Fix, "should have fix metadata") + assert.Equal(t, tt.fix.Description(), output.Results[0].Fix.Description, "should have correct description") + assert.Equal(t, tt.expectInteract, output.Results[0].Fix.Interactive, "should have correct interactive flag") + } else { + assert.Nil(t, output.Results[0].Fix, "should not have fix metadata") + } + }) + } +} diff --git a/linter/format/json.go b/linter/format/json.go index 6a60e199..d908d1e7 100644 --- a/linter/format/json.go +++ b/linter/format/json.go @@ -37,6 +37,7 @@ type jsonLocation struct { type jsonFix struct { Description string `json:"description"` + Interactive bool `json:"interactive,omitempty"` } type jsonSummary struct { @@ -76,7 +77,8 @@ func (f *JSONFormatter) Format(results []error) (string, error) { if vErr.Fix != nil { result.Fix = &jsonFix{ - Description: vErr.Fix.FixDescription(), + Description: vErr.Fix.Description(), + Interactive: vErr.Fix.Interactive(), } } diff --git a/linter/format/summary.go b/linter/format/summary.go new file mode 100644 index 00000000..40345873 --- /dev/null +++ b/linter/format/summary.go @@ -0,0 +1,105 @@ +package format + +import ( + "errors" + "fmt" + "sort" + "strings" + + "github.com/speakeasy-api/openapi/validation" +) + +// SummaryFormatter formats results as a per-rule summary table. +type SummaryFormatter struct{} + +// NewSummaryFormatter creates a new SummaryFormatter. +func NewSummaryFormatter() *SummaryFormatter { + return &SummaryFormatter{} +} + +type ruleSummary struct { + rule string + category string + severity validation.Severity + count int +} + +// Format outputs a per-rule summary table sorted by count descending. +func (f *SummaryFormatter) Format(results []error) (string, error) { + byRule := make(map[string]*ruleSummary) + + errorCount := 0 + warningCount := 0 + hintCount := 0 + + for _, err := range results { + var vErr *validation.Error + if errors.As(err, &vErr) { + rs, ok := byRule[vErr.Rule] + if !ok { + category := "unknown" + if idx := strings.Index(vErr.Rule, "-"); idx > 0 { + category = vErr.Rule[:idx] + } + rs = &ruleSummary{ + rule: vErr.Rule, + category: category, + severity: vErr.Severity, + } + byRule[vErr.Rule] = rs + } + rs.count++ + + switch vErr.Severity { + case validation.SeverityError: + errorCount++ + case validation.SeverityWarning: + warningCount++ + case validation.SeverityHint: + hintCount++ + } + } else { + rs, ok := byRule["internal"] + if !ok { + rs = &ruleSummary{ + rule: "internal", + category: "internal", + severity: validation.SeverityError, + } + byRule["internal"] = rs + } + rs.count++ + errorCount++ + } + } + + // Sort by count descending, then by rule name + sorted := make([]*ruleSummary, 0, len(byRule)) + for _, rs := range byRule { + sorted = append(sorted, rs) + } + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].count != sorted[j].count { + return sorted[i].count > sorted[j].count + } + return sorted[i].rule < sorted[j].rule + }) + + var sb strings.Builder + + // Header + fmt.Fprintf(&sb, "%-50s %8s %10s %8s\n", "Rule", "Severity", "Category", "Count") + sb.WriteString(strings.Repeat("─", 80)) + sb.WriteString("\n") + + for _, rs := range sorted { + fmt.Fprintf(&sb, "%-50s %8s %10s %8d\n", rs.rule, rs.severity, rs.category, rs.count) + } + + sb.WriteString(strings.Repeat("─", 80)) + sb.WriteString("\n") + fmt.Fprintf(&sb, "✖ %d problems (%d errors, %d warnings, %d hints) across %d rules\n", + len(results), errorCount, warningCount, hintCount, len(byRule)) + + return sb.String(), nil +} diff --git a/linter/format/text.go b/linter/format/text.go index cdf30f8c..9974778f 100644 --- a/linter/format/text.go +++ b/linter/format/text.go @@ -33,7 +33,12 @@ func (f *TextFormatter) Format(results []error) (string, error) { msg = fmt.Sprintf("%s (document: %s)", msg, vErr.DocumentLocation) } - sb.WriteString(fmt.Sprintf("%d:%d\t%s\t%s\t%s\n", line, col, severity, rule, msg)) + fixable := "" + if vErr.Fix != nil { + fixable = " [fixable]" + } + + fmt.Fprintf(&sb, "%d:%d\t%s\t%s\t%s%s\n", line, col, severity, rule, msg, fixable) switch severity { case validation.SeverityError: diff --git a/linter/linter.go b/linter/linter.go index 8d067bda..780cff61 100644 --- a/linter/linter.go +++ b/linter/linter.go @@ -405,3 +405,9 @@ func (o *Output) FormatJSON() string { s, _ := f.Format(o.Results) return s } + +func (o *Output) FormatSummary() string { + f := format.NewSummaryFormatter() + s, _ := f.Format(o.Results) + return s +} diff --git a/openapi/linter/customrules/go.mod b/openapi/linter/customrules/go.mod index 2d46926f..f659bde1 100644 --- a/openapi/linter/customrules/go.mod +++ b/openapi/linter/customrules/go.mod @@ -2,6 +2,8 @@ module github.com/speakeasy-api/openapi/openapi/linter/customrules go 1.24.3 +replace github.com/speakeasy-api/openapi => ../../../ + require ( github.com/dop251/goja v0.0.0-20260106131823-651366fbe6e3 github.com/evanw/esbuild v0.27.2 diff --git a/openapi/linter/customrules/go.sum b/openapi/linter/customrules/go.sum index 273eae57..dac72d6d 100644 --- a/openapi/linter/customrules/go.sum +++ b/openapi/linter/customrules/go.sum @@ -25,8 +25,6 @@ github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEV github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU= github.com/speakeasy-api/jsonpath v0.6.2 h1:Mys71yd6u8kuowNCR0gCVPlVAHCmKtoGXYoAtcEbqXQ= github.com/speakeasy-api/jsonpath v0.6.2/go.mod h1:ymb2iSkyOycmzKwbEAYPJV/yi2rSmvBCLZJcyD+VVWw= -github.com/speakeasy-api/openapi v1.15.2-0.20260206122839-792c2b51b2a8 h1:0KCP/92O9NMTus758qW5HEYS2NlcGEoJSffRq0USerU= -github.com/speakeasy-api/openapi v1.15.2-0.20260206122839-792c2b51b2a8/go.mod h1:aiVj+JnirrwZDtKegt0hQrj/ixl3v17EkN2YGnTuSro= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= diff --git a/openapi/linter/customrules/jsfix.go b/openapi/linter/customrules/jsfix.go new file mode 100644 index 00000000..1d95fbb8 --- /dev/null +++ b/openapi/linter/customrules/jsfix.go @@ -0,0 +1,164 @@ +package customrules + +import ( + "fmt" + "time" + + "github.com/dop251/goja" + "github.com/speakeasy-api/openapi/validation" +) + +// JSFix bridges a JavaScript fix object to the Go Fix interface. +type JSFix struct { + rt *Runtime + config *Config + description string + interactive bool + prompts []validation.Prompt + applyFn goja.Callable + jsFixObj goja.Value + inputs []string +} + +func (f *JSFix) Description() string { return f.description } +func (f *JSFix) Interactive() bool { return f.interactive } +func (f *JSFix) Prompts() []validation.Prompt { return f.prompts } + +func (f *JSFix) SetInput(responses []string) error { + if len(responses) != len(f.prompts) { + return fmt.Errorf("expected %d responses, got %d", len(f.prompts), len(responses)) + } + f.inputs = responses + return nil +} + +func (f *JSFix) Apply(doc any) error { + // Set up timeout using the same pattern as rule.Run() + timeout := f.config.GetTimeout() + timer := time.AfterFunc(timeout, func() { + f.rt.Interrupt("fix apply timeout exceeded") + }) + defer timer.Stop() + defer f.rt.ClearInterrupt() + + args := []goja.Value{f.rt.ToValue(doc)} + + if f.interactive && len(f.inputs) > 0 { + jsInputs := make([]interface{}, len(f.inputs)) + for i, input := range f.inputs { + jsInputs[i] = input + } + args = append(args, f.rt.ToValue(jsInputs)) + } + + _, err := f.applyFn(f.jsFixObj, args...) + if err != nil { + return fmt.Errorf("fix apply error: %w", err) + } + return nil +} + +// newJSFix creates a JSFix from a JavaScript options object. +// Expected JS shape: { description: string, interactive?: bool, prompts?: [...], apply: (doc, inputs?) => void } +func newJSFix(rt *Runtime, config *Config, optionsVal goja.Value) (*JSFix, error) { + obj := optionsVal.ToObject(rt.vm) + if obj == nil { + return nil, fmt.Errorf("createFix: argument must be an object") + } + + // description (required) + descVal := obj.Get("description") + if descVal == nil || goja.IsUndefined(descVal) { + return nil, fmt.Errorf("createFix: description is required") + } + + // apply (required) + applyVal := obj.Get("apply") + if applyVal == nil || goja.IsUndefined(applyVal) { + return nil, fmt.Errorf("createFix: apply function is required") + } + applyFn, ok := goja.AssertFunction(applyVal) + if !ok { + return nil, fmt.Errorf("createFix: apply must be a function") + } + + fix := &JSFix{ + rt: rt, + config: config, + description: descVal.String(), + applyFn: applyFn, + jsFixObj: optionsVal, + } + + // interactive (optional) + interVal := obj.Get("interactive") + if interVal != nil && !goja.IsUndefined(interVal) { + fix.interactive = interVal.ToBoolean() + } + + // prompts (optional) + promptsVal := obj.Get("prompts") + if promptsVal != nil && !goja.IsUndefined(promptsVal) && !goja.IsNull(promptsVal) { + prompts, err := parseJSPrompts(promptsVal) + if err != nil { + return nil, fmt.Errorf("createFix: %w", err) + } + fix.prompts = prompts + } + + return fix, nil +} + +// parseJSPrompts converts a JS array of prompt objects to Go Prompt slice. +func parseJSPrompts(val goja.Value) ([]validation.Prompt, error) { + exported := val.Export() + arr, ok := exported.([]interface{}) + if !ok { + return nil, fmt.Errorf("prompts must be an array") + } + + prompts := make([]validation.Prompt, 0, len(arr)) + for i, item := range arr { + m, ok := item.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("prompts[%d] must be an object", i) + } + + prompt := validation.Prompt{} + + // type + if t, ok := m["type"].(string); ok { + switch t { + case "choice": + prompt.Type = validation.PromptChoice + case "text": + prompt.Type = validation.PromptFreeText + default: + return nil, fmt.Errorf("prompts[%d]: unknown type %q (use \"choice\" or \"text\")", i, t) + } + } + + // message + if msg, ok := m["message"].(string); ok { + prompt.Message = msg + } + + // default + if def, ok := m["default"].(string); ok { + prompt.Default = def + } + + // choices (for choice type) + if choices, ok := m["choices"].([]interface{}); ok { + for _, c := range choices { + if s, ok := c.(string); ok { + prompt.Choices = append(prompt.Choices, s) + } + } + } + + prompts = append(prompts, prompt) + } + + return prompts, nil +} diff --git a/openapi/linter/customrules/jsfix_test.go b/openapi/linter/customrules/jsfix_test.go new file mode 100644 index 00000000..2b7ed2d7 --- /dev/null +++ b/openapi/linter/customrules/jsfix_test.go @@ -0,0 +1,149 @@ +package customrules_test + +import ( + "testing" + + "github.com/speakeasy-api/openapi/openapi/linter/customrules" + "github.com/speakeasy-api/openapi/validation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSFix_NonInteractive(t *testing.T) { + t.Parallel() + + rt, err := customrules.NewRuntime(&testLogger{}, nil) + require.NoError(t, err, "creating runtime should succeed") + + // Create a non-interactive fix via JS + _, err = rt.RunScript("test", ` + var fix = createFix({ + description: "remove trailing slash", + apply: function(doc) { + // no-op for test + } + }); + `) + require.NoError(t, err, "creating fix should succeed") + + // Create an error with fix via JS + result, err := rt.RunScript("test2", ` + var err = createValidationErrorWithFix("warning", "test-rule", "has trailing slash", null, fix); + err; + `) + require.NoError(t, err, "creating error with fix should succeed") + + exported := result.Export() + vErr, ok := exported.(*validation.Error) + require.True(t, ok, "result should be a validation.Error") + require.NotNil(t, vErr.Fix, "error should have a fix attached") + assert.Equal(t, "remove trailing slash", vErr.Fix.Description(), "fix description should match") + assert.False(t, vErr.Fix.Interactive(), "fix should be non-interactive") + assert.Nil(t, vErr.Fix.Prompts(), "non-interactive fix should have no prompts") +} + +func TestJSFix_Interactive(t *testing.T) { + t.Parallel() + + rt, err := customrules.NewRuntime(&testLogger{}, nil) + require.NoError(t, err, "creating runtime should succeed") + + _, err = rt.RunScript("test", ` + var fix = createFix({ + description: "add description", + interactive: true, + prompts: [ + { type: "text", message: "Enter a description", "default": "A sample API" }, + { type: "choice", message: "Pick a tag", choices: ["users", "admin"] } + ], + apply: function(doc, inputs) { + // no-op for test + } + }); + `) + require.NoError(t, err, "creating interactive fix should succeed") + + result, err := rt.RunScript("test2", ` + var err = createValidationErrorWithFix("warning", "test-rule", "missing description", null, fix); + err; + `) + require.NoError(t, err, "creating error with interactive fix should succeed") + + exported := result.Export() + vErr, ok := exported.(*validation.Error) + require.True(t, ok, "result should be a validation.Error") + require.NotNil(t, vErr.Fix, "error should have a fix attached") + + fix := vErr.Fix + assert.True(t, fix.Interactive(), "fix should be interactive") + assert.Len(t, fix.Prompts(), 2, "fix should have 2 prompts") + + prompts := fix.Prompts() + assert.Equal(t, validation.PromptFreeText, prompts[0].Type, "first prompt should be free text") + assert.Equal(t, "Enter a description", prompts[0].Message, "first prompt message should match") + assert.Equal(t, "A sample API", prompts[0].Default, "first prompt default should match") + assert.Equal(t, validation.PromptChoice, prompts[1].Type, "second prompt should be choice") + assert.Equal(t, []string{"users", "admin"}, prompts[1].Choices, "second prompt choices should match") +} + +func TestJSFix_SetInput(t *testing.T) { + t.Parallel() + + rt, err := customrules.NewRuntime(&testLogger{}, nil) + require.NoError(t, err, "creating runtime should succeed") + + _, err = rt.RunScript("test", ` + var fix = createFix({ + description: "needs input", + interactive: true, + prompts: [{ type: "text", message: "Enter value" }], + apply: function(doc, inputs) {} + }); + `) + require.NoError(t, err, "creating fix should succeed") + + result, err := rt.RunScript("test2", ` + createValidationErrorWithFix("warning", "test-rule", "msg", null, fix); + `) + require.NoError(t, err) + + vErr := result.Export().(*validation.Error) + fix := vErr.Fix + + // Wrong number of inputs + require.Error(t, fix.SetInput([]string{"a", "b"}), "SetInput with wrong count should fail") + + // Correct number of inputs + require.NoError(t, fix.SetInput([]string{"my value"}), "SetInput with correct count should succeed") +} + +func TestJSFix_CreateFix_MissingDescription(t *testing.T) { + t.Parallel() + + rt, err := customrules.NewRuntime(&testLogger{}, nil) + require.NoError(t, err, "creating runtime should succeed") + + _, err = rt.RunScript("test", ` + createFix({ apply: function(doc) {} }); + `) + require.Error(t, err, "createFix without description should fail") +} + +func TestJSFix_CreateFix_MissingApply(t *testing.T) { + t.Parallel() + + rt, err := customrules.NewRuntime(&testLogger{}, nil) + require.NoError(t, err, "creating runtime should succeed") + + _, err = rt.RunScript("test", ` + createFix({ description: "test" }); + `) + require.Error(t, err, "createFix without apply should fail") +} + +// testLogger is a simple logger for testing. +type testLogger struct{} + +func (l *testLogger) Log(args ...any) {} +func (l *testLogger) Warn(args ...any) {} +func (l *testLogger) Error(args ...any) {} diff --git a/openapi/linter/customrules/loader.go b/openapi/linter/customrules/loader.go index 1028d837..0f80901e 100644 --- a/openapi/linter/customrules/loader.go +++ b/openapi/linter/customrules/loader.go @@ -267,7 +267,7 @@ func (l *Loader) loadRules(transpiled []*TranspiledRule, config *Config) ([]base for _, tr := range transpiled { // Create a new runtime for each rule file // (goja runtimes are not thread-safe) - rt, err := NewRuntime(config.GetLogger()) + rt, err := NewRuntime(config.GetLogger(), config) if err != nil { return nil, fmt.Errorf("creating runtime for %q: %w", tr.SourceFile, err) } diff --git a/openapi/linter/customrules/runtime.go b/openapi/linter/customrules/runtime.go index f2d40d1c..6b906b81 100644 --- a/openapi/linter/customrules/runtime.go +++ b/openapi/linter/customrules/runtime.go @@ -16,13 +16,14 @@ import ( type Runtime struct { vm *goja.Runtime logger Logger + config *Config // registeredRules holds rules registered via registerRule() registeredRules []goja.Value } // NewRuntime creates a new JavaScript runtime configured for custom rules. -func NewRuntime(logger Logger) (*Runtime, error) { +func NewRuntime(logger Logger, config *Config) (*Runtime, error) { vm := goja.New() // Use uncapitalized field/method mapper for JS-style naming @@ -32,6 +33,7 @@ func NewRuntime(logger Logger) (*Runtime, error) { rt := &Runtime{ vm: vm, logger: logger, + config: config, } // Set up console object @@ -102,6 +104,16 @@ func (rt *Runtime) setupGlobals() error { return err } + // createFix(options) - creates a fix object for attaching to validation errors + if err := rt.vm.Set("createFix", rt.createFix); err != nil { + return err + } + + // createValidationErrorWithFix(severity, ruleId, message, node, fix) - creates a validation error with a fix + if err := rt.vm.Set("createValidationErrorWithFix", rt.createValidationErrorWithFix); err != nil { + return err + } + return nil } @@ -138,6 +150,55 @@ func (rt *Runtime) createValidationError(call goja.FunctionCall) goja.Value { return rt.vm.ToValue(err) } +// createFix is the JS-callable function for creating a fix object. +func (rt *Runtime) createFix(call goja.FunctionCall) goja.Value { + if len(call.Arguments) < 1 { + panic(rt.vm.ToValue("createFix requires an options argument")) + } + + fix, err := newJSFix(rt, rt.config, call.Arguments[0]) + if err != nil { + panic(rt.vm.ToValue(err.Error())) + } + + return rt.vm.ToValue(fix) +} + +// createValidationErrorWithFix is the JS-callable function for creating validation errors with fixes. +func (rt *Runtime) createValidationErrorWithFix(call goja.FunctionCall) goja.Value { + if len(call.Arguments) < 5 { + panic(rt.vm.ToValue("createValidationErrorWithFix requires 5 arguments: severity, ruleId, message, node, fix")) + } + + severityStr := call.Arguments[0].String() + ruleID := call.Arguments[1].String() + message := call.Arguments[2].String() + nodeVal := call.Arguments[3].Export() + fixVal := call.Arguments[4].Export() + + severity := parseSeverity(severityStr) + + var node *yaml.Node + if nodeVal != nil { + if n, ok := nodeVal.(*yaml.Node); ok { + node = n + } + } + + vErr := &validation.Error{ + UnderlyingError: errors.New(message), + Node: node, + Severity: severity, + Rule: ruleID, + } + + if fix, ok := fixVal.(validation.Fix); ok { + vErr.Fix = fix + } + + return rt.vm.ToValue(vErr) +} + // parseSeverity converts a string to validation.Severity. func parseSeverity(s string) validation.Severity { switch s { diff --git a/openapi/linter/customrules/shim/types-shim.js b/openapi/linter/customrules/shim/types-shim.js index 85a78c14..942053b0 100644 --- a/openapi/linter/customrules/shim/types-shim.js +++ b/openapi/linter/customrules/shim/types-shim.js @@ -6,6 +6,8 @@ // These are defined in runtime.go's setupGlobals() var registerRule = globalThis.registerRule; var createValidationError = globalThis.createValidationError; +var createFix = globalThis.createFix; +var createValidationErrorWithFix = globalThis.createValidationErrorWithFix; // Base Rule class - users can extend this or implement RuleRunner directly export class Rule { @@ -20,4 +22,4 @@ export class Rule { } // Export the globals so user rules can import them -export { registerRule, createValidationError }; +export { registerRule, createValidationError, createFix, createValidationErrorWithFix }; diff --git a/openapi/linter/rules/component_description.go b/openapi/linter/rules/component_description.go index 0c52ab3b..98a55475 100644 --- a/openapi/linter/rules/component_description.go +++ b/openapi/linter/rules/component_description.go @@ -76,12 +76,13 @@ func (r *ComponentDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu description := schema.GetDescription() if description == "" { node := componentsCore.Schemas.GetMapKeyNodeOrRoot(schemaKey, componentsRoot) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleComponentDescription, - fmt.Errorf("`schemas` component `%s` is missing a description", schemaKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("`schemas` component `%s` is missing a description", schemaKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleComponentDescription, + Fix: &addDescriptionFix{targetNode: jsonSchema.GetRootNode(), targetLabel: "schema '" + schemaKey + "'"}, + }) } } } @@ -99,12 +100,13 @@ func (r *ComponentDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu description := param.GetDescription() if description == "" { node := componentsCore.Parameters.GetMapKeyNodeOrRoot(paramKey, componentsRoot) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleComponentDescription, - fmt.Errorf("`parameters` component `%s` is missing a description", paramKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("`parameters` component `%s` is missing a description", paramKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleComponentDescription, + Fix: &addDescriptionFix{targetNode: param.GetRootNode(), targetLabel: "parameter '" + paramKey + "'"}, + }) } } } @@ -122,12 +124,13 @@ func (r *ComponentDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu description := rb.GetDescription() if description == "" { node := componentsCore.RequestBodies.GetMapKeyNodeOrRoot(rbKey, componentsRoot) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleComponentDescription, - fmt.Errorf("`requestBodies` component `%s` is missing a description", rbKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("`requestBodies` component `%s` is missing a description", rbKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleComponentDescription, + Fix: &addDescriptionFix{targetNode: rb.GetRootNode(), targetLabel: "requestBody '" + rbKey + "'"}, + }) } } } @@ -145,12 +148,13 @@ func (r *ComponentDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu description := resp.GetDescription() if description == "" { node := componentsCore.Responses.GetMapKeyNodeOrRoot(respKey, componentsRoot) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleComponentDescription, - fmt.Errorf("`responses` component `%s` is missing a description", respKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("`responses` component `%s` is missing a description", respKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleComponentDescription, + Fix: &addDescriptionFix{targetNode: resp.GetRootNode(), targetLabel: "response '" + respKey + "'"}, + }) } } } @@ -168,12 +172,13 @@ func (r *ComponentDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu description := ex.GetDescription() if description == "" { node := componentsCore.Examples.GetMapKeyNodeOrRoot(exKey, componentsRoot) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleComponentDescription, - fmt.Errorf("`examples` component `%s` is missing a description", exKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("`examples` component `%s` is missing a description", exKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleComponentDescription, + Fix: &addDescriptionFix{targetNode: ex.GetRootNode(), targetLabel: "example '" + exKey + "'"}, + }) } } } @@ -191,12 +196,13 @@ func (r *ComponentDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu description := header.GetDescription() if description == "" { node := componentsCore.Headers.GetMapKeyNodeOrRoot(headerKey, componentsRoot) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleComponentDescription, - fmt.Errorf("`headers` component `%s` is missing a description", headerKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("`headers` component `%s` is missing a description", headerKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleComponentDescription, + Fix: &addDescriptionFix{targetNode: header.GetRootNode(), targetLabel: "header '" + headerKey + "'"}, + }) } } } @@ -214,12 +220,13 @@ func (r *ComponentDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu description := link.GetDescription() if description == "" { node := componentsCore.Links.GetMapKeyNodeOrRoot(linkKey, componentsRoot) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleComponentDescription, - fmt.Errorf("`links` component `%s` is missing a description", linkKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("`links` component `%s` is missing a description", linkKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleComponentDescription, + Fix: &addDescriptionFix{targetNode: link.GetRootNode(), targetLabel: "link '" + linkKey + "'"}, + }) } } } @@ -237,12 +244,13 @@ func (r *ComponentDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu description := ss.GetDescription() if description == "" { node := componentsCore.SecuritySchemes.GetMapKeyNodeOrRoot(ssKey, componentsRoot) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleComponentDescription, - fmt.Errorf("`securitySchemes` component `%s` is missing a description", ssKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("`securitySchemes` component `%s` is missing a description", ssKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleComponentDescription, + Fix: &addDescriptionFix{targetNode: ss.GetRootNode(), targetLabel: "securityScheme '" + ssKey + "'"}, + }) } } } diff --git a/openapi/linter/rules/contact_properties.go b/openapi/linter/rules/contact_properties.go index 22889e4d..93141329 100644 --- a/openapi/linter/rules/contact_properties.go +++ b/openapi/linter/rules/contact_properties.go @@ -67,31 +67,36 @@ func (r *ContactPropertiesRule) Run(ctx context.Context, docInfo *linter.Documen url := contact.GetURL() email := contact.GetEmail() + contactRoot := contact.GetRootNode() + if name == "" { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleContactProperties, - errors.New("`contact` section must contain a `name`"), - contact.GetRootNode(), - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("`contact` section must contain a `name`"), + Node: contactRoot, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleContactProperties, + Fix: &addContactPropertyFix{contactNode: contactRoot, property: "name"}, + }) } if url == "" { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleContactProperties, - errors.New("`contact` section must contain a `url`"), - contact.GetRootNode(), - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("`contact` section must contain a `url`"), + Node: contactRoot, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleContactProperties, + Fix: &addContactPropertyFix{contactNode: contactRoot, property: "url"}, + }) } if email == "" { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleContactProperties, - errors.New("`contact` section must contain an `email`"), - contact.GetRootNode(), - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("`contact` section must contain an `email`"), + Node: contactRoot, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleContactProperties, + Fix: &addContactPropertyFix{contactNode: contactRoot, property: "email"}, + }) } return errs diff --git a/openapi/linter/rules/duplicated_entry_in_enum.go b/openapi/linter/rules/duplicated_entry_in_enum.go index 81b554c0..5cf94507 100644 --- a/openapi/linter/rules/duplicated_entry_in_enum.go +++ b/openapi/linter/rules/duplicated_entry_in_enum.go @@ -57,19 +57,20 @@ func (r *DuplicatedEnumRule) Run(ctx context.Context, docInfo *linter.DocumentIn } // Check for duplicates + coreSchema := schema.GetCore() duplicateIndices := findDuplicateIndices(enumValues) for _, indices := range duplicateIndices { // Report on first duplicate occurrence (second index in the list) if len(indices) > 1 { displayValue := nodeToDisplayString(enumValues[indices[1]]) - errs = append(errs, validation.NewSliceError( - config.GetSeverity(r.DefaultSeverity()), - RuleSemanticDuplicatedEnum, - fmt.Errorf("enum contains a duplicate: `%s`", displayValue), - schema.GetCore(), - schema.GetCore().Enum, - indices[1], // Report at second occurrence - )) + errNode := coreSchema.Enum.GetSliceValueNodeOrRoot(indices[1], refSchema.GetRootNode()) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("enum contains a duplicate: `%s`", displayValue), + Node: errNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleSemanticDuplicatedEnum, + Fix: &removeDuplicateEnumFix{enumNode: coreSchema.Enum.ValueNode, duplicateIndices: indices[1:]}, + }) } } } @@ -97,6 +98,32 @@ func findDuplicateIndices(enumValues []*yaml.Node) map[string][]int { return duplicates } +// removeDuplicateEnumFix removes duplicate entries from an enum sequence node. +type removeDuplicateEnumFix struct { + enumNode *yaml.Node // the sequence node containing enum values + duplicateIndices []int // indices of duplicate entries to remove +} + +func (f *removeDuplicateEnumFix) Description() string { return "Remove duplicate enum entries" } +func (f *removeDuplicateEnumFix) Interactive() bool { return false } +func (f *removeDuplicateEnumFix) Prompts() []validation.Prompt { return nil } +func (f *removeDuplicateEnumFix) SetInput([]string) error { return nil } +func (f *removeDuplicateEnumFix) Apply(doc any) error { return nil } + +func (f *removeDuplicateEnumFix) ApplyNode(_ *yaml.Node) error { + if f.enumNode == nil || len(f.duplicateIndices) == 0 { + return nil + } + // Remove from last index first to preserve earlier indices + for i := len(f.duplicateIndices) - 1; i >= 0; i-- { + idx := f.duplicateIndices[i] + if idx < len(f.enumNode.Content) { + f.enumNode.Content = append(f.enumNode.Content[:idx], f.enumNode.Content[idx+1:]...) + } + } + return nil +} + // nodeToString converts a yaml.Node to a string representation for comparison // This includes type prefixes to distinguish between different types of the same value func nodeToString(node *yaml.Node) string { diff --git a/openapi/linter/rules/fix_available.go b/openapi/linter/rules/fix_available.go new file mode 100644 index 00000000..6455ae3e --- /dev/null +++ b/openapi/linter/rules/fix_available.go @@ -0,0 +1,37 @@ +package rules + +// FixAvailable returns true for rules that provide auto-fix suggestions. +// This satisfies the linter.DocumentedRule interface's FixAvailable() method. + +func (r *PathTrailingSlashRule) FixAvailable() bool { return true } +func (r *OAS3HostTrailingSlashRule) FixAvailable() bool { return true } +func (r *OwaspSecurityHostsHttpsOAS3Rule) FixAvailable() bool { return true } +func (r *DuplicatedEnumRule) FixAvailable() bool { return true } +func (r *OAS3NoNullableRule) FixAvailable() bool { return true } +func (r *OperationTagDefinedRule) FixAvailable() bool { return true } +func (r *TagsAlphabeticalRule) FixAvailable() bool { return true } +func (r *OwaspJWTBestPracticesRule) FixAvailable() bool { return true } +func (r *OwaspNoAdditionalPropertiesRule) FixAvailable() bool { return true } +func (r *OwaspDefineErrorResponses401Rule) FixAvailable() bool { return true } +func (r *OwaspDefineErrorResponses429Rule) FixAvailable() bool { return true } +func (r *OwaspDefineErrorResponses500Rule) FixAvailable() bool { return true } +func (r *OwaspDefineErrorValidationRule) FixAvailable() bool { return true } +func (r *OwaspRateLimitRetryAfterRule) FixAvailable() bool { return true } +func (r *InfoDescriptionRule) FixAvailable() bool { return true } +func (r *InfoContactRule) FixAvailable() bool { return true } +func (r *InfoLicenseRule) FixAvailable() bool { return true } +func (r *LicenseURLRule) FixAvailable() bool { return true } +func (r *ComponentDescriptionRule) FixAvailable() bool { return true } +func (r *TagDescriptionRule) FixAvailable() bool { return true } +func (r *OperationDescriptionRule) FixAvailable() bool { return true } +func (r *OAS3ParameterDescriptionRule) FixAvailable() bool { return true } +func (r *OperationTagsRule) FixAvailable() bool { return true } +func (r *ContactPropertiesRule) FixAvailable() bool { return true } +func (r *OAS3HostNotExampleRule) FixAvailable() bool { return true } +func (r *OAS3APIServersRule) FixAvailable() bool { return true } +func (r *OwaspIntegerFormatRule) FixAvailable() bool { return true } +func (r *OwaspStringLimitRule) FixAvailable() bool { return true } +func (r *OwaspArrayLimitRule) FixAvailable() bool { return true } +func (r *OwaspIntegerLimitRule) FixAvailable() bool { return true } +func (r *OwaspAdditionalPropertiesConstrainedRule) FixAvailable() bool { return true } +func (r *UnusedComponentRule) FixAvailable() bool { return true } diff --git a/openapi/linter/rules/fix_helpers.go b/openapi/linter/rules/fix_helpers.go new file mode 100644 index 00000000..509f9c86 --- /dev/null +++ b/openapi/linter/rules/fix_helpers.go @@ -0,0 +1,574 @@ +package rules + +import ( + "context" + "fmt" + "strconv" + + "github.com/speakeasy-api/openapi/openapi" + "github.com/speakeasy-api/openapi/validation" + "github.com/speakeasy-api/openapi/yml" + "gopkg.in/yaml.v3" +) + +// addErrorResponseFix adds a skeleton error response to an operation's responses node. +type addErrorResponseFix struct { + responsesNode *yaml.Node // the responses mapping node + statusCode string // e.g. "401", "429", "500", "400" + description string // e.g. "Unauthorized" +} + +func (f *addErrorResponseFix) Description() string { + return "Add " + f.statusCode + " response: " + f.description +} +func (f *addErrorResponseFix) Interactive() bool { return false } +func (f *addErrorResponseFix) Prompts() []validation.Prompt { return nil } +func (f *addErrorResponseFix) SetInput([]string) error { return nil } +func (f *addErrorResponseFix) Apply(doc any) error { return nil } + +func (f *addErrorResponseFix) ApplyNode(_ *yaml.Node) error { + if f.responsesNode == nil || f.responsesNode.Kind != yaml.MappingNode { + return nil + } + + ctx := context.Background() + + // Idempotency: check if status code already exists + _, _, found := yml.GetMapElementNodes(ctx, f.responsesNode, f.statusCode) + if found { + return nil + } + + // Create: "statusCode": { description: "..." } + responseNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("description"), + yml.CreateStringNode(f.description), + }) + + yml.CreateOrUpdateMapNodeElement(ctx, f.statusCode, nil, responseNode, f.responsesNode) + return nil +} + +// addRetryAfterHeaderFix adds a Retry-After header to a 429 response node. +type addRetryAfterHeaderFix struct { + responseNode *yaml.Node // the 429 response mapping node +} + +func (f *addRetryAfterHeaderFix) Description() string { + return "Add Retry-After header to 429 response" +} +func (f *addRetryAfterHeaderFix) Interactive() bool { return false } +func (f *addRetryAfterHeaderFix) Prompts() []validation.Prompt { return nil } +func (f *addRetryAfterHeaderFix) SetInput([]string) error { return nil } +func (f *addRetryAfterHeaderFix) Apply(doc any) error { return nil } + +func (f *addRetryAfterHeaderFix) ApplyNode(_ *yaml.Node) error { + if f.responseNode == nil || f.responseNode.Kind != yaml.MappingNode { + return nil + } + + ctx := context.Background() + + // Check if headers already exists + _, headersNode, found := yml.GetMapElementNodes(ctx, f.responseNode, "headers") + if !found || headersNode == nil { + // Create headers mapping + headersNode = yml.CreateMapNode(ctx, nil) + yml.CreateOrUpdateMapNodeElement(ctx, "headers", nil, headersNode, f.responseNode) + } + + // Idempotency: check if Retry-After already exists + _, _, found = yml.GetMapElementNodes(ctx, headersNode, "Retry-After") + if found { + return nil + } + + // Create the Retry-After header: + // Retry-After: + // description: "Number of seconds to wait before retrying" + // schema: + // type: integer + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("integer"), + }) + headerNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("description"), + yml.CreateStringNode("Number of seconds to wait before retrying"), + yml.CreateStringNode("schema"), + schemaNode, + }) + + yml.CreateOrUpdateMapNodeElement(ctx, "Retry-After", nil, headerNode, headersNode) + return nil +} + +// addDescriptionFix is an interactive fix that prompts for a description and sets it on a YAML mapping node. +type addDescriptionFix struct { + targetNode *yaml.Node // the mapping node to add/update "description" on + targetLabel string // human-readable label e.g. "tag 'users'", "operation GET /pets" + description string // filled by SetInput +} + +func (f *addDescriptionFix) Description() string { + return "Add description to " + f.targetLabel +} +func (f *addDescriptionFix) Interactive() bool { return true } +func (f *addDescriptionFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + { + Type: validation.PromptFreeText, + Message: "Enter description for " + f.targetLabel, + }, + } +} + +func (f *addDescriptionFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.description = responses[0] + return nil +} + +func (f *addDescriptionFix) Apply(doc any) error { return nil } + +func (f *addDescriptionFix) ApplyNode(_ *yaml.Node) error { + if f.targetNode == nil || f.targetNode.Kind != yaml.MappingNode || f.description == "" { + return nil + } + ctx := context.Background() + yml.CreateOrUpdateMapNodeElement(ctx, "description", nil, yml.CreateStringNode(f.description), f.targetNode) + return nil +} + +// addContactFix prompts for contact name, URL, and email and adds them to the info node. +type addContactFix struct { + infoNode *yaml.Node + name string + url string + email string +} + +func (f *addContactFix) Description() string { return "Add contact information to info section" } +func (f *addContactFix) Interactive() bool { return true } +func (f *addContactFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + {Type: validation.PromptFreeText, Message: "Contact name"}, + {Type: validation.PromptFreeText, Message: "Contact URL"}, + {Type: validation.PromptFreeText, Message: "Contact email"}, + } +} + +func (f *addContactFix) SetInput(responses []string) error { + if len(responses) != 3 { + return fmt.Errorf("expected 3 responses, got %d", len(responses)) + } + f.name = responses[0] + f.url = responses[1] + f.email = responses[2] + return nil +} + +func (f *addContactFix) Apply(doc any) error { return nil } + +func (f *addContactFix) ApplyNode(_ *yaml.Node) error { + if f.infoNode == nil || f.infoNode.Kind != yaml.MappingNode { + return nil + } + ctx := context.Background() + var content []*yaml.Node + if f.name != "" { + content = append(content, yml.CreateStringNode("name"), yml.CreateStringNode(f.name)) + } + if f.url != "" { + content = append(content, yml.CreateStringNode("url"), yml.CreateStringNode(f.url)) + } + if f.email != "" { + content = append(content, yml.CreateStringNode("email"), yml.CreateStringNode(f.email)) + } + if len(content) == 0 { + return nil + } + contactNode := yml.CreateMapNode(ctx, content) + yml.CreateOrUpdateMapNodeElement(ctx, "contact", nil, contactNode, f.infoNode) + return nil +} + +// addLicenseFix prompts for license name and adds a license object to the info node. +type addLicenseFix struct { + infoNode *yaml.Node + licenseName string +} + +func (f *addLicenseFix) Description() string { return "Add license to info section" } +func (f *addLicenseFix) Interactive() bool { return true } +func (f *addLicenseFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "License type", + Choices: []string{"MIT", "Apache-2.0", "BSD-3-Clause", "Other"}, + }, + } +} + +func (f *addLicenseFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.licenseName = responses[0] + return nil +} + +func (f *addLicenseFix) Apply(doc any) error { return nil } + +func (f *addLicenseFix) ApplyNode(_ *yaml.Node) error { + if f.infoNode == nil || f.infoNode.Kind != yaml.MappingNode || f.licenseName == "" { + return nil + } + ctx := context.Background() + licenseNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode(f.licenseName), + }) + yml.CreateOrUpdateMapNodeElement(ctx, "license", nil, licenseNode, f.infoNode) + return nil +} + +// addLicenseURLFix prompts for a license URL and sets it on the license node. +type addLicenseURLFix struct { + licenseNode *yaml.Node + url string +} + +func (f *addLicenseURLFix) Description() string { return "Add URL to license" } +func (f *addLicenseURLFix) Interactive() bool { return true } +func (f *addLicenseURLFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + {Type: validation.PromptFreeText, Message: "License URL"}, + } +} + +func (f *addLicenseURLFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.url = responses[0] + return nil +} + +func (f *addLicenseURLFix) Apply(doc any) error { return nil } + +func (f *addLicenseURLFix) ApplyNode(_ *yaml.Node) error { + if f.licenseNode == nil || f.licenseNode.Kind != yaml.MappingNode || f.url == "" { + return nil + } + ctx := context.Background() + yml.CreateOrUpdateMapNodeElement(ctx, "url", nil, yml.CreateStringNode(f.url), f.licenseNode) + return nil +} + +// addOperationTagFix prompts for a tag and adds it to an operation. +type addOperationTagFix struct { + operationNode *yaml.Node + tag string +} + +func (f *addOperationTagFix) Description() string { return "Add tag to operation" } +func (f *addOperationTagFix) Interactive() bool { return true } +func (f *addOperationTagFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + {Type: validation.PromptFreeText, Message: "Tag for this operation"}, + } +} + +func (f *addOperationTagFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.tag = responses[0] + return nil +} + +func (f *addOperationTagFix) Apply(doc any) error { return nil } + +func (f *addOperationTagFix) ApplyNode(_ *yaml.Node) error { + if f.operationNode == nil || f.operationNode.Kind != yaml.MappingNode || f.tag == "" { + return nil + } + ctx := context.Background() + // Check if tags array exists + _, tagsNode, found := yml.GetMapElementNodes(ctx, f.operationNode, "tags") + if !found || tagsNode == nil { + tagsNode = &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"} + yml.CreateOrUpdateMapNodeElement(ctx, "tags", nil, tagsNode, f.operationNode) + } + tagsNode.Content = append(tagsNode.Content, yml.CreateStringNode(f.tag)) + return nil +} + +// addContactPropertyFix prompts for a single missing contact property. +type addContactPropertyFix struct { + contactNode *yaml.Node + property string // "name", "url", or "email" + value string +} + +func (f *addContactPropertyFix) Description() string { + return "Add " + f.property + " to contact" +} +func (f *addContactPropertyFix) Interactive() bool { return true } +func (f *addContactPropertyFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + {Type: validation.PromptFreeText, Message: "Contact " + f.property}, + } +} + +func (f *addContactPropertyFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.value = responses[0] + return nil +} + +func (f *addContactPropertyFix) Apply(doc any) error { return nil } + +func (f *addContactPropertyFix) ApplyNode(_ *yaml.Node) error { + if f.contactNode == nil || f.contactNode.Kind != yaml.MappingNode || f.value == "" { + return nil + } + ctx := context.Background() + yml.CreateOrUpdateMapNodeElement(ctx, f.property, nil, yml.CreateStringNode(f.value), f.contactNode) + return nil +} + +// replaceServerURLFix prompts for a replacement server URL. +type replaceServerURLFix struct { + urlNode *yaml.Node + newURL string +} + +func (f *replaceServerURLFix) Description() string { return "Replace server URL" } +func (f *replaceServerURLFix) Interactive() bool { return true } +func (f *replaceServerURLFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + {Type: validation.PromptFreeText, Message: "New server URL"}, + } +} + +func (f *replaceServerURLFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.newURL = responses[0] + return nil +} + +func (f *replaceServerURLFix) Apply(doc any) error { return nil } + +func (f *replaceServerURLFix) ApplyNode(_ *yaml.Node) error { + if f.urlNode != nil && f.newURL != "" { + f.urlNode.Value = f.newURL + } + return nil +} + +// addServerFix prompts for a server URL and adds it to the document. +type addServerFix struct { + doc *openapi.OpenAPI + url string +} + +func (f *addServerFix) Description() string { return "Add server URL" } +func (f *addServerFix) Interactive() bool { return true } +func (f *addServerFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + {Type: validation.PromptFreeText, Message: "Server URL"}, + } +} + +func (f *addServerFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.url = responses[0] + return nil +} + +func (f *addServerFix) Apply(doc any) error { + if f.url == "" { + return nil + } + oasDoc, ok := doc.(*openapi.OpenAPI) + if !ok { + return fmt.Errorf("expected *openapi.OpenAPI, got %T", doc) + } + oasDoc.Servers = append(oasDoc.Servers, &openapi.Server{URL: f.url}) + return nil +} + +// setIntegerFormatFix prompts for int32 or int64 and sets the format on a schema node. +type setIntegerFormatFix struct { + schemaNode *yaml.Node + format string +} + +func (f *setIntegerFormatFix) Description() string { return "Set integer format" } +func (f *setIntegerFormatFix) Interactive() bool { return true } +func (f *setIntegerFormatFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "Integer format", + Choices: []string{"int32", "int64"}, + }, + } +} + +func (f *setIntegerFormatFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.format = responses[0] + return nil +} + +func (f *setIntegerFormatFix) Apply(doc any) error { return nil } + +func (f *setIntegerFormatFix) ApplyNode(_ *yaml.Node) error { + if f.schemaNode == nil || f.schemaNode.Kind != yaml.MappingNode || f.format == "" { + return nil + } + ctx := context.Background() + yml.CreateOrUpdateMapNodeElement(ctx, "format", nil, yml.CreateStringNode(f.format), f.schemaNode) + return nil +} + +// setNumericPropertyFix prompts for a numeric value and sets it as a property on a schema node. +type setNumericPropertyFix struct { + schemaNode *yaml.Node + property string // e.g. "maxLength", "maxItems", "maxProperties" + label string // human-readable prompt label + value int64 +} + +func (f *setNumericPropertyFix) Description() string { + return "Set " + f.property +} +func (f *setNumericPropertyFix) Interactive() bool { return true } +func (f *setNumericPropertyFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + {Type: validation.PromptFreeText, Message: f.label}, + } +} + +func (f *setNumericPropertyFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + val, err := strconv.ParseInt(responses[0], 10, 64) + if err != nil { + return fmt.Errorf("invalid number %q: %w", responses[0], err) + } + f.value = val + return nil +} + +func (f *setNumericPropertyFix) Apply(doc any) error { return nil } + +func (f *setNumericPropertyFix) ApplyNode(_ *yaml.Node) error { + if f.schemaNode == nil || f.schemaNode.Kind != yaml.MappingNode { + return nil + } + ctx := context.Background() + yml.CreateOrUpdateMapNodeElement(ctx, f.property, nil, yml.CreateIntNode(f.value), f.schemaNode) + return nil +} + +// setIntegerLimitsFix prompts for minimum and maximum values for integer schemas. +type setIntegerLimitsFix struct { + schemaNode *yaml.Node + minVal int64 + maxVal int64 +} + +func (f *setIntegerLimitsFix) Description() string { return "Set integer minimum and maximum" } +func (f *setIntegerLimitsFix) Interactive() bool { return true } +func (f *setIntegerLimitsFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + {Type: validation.PromptFreeText, Message: "Minimum value"}, + {Type: validation.PromptFreeText, Message: "Maximum value"}, + } +} + +func (f *setIntegerLimitsFix) SetInput(responses []string) error { + if len(responses) != 2 { + return fmt.Errorf("expected 2 responses, got %d", len(responses)) + } + minV, err := strconv.ParseInt(responses[0], 10, 64) + if err != nil { + return fmt.Errorf("invalid minimum %q: %w", responses[0], err) + } + maxV, err := strconv.ParseInt(responses[1], 10, 64) + if err != nil { + return fmt.Errorf("invalid maximum %q: %w", responses[1], err) + } + f.minVal = minV + f.maxVal = maxV + return nil +} + +func (f *setIntegerLimitsFix) Apply(doc any) error { return nil } + +func (f *setIntegerLimitsFix) ApplyNode(_ *yaml.Node) error { + if f.schemaNode == nil || f.schemaNode.Kind != yaml.MappingNode { + return nil + } + ctx := context.Background() + yml.CreateOrUpdateMapNodeElement(ctx, "minimum", nil, yml.CreateIntNode(f.minVal), f.schemaNode) + yml.CreateOrUpdateMapNodeElement(ctx, "maximum", nil, yml.CreateIntNode(f.maxVal), f.schemaNode) + return nil +} + +// removeUnusedComponentFix is an interactive fix that removes an unused component entry. +type removeUnusedComponentFix struct { + parentMapNode *yaml.Node // the component type's mapping node (e.g., schemas map) + componentName string // the key to remove + componentRef string // human-readable ref e.g. "#/components/schemas/Pet" + confirmed bool +} + +func (f *removeUnusedComponentFix) Description() string { + return "Remove unused component " + f.componentRef +} +func (f *removeUnusedComponentFix) Interactive() bool { return true } +func (f *removeUnusedComponentFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "Remove unused component " + f.componentRef + "?", + Choices: []string{"Yes", "No"}, + }, + } +} + +func (f *removeUnusedComponentFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.confirmed = responses[0] == "Yes" + return nil +} + +func (f *removeUnusedComponentFix) Apply(doc any) error { return nil } + +func (f *removeUnusedComponentFix) ApplyNode(_ *yaml.Node) error { + if !f.confirmed || f.parentMapNode == nil || f.parentMapNode.Kind != yaml.MappingNode { + return nil + } + ctx := context.Background() + yml.DeleteMapNodeElement(ctx, f.componentName, f.parentMapNode) + return nil +} diff --git a/openapi/linter/rules/fix_helpers_test.go b/openapi/linter/rules/fix_helpers_test.go new file mode 100644 index 00000000..5c687d67 --- /dev/null +++ b/openapi/linter/rules/fix_helpers_test.go @@ -0,0 +1,1014 @@ +package rules + +import ( + "testing" + + "github.com/speakeasy-api/openapi/openapi" + "github.com/speakeasy-api/openapi/validation" + "github.com/speakeasy-api/openapi/yml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +// ============================================================ +// Non-interactive fixes +// ============================================================ + +func TestAddErrorResponseFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addErrorResponseFix{statusCode: "401", description: "Unauthorized"} + assert.Equal(t, "Add 401 response: Unauthorized", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + require.NoError(t, f.SetInput(nil)) + require.NoError(t, f.Apply(nil)) + }) + + t.Run("adds response to mapping", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + responsesNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("200"), + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("description"), + yml.CreateStringNode("OK"), + }), + }) + f := &addErrorResponseFix{responsesNode: responsesNode, statusCode: "401", description: "Unauthorized"} + require.NoError(t, f.ApplyNode(nil)) + + _, val, found := yml.GetMapElementNodes(ctx, responsesNode, "401") + require.True(t, found, "401 response should be added") + _, desc, found := yml.GetMapElementNodes(ctx, val, "description") + require.True(t, found, "response should have description") + assert.Equal(t, "Unauthorized", desc.Value) + }) + + t.Run("idempotent when status exists", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + responsesNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("401"), + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("description"), + yml.CreateStringNode("Existing"), + }), + }) + f := &addErrorResponseFix{responsesNode: responsesNode, statusCode: "401", description: "Unauthorized"} + require.NoError(t, f.ApplyNode(nil)) + + _, val, _ := yml.GetMapElementNodes(ctx, responsesNode, "401") + _, desc, _ := yml.GetMapElementNodes(ctx, val, "description") + assert.Equal(t, "Existing", desc.Value, "should not overwrite existing response") + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &addErrorResponseFix{responsesNode: nil} + require.NoError(t, f.ApplyNode(nil)) + }) + + t.Run("non-mapping node is no-op", func(t *testing.T) { + t.Parallel() + f := &addErrorResponseFix{responsesNode: &yaml.Node{Kind: yaml.ScalarNode}} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestAddRetryAfterHeaderFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addRetryAfterHeaderFix{} + assert.Equal(t, "Add Retry-After header to 429 response", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + require.NoError(t, f.SetInput(nil)) + require.NoError(t, f.Apply(nil)) + }) + + t.Run("creates headers and adds Retry-After", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + responseNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("description"), + yml.CreateStringNode("Too Many Requests"), + }) + f := &addRetryAfterHeaderFix{responseNode: responseNode} + require.NoError(t, f.ApplyNode(nil)) + + _, headersNode, found := yml.GetMapElementNodes(ctx, responseNode, "headers") + require.True(t, found, "headers should be added") + _, retryAfter, found := yml.GetMapElementNodes(ctx, headersNode, "Retry-After") + require.True(t, found, "Retry-After header should be added") + _, desc, found := yml.GetMapElementNodes(ctx, retryAfter, "description") + require.True(t, found) + assert.Contains(t, desc.Value, "seconds") + _, schema, found := yml.GetMapElementNodes(ctx, retryAfter, "schema") + require.True(t, found) + _, typ, found := yml.GetMapElementNodes(ctx, schema, "type") + require.True(t, found) + assert.Equal(t, "integer", typ.Value) + }) + + t.Run("adds to existing headers mapping", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + headersNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("X-Custom"), + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("description"), + yml.CreateStringNode("Custom header"), + }), + }) + responseNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("headers"), + headersNode, + }) + f := &addRetryAfterHeaderFix{responseNode: responseNode} + require.NoError(t, f.ApplyNode(nil)) + + _, hNode, _ := yml.GetMapElementNodes(ctx, responseNode, "headers") + _, _, found := yml.GetMapElementNodes(ctx, hNode, "Retry-After") + assert.True(t, found, "Retry-After should be added alongside existing headers") + _, _, found = yml.GetMapElementNodes(ctx, hNode, "X-Custom") + assert.True(t, found, "existing headers should be preserved") + }) + + t.Run("idempotent when Retry-After exists", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + retryAfterHeader := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("description"), + yml.CreateStringNode("Original"), + }) + headersNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("Retry-After"), + retryAfterHeader, + }) + responseNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("headers"), + headersNode, + }) + f := &addRetryAfterHeaderFix{responseNode: responseNode} + require.NoError(t, f.ApplyNode(nil)) + + _, hNode, _ := yml.GetMapElementNodes(ctx, responseNode, "headers") + _, raNode, _ := yml.GetMapElementNodes(ctx, hNode, "Retry-After") + _, dNode, _ := yml.GetMapElementNodes(ctx, raNode, "description") + assert.Equal(t, "Original", dNode.Value, "should not overwrite existing Retry-After") + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &addRetryAfterHeaderFix{responseNode: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +// ============================================================ +// Interactive single-prompt fixes +// ============================================================ + +func TestAddDescriptionFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addDescriptionFix{targetLabel: "schema 'Pet'"} + assert.Equal(t, "Add description to schema 'Pet'", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + assert.Contains(t, prompts[0].Message, "schema 'Pet'") + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &addDescriptionFix{} + require.NoError(t, f.SetInput([]string{"A pet object"})) + assert.Equal(t, "A pet object", f.description) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &addDescriptionFix{} + require.Error(t, f.SetInput([]string{})) + require.Error(t, f.SetInput([]string{"a", "b"})) + }) + + t.Run("applies description", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + targetNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("object"), + }) + f := &addDescriptionFix{targetNode: targetNode, description: "A pet object"} + require.NoError(t, f.ApplyNode(nil)) + + _, desc, found := yml.GetMapElementNodes(ctx, targetNode, "description") + require.True(t, found) + assert.Equal(t, "A pet object", desc.Value) + }) + + t.Run("empty description is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + targetNode := yml.CreateMapNode(ctx, nil) + f := &addDescriptionFix{targetNode: targetNode, description: ""} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, targetNode, "description") + assert.False(t, found) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &addDescriptionFix{targetNode: nil, description: "test"} + require.NoError(t, f.ApplyNode(nil)) + }) + + t.Run("apply is no-op", func(t *testing.T) { + t.Parallel() + f := &addDescriptionFix{} + require.NoError(t, f.Apply(nil)) + }) +} + +func TestAddLicenseURLFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addLicenseURLFix{} + assert.Equal(t, "Add URL to license", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &addLicenseURLFix{} + require.NoError(t, f.SetInput([]string{"https://opensource.org/licenses/MIT"})) + assert.Equal(t, "https://opensource.org/licenses/MIT", f.url) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &addLicenseURLFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("applies url to license node", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + licenseNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("MIT"), + }) + f := &addLicenseURLFix{licenseNode: licenseNode, url: "https://opensource.org/licenses/MIT"} + require.NoError(t, f.ApplyNode(nil)) + + _, urlNode, found := yml.GetMapElementNodes(ctx, licenseNode, "url") + require.True(t, found) + assert.Equal(t, "https://opensource.org/licenses/MIT", urlNode.Value) + }) + + t.Run("empty url is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + licenseNode := yml.CreateMapNode(ctx, nil) + f := &addLicenseURLFix{licenseNode: licenseNode, url: ""} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, licenseNode, "url") + assert.False(t, found) + }) +} + +func TestAddContactPropertyFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addContactPropertyFix{property: "email"} + assert.Equal(t, "Add email to contact", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + assert.Contains(t, prompts[0].Message, "email") + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &addContactPropertyFix{} + require.NoError(t, f.SetInput([]string{"test@example.com"})) + assert.Equal(t, "test@example.com", f.value) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &addContactPropertyFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("applies property to contact node", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + contactNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("Support"), + }) + f := &addContactPropertyFix{contactNode: contactNode, property: "email", value: "support@example.com"} + require.NoError(t, f.ApplyNode(nil)) + + _, emailNode, found := yml.GetMapElementNodes(ctx, contactNode, "email") + require.True(t, found) + assert.Equal(t, "support@example.com", emailNode.Value) + }) + + t.Run("empty value is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + contactNode := yml.CreateMapNode(ctx, nil) + f := &addContactPropertyFix{contactNode: contactNode, property: "email", value: ""} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, contactNode, "email") + assert.False(t, found) + }) +} + +func TestReplaceServerURLFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &replaceServerURLFix{} + assert.Equal(t, "Replace server URL", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &replaceServerURLFix{} + require.NoError(t, f.SetInput([]string{"https://api.real.com"})) + assert.Equal(t, "https://api.real.com", f.newURL) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &replaceServerURLFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("replaces node value", func(t *testing.T) { + t.Parallel() + urlNode := yml.CreateStringNode("https://example.com") + f := &replaceServerURLFix{urlNode: urlNode, newURL: "https://api.real.com"} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "https://api.real.com", urlNode.Value) + }) + + t.Run("empty url is no-op", func(t *testing.T) { + t.Parallel() + urlNode := yml.CreateStringNode("https://example.com") + f := &replaceServerURLFix{urlNode: urlNode, newURL: ""} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "https://example.com", urlNode.Value) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &replaceServerURLFix{urlNode: nil, newURL: "https://api.real.com"} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestAddOperationTagFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addOperationTagFix{} + assert.Equal(t, "Add tag to operation", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &addOperationTagFix{} + require.NoError(t, f.SetInput([]string{"users"})) + assert.Equal(t, "users", f.tag) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &addOperationTagFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("creates tags array and adds tag", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + operationNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("summary"), + yml.CreateStringNode("List users"), + }) + f := &addOperationTagFix{operationNode: operationNode, tag: "users"} + require.NoError(t, f.ApplyNode(nil)) + + _, tagsNode, found := yml.GetMapElementNodes(ctx, operationNode, "tags") + require.True(t, found, "tags should be created") + assert.Equal(t, yaml.SequenceNode, tagsNode.Kind) + require.Len(t, tagsNode.Content, 1) + assert.Equal(t, "users", tagsNode.Content[0].Value) + }) + + t.Run("appends to existing tags", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + tagsNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Tag: "!!seq", + Content: []*yaml.Node{ + yml.CreateStringNode("existing"), + }, + } + operationNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("tags"), + tagsNode, + }) + f := &addOperationTagFix{operationNode: operationNode, tag: "newTag"} + require.NoError(t, f.ApplyNode(nil)) + + _, updatedTags, _ := yml.GetMapElementNodes(ctx, operationNode, "tags") + require.Len(t, updatedTags.Content, 2) + assert.Equal(t, "existing", updatedTags.Content[0].Value) + assert.Equal(t, "newTag", updatedTags.Content[1].Value) + }) + + t.Run("empty tag is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + operationNode := yml.CreateMapNode(ctx, nil) + f := &addOperationTagFix{operationNode: operationNode, tag: ""} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, operationNode, "tags") + assert.False(t, found) + }) +} + +// ============================================================ +// Interactive choice fixes +// ============================================================ + +func TestAddLicenseFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addLicenseFix{} + assert.Equal(t, "Add license to info section", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptChoice, prompts[0].Type) + assert.Contains(t, prompts[0].Choices, "MIT") + assert.Contains(t, prompts[0].Choices, "Apache-2.0") + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &addLicenseFix{} + require.NoError(t, f.SetInput([]string{"MIT"})) + assert.Equal(t, "MIT", f.licenseName) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &addLicenseFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("adds license to info node", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + infoNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("title"), + yml.CreateStringNode("Test API"), + }) + f := &addLicenseFix{infoNode: infoNode, licenseName: "MIT"} + require.NoError(t, f.ApplyNode(nil)) + + _, licenseNode, found := yml.GetMapElementNodes(ctx, infoNode, "license") + require.True(t, found) + _, nameNode, found := yml.GetMapElementNodes(ctx, licenseNode, "name") + require.True(t, found) + assert.Equal(t, "MIT", nameNode.Value) + }) + + t.Run("empty license is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + infoNode := yml.CreateMapNode(ctx, nil) + f := &addLicenseFix{infoNode: infoNode, licenseName: ""} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, infoNode, "license") + assert.False(t, found) + }) +} + +func TestSetIntegerFormatFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &setIntegerFormatFix{} + assert.Equal(t, "Set integer format", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptChoice, prompts[0].Type) + assert.Contains(t, prompts[0].Choices, "int32") + assert.Contains(t, prompts[0].Choices, "int64") + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &setIntegerFormatFix{} + require.NoError(t, f.SetInput([]string{"int64"})) + assert.Equal(t, "int64", f.format) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &setIntegerFormatFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("sets format on schema", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("integer"), + }) + f := &setIntegerFormatFix{schemaNode: schemaNode, format: "int32"} + require.NoError(t, f.ApplyNode(nil)) + + _, formatNode, found := yml.GetMapElementNodes(ctx, schemaNode, "format") + require.True(t, found) + assert.Equal(t, "int32", formatNode.Value) + }) + + t.Run("empty format is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + schemaNode := yml.CreateMapNode(ctx, nil) + f := &setIntegerFormatFix{schemaNode: schemaNode, format: ""} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, schemaNode, "format") + assert.False(t, found) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &setIntegerFormatFix{schemaNode: nil, format: "int32"} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +// ============================================================ +// Interactive multi-prompt fixes +// ============================================================ + +func TestAddContactFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addContactFix{} + assert.Equal(t, "Add contact information to info section", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 3) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + assert.Equal(t, validation.PromptFreeText, prompts[1].Type) + assert.Equal(t, validation.PromptFreeText, prompts[2].Type) + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &addContactFix{} + require.NoError(t, f.SetInput([]string{"Support", "https://support.example.com", "support@example.com"})) + assert.Equal(t, "Support", f.name) + assert.Equal(t, "https://support.example.com", f.url) + assert.Equal(t, "support@example.com", f.email) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &addContactFix{} + require.Error(t, f.SetInput([]string{"only one"})) + require.Error(t, f.SetInput([]string{"a", "b"})) + }) + + t.Run("adds contact with all fields", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + infoNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("title"), + yml.CreateStringNode("Test API"), + }) + f := &addContactFix{infoNode: infoNode, name: "Support", url: "https://support.example.com", email: "support@example.com"} + require.NoError(t, f.ApplyNode(nil)) + + _, contactNode, found := yml.GetMapElementNodes(ctx, infoNode, "contact") + require.True(t, found) + _, nameNode, found := yml.GetMapElementNodes(ctx, contactNode, "name") + require.True(t, found) + assert.Equal(t, "Support", nameNode.Value) + _, urlNode, found := yml.GetMapElementNodes(ctx, contactNode, "url") + require.True(t, found) + assert.Equal(t, "https://support.example.com", urlNode.Value) + _, emailNode, found := yml.GetMapElementNodes(ctx, contactNode, "email") + require.True(t, found) + assert.Equal(t, "support@example.com", emailNode.Value) + }) + + t.Run("adds contact with partial fields", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + infoNode := yml.CreateMapNode(ctx, nil) + f := &addContactFix{infoNode: infoNode, name: "Support", url: "", email: ""} + require.NoError(t, f.ApplyNode(nil)) + + _, contactNode, found := yml.GetMapElementNodes(ctx, infoNode, "contact") + require.True(t, found, "contact should be added even with partial fields") + _, nameNode, found := yml.GetMapElementNodes(ctx, contactNode, "name") + require.True(t, found) + assert.Equal(t, "Support", nameNode.Value) + _, _, found = yml.GetMapElementNodes(ctx, contactNode, "url") + assert.False(t, found, "empty url should not be added") + _, _, found = yml.GetMapElementNodes(ctx, contactNode, "email") + assert.False(t, found, "empty email should not be added") + }) + + t.Run("all empty fields is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + infoNode := yml.CreateMapNode(ctx, nil) + f := &addContactFix{infoNode: infoNode, name: "", url: "", email: ""} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, infoNode, "contact") + assert.False(t, found) + }) +} + +func TestSetIntegerLimitsFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &setIntegerLimitsFix{} + assert.Equal(t, "Set integer minimum and maximum", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 2) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + assert.Contains(t, prompts[0].Message, "Minimum") + assert.Equal(t, validation.PromptFreeText, prompts[1].Type) + assert.Contains(t, prompts[1].Message, "Maximum") + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &setIntegerLimitsFix{} + require.NoError(t, f.SetInput([]string{"0", "100"})) + assert.Equal(t, int64(0), f.minVal) + assert.Equal(t, int64(100), f.maxVal) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &setIntegerLimitsFix{} + require.Error(t, f.SetInput([]string{"1"})) + require.Error(t, f.SetInput([]string{"1", "2", "3"})) + }) + + t.Run("set input invalid minimum", func(t *testing.T) { + t.Parallel() + f := &setIntegerLimitsFix{} + err := f.SetInput([]string{"abc", "100"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid minimum") + }) + + t.Run("set input invalid maximum", func(t *testing.T) { + t.Parallel() + f := &setIntegerLimitsFix{} + err := f.SetInput([]string{"0", "xyz"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid maximum") + }) + + t.Run("sets minimum and maximum on schema", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("integer"), + }) + f := &setIntegerLimitsFix{schemaNode: schemaNode, minVal: -100, maxVal: 100} + require.NoError(t, f.ApplyNode(nil)) + + _, minNode, found := yml.GetMapElementNodes(ctx, schemaNode, "minimum") + require.True(t, found) + assert.Equal(t, "-100", minNode.Value) + _, maxNode, found := yml.GetMapElementNodes(ctx, schemaNode, "maximum") + require.True(t, found) + assert.Equal(t, "100", maxNode.Value) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &setIntegerLimitsFix{schemaNode: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +// ============================================================ +// Interactive numeric fix +// ============================================================ + +func TestSetNumericPropertyFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &setNumericPropertyFix{property: "maxLength", label: "Maximum string length"} + assert.Equal(t, "Set maxLength", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + assert.Equal(t, "Maximum string length", prompts[0].Message) + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &setNumericPropertyFix{} + require.NoError(t, f.SetInput([]string{"255"})) + assert.Equal(t, int64(255), f.value) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &setNumericPropertyFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("set input invalid number", func(t *testing.T) { + t.Parallel() + f := &setNumericPropertyFix{} + err := f.SetInput([]string{"not-a-number"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid number") + }) + + t.Run("sets property on schema", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("string"), + }) + f := &setNumericPropertyFix{schemaNode: schemaNode, property: "maxLength", value: 255} + require.NoError(t, f.ApplyNode(nil)) + + _, valNode, found := yml.GetMapElementNodes(ctx, schemaNode, "maxLength") + require.True(t, found) + assert.Equal(t, "255", valNode.Value) + }) + + t.Run("works for maxItems", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("array"), + }) + f := &setNumericPropertyFix{schemaNode: schemaNode, property: "maxItems", value: 100} + require.NoError(t, f.ApplyNode(nil)) + + _, valNode, found := yml.GetMapElementNodes(ctx, schemaNode, "maxItems") + require.True(t, found) + assert.Equal(t, "100", valNode.Value) + }) + + t.Run("works for maxProperties", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("object"), + }) + f := &setNumericPropertyFix{schemaNode: schemaNode, property: "maxProperties", value: 50} + require.NoError(t, f.ApplyNode(nil)) + + _, valNode, found := yml.GetMapElementNodes(ctx, schemaNode, "maxProperties") + require.True(t, found) + assert.Equal(t, "50", valNode.Value) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &setNumericPropertyFix{schemaNode: nil, property: "maxLength", value: 100} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +// ============================================================ +// Interactive confirmation fix +// ============================================================ + +func TestRemoveUnusedComponentFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &removeUnusedComponentFix{componentRef: "#/components/schemas/Pet"} + assert.Equal(t, "Remove unused component #/components/schemas/Pet", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptChoice, prompts[0].Type) + assert.Contains(t, prompts[0].Choices, "Yes") + assert.Contains(t, prompts[0].Choices, "No") + assert.Contains(t, prompts[0].Message, "#/components/schemas/Pet") + }) + + t.Run("set input yes", func(t *testing.T) { + t.Parallel() + f := &removeUnusedComponentFix{} + require.NoError(t, f.SetInput([]string{"Yes"})) + assert.True(t, f.confirmed) + }) + + t.Run("set input no", func(t *testing.T) { + t.Parallel() + f := &removeUnusedComponentFix{} + require.NoError(t, f.SetInput([]string{"No"})) + assert.False(t, f.confirmed) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &removeUnusedComponentFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("confirmed removes component from mapping", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + parentMap := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("Pet"), + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("object"), + }), + yml.CreateStringNode("User"), + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("object"), + }), + }) + f := &removeUnusedComponentFix{parentMapNode: parentMap, componentName: "Pet", confirmed: true} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, parentMap, "Pet") + assert.False(t, found, "Pet should be removed") + _, _, found = yml.GetMapElementNodes(ctx, parentMap, "User") + assert.True(t, found, "User should be preserved") + }) + + t.Run("not confirmed is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + parentMap := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("Pet"), + yml.CreateMapNode(ctx, nil), + }) + f := &removeUnusedComponentFix{parentMapNode: parentMap, componentName: "Pet", confirmed: false} + require.NoError(t, f.ApplyNode(nil)) + + _, _, found := yml.GetMapElementNodes(ctx, parentMap, "Pet") + assert.True(t, found, "Pet should NOT be removed when not confirmed") + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &removeUnusedComponentFix{parentMapNode: nil, componentName: "Pet", confirmed: true} + require.NoError(t, f.ApplyNode(nil)) + }) + + t.Run("apply is no-op", func(t *testing.T) { + t.Parallel() + f := &removeUnusedComponentFix{} + require.NoError(t, f.Apply(nil)) + }) +} + +// ============================================================ +// Model fix (Apply instead of ApplyNode) +// ============================================================ + +func TestAddServerFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addServerFix{} + assert.Equal(t, "Add server URL", f.Description()) + assert.True(t, f.Interactive()) + prompts := f.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptFreeText, prompts[0].Type) + }) + + t.Run("set input success", func(t *testing.T) { + t.Parallel() + f := &addServerFix{} + require.NoError(t, f.SetInput([]string{"https://api.example.com"})) + assert.Equal(t, "https://api.example.com", f.url) + }) + + t.Run("set input wrong count", func(t *testing.T) { + t.Parallel() + f := &addServerFix{} + require.Error(t, f.SetInput([]string{})) + }) + + t.Run("adds server to document", func(t *testing.T) { + t.Parallel() + doc := &openapi.OpenAPI{} + f := &addServerFix{url: "https://api.example.com"} + require.NoError(t, f.Apply(doc)) + + require.Len(t, doc.Servers, 1) + assert.Equal(t, "https://api.example.com", doc.Servers[0].URL) + }) + + t.Run("appends to existing servers", func(t *testing.T) { + t.Parallel() + doc := &openapi.OpenAPI{ + Servers: []*openapi.Server{{URL: "https://existing.com"}}, + } + f := &addServerFix{url: "https://new.example.com"} + require.NoError(t, f.Apply(doc)) + + require.Len(t, doc.Servers, 2) + assert.Equal(t, "https://existing.com", doc.Servers[0].URL) + assert.Equal(t, "https://new.example.com", doc.Servers[1].URL) + }) + + t.Run("empty url is no-op", func(t *testing.T) { + t.Parallel() + doc := &openapi.OpenAPI{} + f := &addServerFix{url: ""} + require.NoError(t, f.Apply(doc)) + + assert.Empty(t, doc.Servers) + }) + + t.Run("wrong doc type returns error", func(t *testing.T) { + t.Parallel() + f := &addServerFix{url: "https://api.example.com"} + err := f.Apply("not an openapi doc") + require.Error(t, err) + assert.Contains(t, err.Error(), "expected *openapi.OpenAPI") + }) +} diff --git a/openapi/linter/rules/fix_integration_test.go b/openapi/linter/rules/fix_integration_test.go new file mode 100644 index 00000000..d81e204f --- /dev/null +++ b/openapi/linter/rules/fix_integration_test.go @@ -0,0 +1,516 @@ +package rules + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/linter" + "github.com/speakeasy-api/openapi/openapi" + "github.com/speakeasy-api/openapi/references" + "github.com/speakeasy-api/openapi/validation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +// These integration tests verify that fixes actually resolve the violations +// they are meant to fix. The pattern is: +// 1. Parse a document with a known violation +// 2. Run the rule to get errors with fixes +// 3. Apply the fix +// 4. Re-parse/re-run to verify the violation is resolved + +// helper: parse OpenAPI document from YAML string +func parseOpenAPIDoc(t *testing.T, yamlStr string) *openapi.OpenAPI { + t.Helper() + ctx := t.Context() + doc, _, err := openapi.Unmarshal(ctx, strings.NewReader(yamlStr)) + require.NoError(t, err, "unmarshal should succeed") + return doc +} + +// helper: build index and create DocumentInfo +func buildDocInfo(t *testing.T, doc *openapi.OpenAPI) *linter.DocumentInfo[*openapi.OpenAPI] { + t.Helper() + ctx := t.Context() + idx := openapi.BuildIndex(ctx, doc, references.ResolveOptions{ + RootDocument: doc, + TargetDocument: doc, + TargetLocation: "test.yaml", + }) + return linter.NewDocumentInfoWithIndex(doc, "test.yaml", idx) +} + +// helper: extract NodeFix from first error +func extractNodeFix(t *testing.T, errs []error) validation.NodeFix { + t.Helper() + require.NotEmpty(t, errs, "should have at least one error") + var valErr *validation.Error + require.ErrorAs(t, errs[0], &valErr, "error should be a *validation.Error") + require.NotNil(t, valErr.Fix, "error should have a fix") + nodeFix, ok := valErr.Fix.(validation.NodeFix) + require.True(t, ok, "fix should implement NodeFix") + return nodeFix +} + +// helper: re-parse from modified YAML root node +func remarshalAndParse(t *testing.T, doc *openapi.OpenAPI) *openapi.OpenAPI { + t.Helper() + rootNode := doc.GetCore().GetRootNode() + require.NotNil(t, rootNode, "root node should exist") + + out, err := yaml.Marshal(rootNode) + require.NoError(t, err, "marshal should succeed") + + return parseOpenAPIDoc(t, string(out)) +} + +// ============================================================ +// Non-interactive fix integration tests +// ============================================================ + +func TestFixIntegration_HostTrailingSlash(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +servers: + - url: https://api.example.com/ +paths: {} +`) + + // Step 1: Run rule and get violation + rule := &OAS3HostTrailingSlashRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.Len(t, errs, 1, "should detect trailing slash") + + // Step 2: Apply fix + nodeFix := extractNodeFix(t, errs) + require.NoError(t, nodeFix.ApplyNode(nil)) + + // Step 3: Re-parse and verify violation is gone + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the trailing slash violation") +} + +func TestFixIntegration_HTTPSUpgrade(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +servers: + - url: http://api.example.com +paths: {} +`) + + rule := &OwaspSecurityHostsHttpsOAS3Rule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect http:// URL") + + nodeFix := extractNodeFix(t, errs) + require.NoError(t, nodeFix.ApplyNode(nil)) + + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the HTTPS violation") +} + +func TestFixIntegration_DuplicateEnum(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: {} +components: + schemas: + Status: + type: string + enum: + - active + - inactive + - active +`) + + rule := &DuplicatedEnumRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect duplicate enum entry") + + nodeFix := extractNodeFix(t, errs) + require.NoError(t, nodeFix.ApplyNode(nil)) + + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the duplicate enum violation") +} + +func TestFixIntegration_AdditionalPropertiesFalse(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: {} +components: + schemas: + Pet: + type: object + additionalProperties: true +`) + + rule := &OwaspNoAdditionalPropertiesRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect additionalProperties: true") + + nodeFix := extractNodeFix(t, errs) + require.NoError(t, nodeFix.ApplyNode(nil)) + + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the additionalProperties violation") +} + +func TestFixIntegration_TagsAlphabetical(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +tags: + - name: users + - name: admin + - name: pets +paths: {} +`) + + rule := &TagsAlphabeticalRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect unsorted tags") + + nodeFix := extractNodeFix(t, errs) + require.NoError(t, nodeFix.ApplyNode(nil)) + + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the alphabetical tag violation") +} + +func TestFixIntegration_AddErrorResponse(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: + /pets: + get: + operationId: listPets + responses: + "200": + description: OK +`) + + rule := &OwaspDefineErrorResponses401Rule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect missing 401 response") + + nodeFix := extractNodeFix(t, errs) + require.NoError(t, nodeFix.ApplyNode(nil)) + + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + // The fix adds {401: {description: "Unauthorized"}} which resolves the "missing response" + // violation. The rule may still report "missing content schema" — that's a different, lesser violation. + for _, err := range errs2 { + assert.NotContains(t, err.Error(), "must define", "the 'must define response' violation should be resolved") + } +} + +// ============================================================ +// Interactive fix integration tests +// ============================================================ + +func TestFixIntegration_InteractiveIntegerFormat(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: {} +components: + schemas: + Age: + type: integer +`) + + rule := &OwaspIntegerFormatRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect missing integer format") + + // Extract fix and simulate user input + var valErr *validation.Error + require.ErrorAs(t, errs[0], &valErr) + require.NotNil(t, valErr.Fix) + fix := valErr.Fix + + assert.True(t, fix.Interactive(), "should be an interactive fix") + prompts := fix.Prompts() + require.Len(t, prompts, 1) + assert.Equal(t, validation.PromptChoice, prompts[0].Type) + + // Simulate user choosing "int32" + require.NoError(t, fix.SetInput([]string{"int32"})) + + nodeFix, ok := fix.(validation.NodeFix) + require.True(t, ok) + require.NoError(t, nodeFix.ApplyNode(nil)) + + // Re-parse and verify + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the integer format violation") +} + +func TestFixIntegration_InteractiveStringLimit(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: {} +components: + schemas: + Name: + type: string +`) + + rule := &OwaspStringLimitRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect missing string limit") + + var valErr *validation.Error + require.ErrorAs(t, errs[0], &valErr) + require.NotNil(t, valErr.Fix) + fix := valErr.Fix + + assert.True(t, fix.Interactive()) + + // Simulate user entering maxLength = 255 + require.NoError(t, fix.SetInput([]string{"255"})) + + nodeFix, ok := fix.(validation.NodeFix) + require.True(t, ok) + require.NoError(t, nodeFix.ApplyNode(nil)) + + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the string limit violation") +} + +func TestFixIntegration_InteractiveIntegerLimits(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: {} +components: + schemas: + Count: + type: integer + format: int32 +`) + + rule := &OwaspIntegerLimitRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect missing integer limits") + + var valErr *validation.Error + require.ErrorAs(t, errs[0], &valErr) + require.NotNil(t, valErr.Fix) + fix := valErr.Fix + + assert.True(t, fix.Interactive()) + prompts := fix.Prompts() + require.Len(t, prompts, 2) + + // Simulate user entering min=0, max=1000 + require.NoError(t, fix.SetInput([]string{"0", "1000"})) + + nodeFix, ok := fix.(validation.NodeFix) + require.True(t, ok) + require.NoError(t, nodeFix.ApplyNode(nil)) + + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the integer limit violation") +} + +func TestFixIntegration_InteractiveArrayLimit(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: {} +components: + schemas: + Items: + type: array + items: + type: string +`) + + rule := &OwaspArrayLimitRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.NotEmpty(t, errs, "should detect missing array limit") + + var valErr *validation.Error + require.ErrorAs(t, errs[0], &valErr) + require.NotNil(t, valErr.Fix) + fix := valErr.Fix + + // Simulate user entering maxItems = 100 + require.NoError(t, fix.SetInput([]string{"100"})) + + nodeFix, ok := fix.(validation.NodeFix) + require.True(t, ok) + require.NoError(t, nodeFix.ApplyNode(nil)) + + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the array limit violation") +} + +// ============================================================ +// Fix engine integration test +// ============================================================ + +func TestFixIntegration_EngineAutoMode(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // A document with multiple non-interactive violations + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +servers: + - url: http://api.example.com/ +paths: + /pets: + get: + operationId: listPets + responses: + "200": + description: OK +`) + + docInfo := buildDocInfo(t, doc) + + // Run multiple rules to collect violations + rules := []interface { + Run(ctx context.Context, docInfo *linter.DocumentInfo[*openapi.OpenAPI], config *linter.RuleConfig) []error + }{ + &OAS3HostTrailingSlashRule{}, + &OwaspSecurityHostsHttpsOAS3Rule{}, + &OwaspDefineErrorResponses401Rule{}, + &OwaspDefineErrorResponses500Rule{}, + } + + var allErrors []error + config := &linter.RuleConfig{} + for _, rule := range rules { + errs := rule.Run(ctx, docInfo, config) + allErrors = append(allErrors, errs...) + } + + // Verify we have violations + require.NotEmpty(t, allErrors, "should have multiple violations") + + // Collect fixable errors (non-interactive only) + var fixCount int + for _, err := range allErrors { + var valErr *validation.Error + if !errors.As(err, &valErr) || valErr.Fix == nil { + continue + } + fix := valErr.Fix + if fix.Interactive() { + continue + } + // Apply the fix + if nodeFix, ok := fix.(validation.NodeFix); ok { + require.NoError(t, nodeFix.ApplyNode(nil)) + fixCount++ + } + } + require.Positive(t, fixCount, "should have applied at least one fix") + + // Re-parse and verify fixes resolved violations + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + + var remainingErrors []error + for _, rule := range rules { + errs := rule.Run(ctx, docInfo2, config) + remainingErrors = append(remainingErrors, errs...) + } + assert.Less(t, len(remainingErrors), len(allErrors), "fixes should reduce the number of violations") +} diff --git a/openapi/linter/rules/host_not_example.go b/openapi/linter/rules/host_not_example.go index ac403c8b..ec26bf92 100644 --- a/openapi/linter/rules/host_not_example.go +++ b/openapi/linter/rules/host_not_example.go @@ -60,12 +60,13 @@ func (r *OAS3HostNotExampleRule) Run(ctx context.Context, docInfo *linter.Docume errNode = doc.GetRootNode() } - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleOAS3HostNotExample, - fmt.Errorf("server url %q must not point at example.com", server.GetURL()), - errNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("server url %q must not point at example.com", server.GetURL()), + Node: errNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleOAS3HostNotExample, + Fix: &replaceServerURLFix{urlNode: errNode}, + }) } return errs diff --git a/openapi/linter/rules/host_trailing_slash.go b/openapi/linter/rules/host_trailing_slash.go index 7b106862..469aa41e 100644 --- a/openapi/linter/rules/host_trailing_slash.go +++ b/openapi/linter/rules/host_trailing_slash.go @@ -8,6 +8,7 @@ import ( "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/openapi" "github.com/speakeasy-api/openapi/validation" + "gopkg.in/yaml.v3" ) const RuleStyleOAS3HostTrailingSlash = "style-oas3-host-trailing-slash" @@ -67,14 +68,35 @@ func (r *OAS3HostTrailingSlashRule) Run(ctx context.Context, docInfo *linter.Doc errNode = server.GetRootNode() } - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleOAS3HostTrailingSlash, - fmt.Errorf("server url %q should not have a trailing slash", url), - errNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("server url %q should not have a trailing slash", url), + Node: errNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleOAS3HostTrailingSlash, + Fix: &removeHostTrailingSlashFix{node: errNode}, + }) } } return errs } + +// removeHostTrailingSlashFix removes the trailing slash from a server URL node. +type removeHostTrailingSlashFix struct { + node *yaml.Node +} + +func (f *removeHostTrailingSlashFix) Description() string { + return "Remove trailing slash from server URL" +} +func (f *removeHostTrailingSlashFix) Interactive() bool { return false } +func (f *removeHostTrailingSlashFix) Prompts() []validation.Prompt { return nil } +func (f *removeHostTrailingSlashFix) SetInput([]string) error { return nil } +func (f *removeHostTrailingSlashFix) Apply(doc any) error { return nil } + +func (f *removeHostTrailingSlashFix) ApplyNode(_ *yaml.Node) error { + if f.node != nil { + f.node.Value = strings.TrimRight(f.node.Value, "/") + } + return nil +} diff --git a/openapi/linter/rules/info_contact.go b/openapi/linter/rules/info_contact.go index f5b5287f..b57eb32b 100644 --- a/openapi/linter/rules/info_contact.go +++ b/openapi/linter/rules/info_contact.go @@ -50,12 +50,14 @@ func (r *InfoContactRule) Run(ctx context.Context, docInfo *linter.DocumentInfo[ contact := info.GetContact() if contact == nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleInfoContact, - errors.New("info section is missing contact details"), - info.GetRootNode(), - )) + infoRoot := info.GetRootNode() + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("info section is missing contact details"), + Node: infoRoot, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleInfoContact, + Fix: &addContactFix{infoNode: infoRoot}, + }) } return errs diff --git a/openapi/linter/rules/info_description.go b/openapi/linter/rules/info_description.go index 03e6ab53..7a4cb437 100644 --- a/openapi/linter/rules/info_description.go +++ b/openapi/linter/rules/info_description.go @@ -3,6 +3,7 @@ package rules import ( "context" "errors" + "fmt" "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/openapi" @@ -50,13 +51,51 @@ func (r *InfoDescriptionRule) Run(ctx context.Context, docInfo *linter.DocumentI description := info.GetDescription() if description == "" { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleInfoDescription, - errors.New("info section is missing a description"), - info.GetRootNode(), - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("info section is missing a description"), + Node: info.GetRootNode(), + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleInfoDescription, + Fix: &addInfoDescriptionFix{}, + }) } return errs } + +// addInfoDescriptionFix prompts the user for a description and sets it on the info object. +type addInfoDescriptionFix struct { + description string +} + +func (f *addInfoDescriptionFix) Description() string { return "Add a description to the info section" } +func (f *addInfoDescriptionFix) Interactive() bool { return true } +func (f *addInfoDescriptionFix) Prompts() []validation.Prompt { + return []validation.Prompt{ + { + Type: validation.PromptFreeText, + Message: "Enter an API description", + }, + } +} + +func (f *addInfoDescriptionFix) SetInput(responses []string) error { + if len(responses) != 1 { + return fmt.Errorf("expected 1 response, got %d", len(responses)) + } + f.description = responses[0] + return nil +} + +func (f *addInfoDescriptionFix) Apply(doc any) error { + oasDoc, ok := doc.(*openapi.OpenAPI) + if !ok { + return fmt.Errorf("expected *openapi.OpenAPI, got %T", doc) + } + info := oasDoc.GetInfo() + if info == nil { + return errors.New("document has no info section") + } + info.Description = &f.description + return nil +} diff --git a/openapi/linter/rules/info_license.go b/openapi/linter/rules/info_license.go index 5d9341b2..2ada103a 100644 --- a/openapi/linter/rules/info_license.go +++ b/openapi/linter/rules/info_license.go @@ -50,12 +50,14 @@ func (r *InfoLicenseRule) Run(ctx context.Context, docInfo *linter.DocumentInfo[ license := info.GetLicense() if license == nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleInfoLicense, - errors.New("info section should contain a license"), - info.GetRootNode(), - )) + infoRoot := info.GetRootNode() + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("info section should contain a license"), + Node: infoRoot, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleInfoLicense, + Fix: &addLicenseFix{infoNode: infoRoot}, + }) } return errs diff --git a/openapi/linter/rules/license_url.go b/openapi/linter/rules/license_url.go index 27fd9b04..1b7b1e9e 100644 --- a/openapi/linter/rules/license_url.go +++ b/openapi/linter/rules/license_url.go @@ -55,12 +55,14 @@ func (r *LicenseURLRule) Run(ctx context.Context, docInfo *linter.DocumentInfo[* url := license.GetURL() if url == "" { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleLicenseURL, - errors.New("license should contain a URL"), - license.GetRootNode(), - )) + licenseRoot := license.GetRootNode() + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("license should contain a URL"), + Node: licenseRoot, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleLicenseURL, + Fix: &addLicenseURLFix{licenseNode: licenseRoot}, + }) } return errs diff --git a/openapi/linter/rules/oas3_api_servers.go b/openapi/linter/rules/oas3_api_servers.go index 9eb95786..1bb46c1b 100644 --- a/openapi/linter/rules/oas3_api_servers.go +++ b/openapi/linter/rules/oas3_api_servers.go @@ -63,12 +63,13 @@ func (r *OAS3APIServersRule) Run(ctx context.Context, docInfo *linter.DocumentIn if len(doc.Servers) == 0 { // Get the root node for error reporting rootNode := doc.GetRootNode() - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleOAS3APIServers, - errors.New("no servers defined for the specification"), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("no servers defined for the specification"), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleOAS3APIServers, + Fix: &addServerFix{doc: doc}, + }) return errs } diff --git a/openapi/linter/rules/oas3_no_nullable.go b/openapi/linter/rules/oas3_no_nullable.go index acc3faed..ec7cb719 100644 --- a/openapi/linter/rules/oas3_no_nullable.go +++ b/openapi/linter/rules/oas3_no_nullable.go @@ -7,6 +7,8 @@ import ( "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/openapi" "github.com/speakeasy-api/openapi/validation" + "github.com/speakeasy-api/openapi/yml" + "gopkg.in/yaml.v3" ) const RuleOAS3NoNullable = "oas3-no-nullable" @@ -61,15 +63,71 @@ func (r *OAS3NoNullableRule) Run(ctx context.Context, docInfo *linter.DocumentIn // Check if nullable field is present in the YAML if coreSchema.Nullable.Present { if rootNode := refSchema.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOAS3NoNullable, - errors.New("the `nullable` keyword is not supported in OpenAPI 3.1 - use `type: [actualType, \"null\"]` instead"), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("the `nullable` keyword is not supported in OpenAPI 3.1 - use `type: [actualType, \"null\"]` instead"), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOAS3NoNullable, + Fix: &removeNullableFix{schemaNode: rootNode, typeValueNode: coreSchema.Type.ValueNode}, + }) } } } return errs } + +// removeNullableFix removes the nullable key and adds "null" to the type array. +type removeNullableFix struct { + schemaNode *yaml.Node // the schema mapping node + typeValueNode *yaml.Node // the existing type value node (may be nil) +} + +func (f *removeNullableFix) Description() string { + return "Replace nullable with type array including null" +} +func (f *removeNullableFix) Interactive() bool { return false } +func (f *removeNullableFix) Prompts() []validation.Prompt { return nil } +func (f *removeNullableFix) SetInput([]string) error { return nil } +func (f *removeNullableFix) Apply(doc any) error { return nil } + +func (f *removeNullableFix) ApplyNode(_ *yaml.Node) error { + if f.schemaNode == nil { + return nil + } + + ctx := context.Background() + + // Remove the nullable key/value pair + yml.DeleteMapNodeElement(ctx, "nullable", f.schemaNode) + + // Add "null" to the type field + if f.typeValueNode != nil { + switch f.typeValueNode.Kind { + case yaml.ScalarNode: + // type: string → type: [string, "null"] + existingType := f.typeValueNode.Value + f.typeValueNode.Kind = yaml.SequenceNode + f.typeValueNode.Tag = "!!seq" + f.typeValueNode.Value = "" + f.typeValueNode.Content = []*yaml.Node{ + yml.CreateStringNode(existingType), + yml.CreateStringNode("null"), + } + case yaml.SequenceNode: + // type: [string, integer] → type: [string, integer, "null"] + // Check if "null" already present + for _, n := range f.typeValueNode.Content { + if n.Value == "null" { + return nil + } + } + f.typeValueNode.Content = append(f.typeValueNode.Content, yml.CreateStringNode("null")) + } + } else { + // No type field exists — add type: "null" + yml.CreateOrUpdateMapNodeElement(ctx, "type", nil, yml.CreateStringNode("null"), f.schemaNode) + } + + return nil +} diff --git a/openapi/linter/rules/operation_description.go b/openapi/linter/rules/operation_description.go index 78d8d7be..a9696902 100644 --- a/openapi/linter/rules/operation_description.go +++ b/openapi/linter/rules/operation_description.go @@ -72,12 +72,14 @@ func (r *OperationDescriptionRule) Run(ctx context.Context, docInfo *linter.Docu } } - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleOperationDescription, - fmt.Errorf("the %s is missing a description or summary", opIdentifier), - operation.GetRootNode(), - )) + rootNode := operation.GetRootNode() + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("the %s is missing a description or summary", opIdentifier), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleOperationDescription, + Fix: &addDescriptionFix{targetNode: rootNode, targetLabel: "operation " + opIdentifier}, + }) } } diff --git a/openapi/linter/rules/operation_tag_defined.go b/openapi/linter/rules/operation_tag_defined.go index 9862bef2..9c241138 100644 --- a/openapi/linter/rules/operation_tag_defined.go +++ b/openapi/linter/rules/operation_tag_defined.go @@ -71,17 +71,44 @@ func (r *OperationTagDefinedRule) Run(ctx context.Context, docInfo *linter.Docum opTags := operation.GetTags() for i, tagName := range opTags { if tagName != "" && !globalTags[tagName] { - errs = append(errs, validation.NewSliceError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleOperationTagDefined, - fmt.Errorf("tag `%s` for %s operation is not defined as a global tag", tagName, opIdentifier), - operation.GetCore(), - operation.GetCore().Tags, - i, - )) + errNode := operation.GetCore().Tags.GetSliceValueNodeOrRoot(i, operation.GetRootNode()) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("tag `%s` for %s operation is not defined as a global tag", tagName, opIdentifier), + Node: errNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleOperationTagDefined, + Fix: &addGlobalTagFix{tagName: tagName}, + }) } } } return errs } + +// addGlobalTagFix adds a missing tag to the document's global tags array. +type addGlobalTagFix struct { + tagName string +} + +func (f *addGlobalTagFix) Description() string { + return "Add tag `" + f.tagName + "` to global tags" +} +func (f *addGlobalTagFix) Interactive() bool { return false } +func (f *addGlobalTagFix) Prompts() []validation.Prompt { return nil } +func (f *addGlobalTagFix) SetInput([]string) error { return nil } + +func (f *addGlobalTagFix) Apply(doc any) error { + oasDoc, ok := doc.(*openapi.OpenAPI) + if !ok { + return fmt.Errorf("expected *openapi.OpenAPI, got %T", doc) + } + // Idempotency: check if tag already exists + for _, tag := range oasDoc.Tags { + if tag != nil && tag.Name == f.tagName { + return nil + } + } + oasDoc.Tags = append(oasDoc.Tags, &openapi.Tag{Name: f.tagName}) + return nil +} diff --git a/openapi/linter/rules/operation_tags.go b/openapi/linter/rules/operation_tags.go index 17be43b3..cb706b58 100644 --- a/openapi/linter/rules/operation_tags.go +++ b/openapi/linter/rules/operation_tags.go @@ -70,12 +70,14 @@ func (r *OperationTagsRule) Run(ctx context.Context, docInfo *linter.DocumentInf } } - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleOperationTags, - fmt.Errorf("the %s is missing tags", opIdentifier), - operation.GetRootNode(), - )) + rootNode := operation.GetRootNode() + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("the %s is missing tags", opIdentifier), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleOperationTags, + Fix: &addOperationTagFix{operationNode: rootNode}, + }) } } diff --git a/openapi/linter/rules/owasp_additional_properties_constrained.go b/openapi/linter/rules/owasp_additional_properties_constrained.go index cf538e18..cc643335 100644 --- a/openapi/linter/rules/owasp_additional_properties_constrained.go +++ b/openapi/linter/rules/owasp_additional_properties_constrained.go @@ -95,12 +95,13 @@ func (r *OwaspAdditionalPropertiesConstrainedRule) Run(ctx context.Context, docI maxProps := schema.GetMaxProperties() if maxProps == nil { if rootNode := refSchema.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspAdditionalPropertiesConstrained, - errors.New("schema should define maxProperties when additionalProperties is set to true or a schema"), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("schema should define maxProperties when additionalProperties is set to true or a schema"), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspAdditionalPropertiesConstrained, + Fix: &setNumericPropertyFix{schemaNode: rootNode, property: "maxProperties", label: "Maximum number of properties"}, + }) } } } diff --git a/openapi/linter/rules/owasp_array_limit.go b/openapi/linter/rules/owasp_array_limit.go index 340b18f9..e2088c6d 100644 --- a/openapi/linter/rules/owasp_array_limit.go +++ b/openapi/linter/rules/owasp_array_limit.go @@ -71,12 +71,13 @@ func (r *OwaspArrayLimitRule) Run(ctx context.Context, docInfo *linter.DocumentI maxItems := schema.GetMaxItems() if maxItems == nil { if rootNode := refSchema.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspArrayLimit, - errors.New("schema of type `array` must specify `maxItems`"), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("schema of type `array` must specify `maxItems`"), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspArrayLimit, + Fix: &setNumericPropertyFix{schemaNode: rootNode, property: "maxItems", label: "Maximum number of array items"}, + }) } } } diff --git a/openapi/linter/rules/owasp_define_error_responses_401.go b/openapi/linter/rules/owasp_define_error_responses_401.go index dec24447..ff7d9206 100644 --- a/openapi/linter/rules/owasp_define_error_responses_401.go +++ b/openapi/linter/rules/owasp_define_error_responses_401.go @@ -86,12 +86,13 @@ func (r *OwaspDefineErrorResponses401Rule) Run(ctx context.Context, docInfo *lin if !has401 { // Missing 401 response if rootNode := responses.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspDefineErrorResponses401, - fmt.Errorf("operation %s %s is missing 401 Unauthorized error response", method, path), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("operation %s %s is missing 401 Unauthorized error response", method, path), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspDefineErrorResponses401, + Fix: &addErrorResponseFix{responsesNode: rootNode, statusCode: "401", description: "Unauthorized"}, + }) } continue } diff --git a/openapi/linter/rules/owasp_define_error_responses_429.go b/openapi/linter/rules/owasp_define_error_responses_429.go index 0c49d8f4..eb20e84c 100644 --- a/openapi/linter/rules/owasp_define_error_responses_429.go +++ b/openapi/linter/rules/owasp_define_error_responses_429.go @@ -86,12 +86,13 @@ func (r *OwaspDefineErrorResponses429Rule) Run(ctx context.Context, docInfo *lin if !has429 { // Missing 429 response if rootNode := responses.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspDefineErrorResponses429, - fmt.Errorf("operation %s %s is missing 429 Too Many Requests response", method, path), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("operation %s %s is missing 429 Too Many Requests response", method, path), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspDefineErrorResponses429, + Fix: &addErrorResponseFix{responsesNode: rootNode, statusCode: "429", description: "Too Many Requests"}, + }) } continue } diff --git a/openapi/linter/rules/owasp_define_error_responses_500.go b/openapi/linter/rules/owasp_define_error_responses_500.go index d807c958..57910f2a 100644 --- a/openapi/linter/rules/owasp_define_error_responses_500.go +++ b/openapi/linter/rules/owasp_define_error_responses_500.go @@ -86,12 +86,13 @@ func (r *OwaspDefineErrorResponses500Rule) Run(ctx context.Context, docInfo *lin if !has500 { // Missing 500 response if rootNode := responses.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspDefineErrorResponses500, - fmt.Errorf("operation %s %s is missing 500 Internal Server Error response", method, path), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("operation %s %s is missing 500 Internal Server Error response", method, path), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspDefineErrorResponses500, + Fix: &addErrorResponseFix{responsesNode: rootNode, statusCode: "500", description: "Internal Server Error"}, + }) } continue } diff --git a/openapi/linter/rules/owasp_define_error_validation.go b/openapi/linter/rules/owasp_define_error_validation.go index 95a71637..a5006f16 100644 --- a/openapi/linter/rules/owasp_define_error_validation.go +++ b/openapi/linter/rules/owasp_define_error_validation.go @@ -88,12 +88,13 @@ func (r *OwaspDefineErrorValidationRule) Run(ctx context.Context, docInfo *linte if has400 == nil && has422 == nil && has4XX == nil { // Missing all validation error responses if rootNode := responses.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspDefineErrorValidation, - fmt.Errorf("operation %s %s is missing validation error response (should have 400, 422, or 4XX)", method, path), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("operation %s %s is missing validation error response (should have 400, 422, or 4XX)", method, path), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspDefineErrorValidation, + Fix: &addErrorResponseFix{responsesNode: rootNode, statusCode: "400", description: "Bad Request"}, + }) } } } diff --git a/openapi/linter/rules/owasp_integer_format.go b/openapi/linter/rules/owasp_integer_format.go index e907be4b..651a829c 100644 --- a/openapi/linter/rules/owasp_integer_format.go +++ b/openapi/linter/rules/owasp_integer_format.go @@ -71,12 +71,13 @@ func (r *OwaspIntegerFormatRule) Run(ctx context.Context, docInfo *linter.Docume format := schema.GetFormat() if format != "int32" && format != "int64" { if rootNode := refSchema.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspIntegerFormat, - errors.New("schema of type `integer` must specify `format` as `int32` or `int64`"), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("schema of type `integer` must specify `format` as `int32` or `int64`"), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspIntegerFormat, + Fix: &setIntegerFormatFix{schemaNode: rootNode}, + }) } } } diff --git a/openapi/linter/rules/owasp_integer_limit.go b/openapi/linter/rules/owasp_integer_limit.go index 331304db..314a2cf8 100644 --- a/openapi/linter/rules/owasp_integer_limit.go +++ b/openapi/linter/rules/owasp_integer_limit.go @@ -84,12 +84,13 @@ func (r *OwaspIntegerLimitRule) Run(ctx context.Context, docInfo *linter.Documen if !hasMin || !hasMax { if rootNode := refSchema.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspIntegerLimit, - errors.New("schema of type `integer` must specify `minimum` and `maximum` (or `exclusiveMinimum` and `exclusiveMaximum`)"), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("schema of type `integer` must specify `minimum` and `maximum` (or `exclusiveMinimum` and `exclusiveMaximum`)"), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspIntegerLimit, + Fix: &setIntegerLimitsFix{schemaNode: rootNode}, + }) } } } diff --git a/openapi/linter/rules/owasp_jwt_best_practices.go b/openapi/linter/rules/owasp_jwt_best_practices.go index 929b6c5b..dd714501 100644 --- a/openapi/linter/rules/owasp_jwt_best_practices.go +++ b/openapi/linter/rules/owasp_jwt_best_practices.go @@ -9,6 +9,7 @@ import ( "github.com/speakeasy-api/openapi/openapi" "github.com/speakeasy-api/openapi/validation" "github.com/speakeasy-api/openapi/yml" + "gopkg.in/yaml.v3" ) const RuleOwaspJWTBestPractices = "owasp-jwt-best-practices" @@ -84,20 +85,22 @@ func (r *OwaspJWTBestPracticesRule) Run(ctx context.Context, docInfo *linter.Doc if rootNode != nil { _, descNode, found := yml.GetMapElementNodes(ctx, rootNode, "description") if found && descNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspJWTBestPractices, - fmt.Errorf("security scheme `%s` must explicitly declare support for RFC8725 in the description", name), - descNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("security scheme `%s` must explicitly declare support for RFC8725 in the description", name), + Node: descNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspJWTBestPractices, + Fix: &appendRFC8725Fix{schemeNode: rootNode, descNode: descNode}, + }) } else { // No description field - report on the scheme itself - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspJWTBestPractices, - fmt.Errorf("security scheme `%s` must explicitly declare support for RFC8725 in the description", name), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("security scheme `%s` must explicitly declare support for RFC8725 in the description", name), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspJWTBestPractices, + Fix: &appendRFC8725Fix{schemeNode: rootNode, descNode: nil}, + }) } } } @@ -105,3 +108,37 @@ func (r *OwaspJWTBestPracticesRule) Run(ctx context.Context, docInfo *linter.Doc return errs } + +const rfc8725Suffix = " This scheme follows RFC8725 best practices." + +// appendRFC8725Fix appends an RFC8725 mention to the security scheme description. +type appendRFC8725Fix struct { + schemeNode *yaml.Node // the security scheme mapping node + descNode *yaml.Node // the existing description value node (may be nil) +} + +func (f *appendRFC8725Fix) Description() string { + return "Add RFC8725 mention to security scheme description" +} +func (f *appendRFC8725Fix) Interactive() bool { return false } +func (f *appendRFC8725Fix) Prompts() []validation.Prompt { return nil } +func (f *appendRFC8725Fix) SetInput([]string) error { return nil } +func (f *appendRFC8725Fix) Apply(doc any) error { return nil } + +func (f *appendRFC8725Fix) ApplyNode(_ *yaml.Node) error { + if f.schemeNode == nil { + return nil + } + + if f.descNode != nil { + // Append to existing description + if !strings.Contains(f.descNode.Value, "RFC8725") { + f.descNode.Value += rfc8725Suffix + } + } else { + // No description field — add one + ctx := context.Background() + yml.CreateOrUpdateMapNodeElement(ctx, "description", nil, yml.CreateStringNode(strings.TrimSpace(rfc8725Suffix)), f.schemeNode) + } + return nil +} diff --git a/openapi/linter/rules/owasp_no_additional_properties.go b/openapi/linter/rules/owasp_no_additional_properties.go index cfa972e8..ef36f75e 100644 --- a/openapi/linter/rules/owasp_no_additional_properties.go +++ b/openapi/linter/rules/owasp_no_additional_properties.go @@ -7,6 +7,8 @@ import ( "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/openapi" "github.com/speakeasy-api/openapi/validation" + "github.com/speakeasy-api/openapi/yml" + "gopkg.in/yaml.v3" ) const RuleOwaspNoAdditionalProperties = "owasp-no-additional-properties" @@ -93,15 +95,44 @@ func (r *OwaspNoAdditionalPropertiesRule) Run(ctx context.Context, docInfo *lint if isViolation { if rootNode := refSchema.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspNoAdditionalProperties, - errors.New("additionalProperties should not be set to true or define a schema - set to false or omit it"), - rootNode, - )) + // Only provide auto-fix for the boolean true case + var fix validation.Fix + if additionalProps.IsBool() { + _, valueNode, found := yml.GetMapElementNodes(ctx, rootNode, "additionalProperties") + if found && valueNode != nil { + fix = &setAdditionalPropertiesFalseFix{node: valueNode} + } + } + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("additionalProperties should not be set to true or define a schema - set to false or omit it"), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspNoAdditionalProperties, + Fix: fix, + }) } } } return errs } + +// setAdditionalPropertiesFalseFix changes additionalProperties: true to additionalProperties: false. +type setAdditionalPropertiesFalseFix struct { + node *yaml.Node // the additionalProperties value node +} + +func (f *setAdditionalPropertiesFalseFix) Description() string { + return "Set additionalProperties to false" +} +func (f *setAdditionalPropertiesFalseFix) Interactive() bool { return false } +func (f *setAdditionalPropertiesFalseFix) Prompts() []validation.Prompt { return nil } +func (f *setAdditionalPropertiesFalseFix) SetInput([]string) error { return nil } +func (f *setAdditionalPropertiesFalseFix) Apply(doc any) error { return nil } + +func (f *setAdditionalPropertiesFalseFix) ApplyNode(_ *yaml.Node) error { + if f.node != nil && f.node.Kind == yaml.ScalarNode && f.node.Value == "true" { + f.node.Value = "false" + } + return nil +} diff --git a/openapi/linter/rules/owasp_rate_limit_retry_after.go b/openapi/linter/rules/owasp_rate_limit_retry_after.go index c63b2627..37ca2d84 100644 --- a/openapi/linter/rules/owasp_rate_limit_retry_after.go +++ b/openapi/linter/rules/owasp_rate_limit_retry_after.go @@ -86,15 +86,17 @@ func (r *OwaspRateLimitRetryAfterRule) Run(ctx context.Context, docInfo *linter. // Check if Retry-After header exists headers := responseObj.GetHeaders() + responseRootNode := response429.GetRootNode() if headers == nil { // No headers at all - if rootNode := response429.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspRateLimitRetryAfter, - fmt.Errorf("429 response for operation %s %s is missing Retry-After header", method, path), - rootNode, - )) + if responseRootNode != nil { + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("429 response for operation %s %s is missing Retry-After header", method, path), + Node: responseRootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspRateLimitRetryAfter, + Fix: &addRetryAfterHeaderFix{responseNode: responseRootNode}, + }) } continue } @@ -107,13 +109,14 @@ func (r *OwaspRateLimitRetryAfterRule) Run(ctx context.Context, docInfo *linter. } if !hasRetryAfter || retryAfter == nil { - if rootNode := responseObj.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspRateLimitRetryAfter, - fmt.Errorf("429 response for operation %s %s is missing Retry-After header", method, path), - rootNode, - )) + if responseRootNode != nil { + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("429 response for operation %s %s is missing Retry-After header", method, path), + Node: responseRootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspRateLimitRetryAfter, + Fix: &addRetryAfterHeaderFix{responseNode: responseRootNode}, + }) } } } diff --git a/openapi/linter/rules/owasp_security_hosts_https_oas3.go b/openapi/linter/rules/owasp_security_hosts_https_oas3.go index b86d9ece..7b4bc3a4 100644 --- a/openapi/linter/rules/owasp_security_hosts_https_oas3.go +++ b/openapi/linter/rules/owasp_security_hosts_https_oas3.go @@ -9,6 +9,7 @@ import ( "github.com/speakeasy-api/openapi/openapi" "github.com/speakeasy-api/openapi/validation" "github.com/speakeasy-api/openapi/yml" + "gopkg.in/yaml.v3" ) const RuleOwaspSecurityHostsHttpsOAS3 = "owasp-security-hosts-https-oas3" @@ -70,12 +71,13 @@ func (r *OwaspSecurityHostsHttpsOAS3Rule) Run(ctx context.Context, docInfo *lint if rootNode := server.GetRootNode(); rootNode != nil { _, urlValueNode, found := yml.GetMapElementNodes(ctx, rootNode, "url") if found && urlValueNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspSecurityHostsHttpsOAS3, - fmt.Errorf("server URL `%s` must use HTTPS protocol for security", url), - urlValueNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("server URL `%s` must use HTTPS protocol for security", url), + Node: urlValueNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspSecurityHostsHttpsOAS3, + Fix: &upgradeToHTTPSFix{node: urlValueNode}, + }) } } } @@ -83,3 +85,21 @@ func (r *OwaspSecurityHostsHttpsOAS3Rule) Run(ctx context.Context, docInfo *lint return errs } + +// upgradeToHTTPSFix replaces http:// with https:// in a server URL node. +type upgradeToHTTPSFix struct { + node *yaml.Node +} + +func (f *upgradeToHTTPSFix) Description() string { return "Upgrade server URL to HTTPS" } +func (f *upgradeToHTTPSFix) Interactive() bool { return false } +func (f *upgradeToHTTPSFix) Prompts() []validation.Prompt { return nil } +func (f *upgradeToHTTPSFix) SetInput([]string) error { return nil } +func (f *upgradeToHTTPSFix) Apply(doc any) error { return nil } + +func (f *upgradeToHTTPSFix) ApplyNode(_ *yaml.Node) error { + if f.node != nil && strings.HasPrefix(f.node.Value, "http://") { + f.node.Value = "https://" + strings.TrimPrefix(f.node.Value, "http://") + } + return nil +} diff --git a/openapi/linter/rules/owasp_string_limit.go b/openapi/linter/rules/owasp_string_limit.go index 198037ee..acf29fe3 100644 --- a/openapi/linter/rules/owasp_string_limit.go +++ b/openapi/linter/rules/owasp_string_limit.go @@ -75,12 +75,13 @@ func (r *OwaspStringLimitRule) Run(ctx context.Context, docInfo *linter.Document // If none of these are defined, report error if maxLength == nil && constValue == nil && len(enumValues) == 0 { if rootNode := refSchema.GetRootNode(); rootNode != nil { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleOwaspStringLimit, - errors.New("schema of type 'string' must specify maxLength, const, or enum to prevent unbounded data"), - rootNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: errors.New("schema of type 'string' must specify maxLength, const, or enum to prevent unbounded data"), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleOwaspStringLimit, + Fix: &setNumericPropertyFix{schemaNode: rootNode, property: "maxLength", label: "Maximum string length"}, + }) } } } diff --git a/openapi/linter/rules/parameter_description.go b/openapi/linter/rules/parameter_description.go index 386535ad..c0b69854 100644 --- a/openapi/linter/rules/parameter_description.go +++ b/openapi/linter/rules/parameter_description.go @@ -96,12 +96,14 @@ func (r *OAS3ParameterDescriptionRule) Run(ctx context.Context, docInfo *linter. msg = fmt.Sprintf("parameter `%s` is missing a description", paramName) } - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleOAS3ParameterDescription, - fmt.Errorf("%s", msg), - errNode, - )) + paramRootNode := param.GetRootNode() + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("%s", msg), + Node: errNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleOAS3ParameterDescription, + Fix: &addDescriptionFix{targetNode: paramRootNode, targetLabel: "parameter '" + paramName + "'"}, + }) } } diff --git a/openapi/linter/rules/path_trailing_slash.go b/openapi/linter/rules/path_trailing_slash.go index 87a5aabe..6b76c5a8 100644 --- a/openapi/linter/rules/path_trailing_slash.go +++ b/openapi/linter/rules/path_trailing_slash.go @@ -8,6 +8,7 @@ import ( "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/openapi" "github.com/speakeasy-api/openapi/validation" + "gopkg.in/yaml.v3" ) const RuleStylePathTrailingSlash = "style-path-trailing-slash" @@ -51,14 +52,36 @@ func (r *PathTrailingSlashRule) Run(ctx context.Context, docInfo *linter.Documen for pathKey := range paths.All() { if strings.HasSuffix(pathKey, "/") && pathKey != "/" { node := paths.GetCore().GetMapKeyNodeOrRoot(pathKey, paths.GetRootNode()) - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStylePathTrailingSlash, - fmt.Errorf("path `%s` must not end with a trailing slash", pathKey), - node, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("path `%s` must not end with a trailing slash", pathKey), + Node: node, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStylePathTrailingSlash, + Fix: &removeTrailingSlashFix{node: node}, + }) } } return errs } + +// removeTrailingSlashFix removes the trailing slash from a path key node. +type removeTrailingSlashFix struct { + node *yaml.Node +} + +func (f *removeTrailingSlashFix) Description() string { + return "Remove trailing slash from path" +} + +func (f *removeTrailingSlashFix) Interactive() bool { return false } +func (f *removeTrailingSlashFix) Prompts() []validation.Prompt { return nil } +func (f *removeTrailingSlashFix) SetInput([]string) error { return nil } +func (f *removeTrailingSlashFix) Apply(doc any) error { return nil } + +func (f *removeTrailingSlashFix) ApplyNode(_ *yaml.Node) error { + if f.node != nil { + f.node.Value = strings.TrimRight(f.node.Value, "/") + } + return nil +} diff --git a/openapi/linter/rules/rule_fixes_test.go b/openapi/linter/rules/rule_fixes_test.go new file mode 100644 index 00000000..3968d6e0 --- /dev/null +++ b/openapi/linter/rules/rule_fixes_test.go @@ -0,0 +1,563 @@ +package rules + +import ( + "context" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/openapi" + "github.com/speakeasy-api/openapi/yml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +// ============================================================ +// Tests for fix structs defined in individual rule files +// ============================================================ + +func TestRemoveHostTrailingSlashFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &removeHostTrailingSlashFix{} + assert.Equal(t, "Remove trailing slash from server URL", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + require.NoError(t, f.SetInput(nil)) + require.NoError(t, f.Apply(nil)) + }) + + t.Run("removes trailing slash", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("https://api.example.com/") + f := &removeHostTrailingSlashFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "https://api.example.com", node.Value) + }) + + t.Run("removes multiple trailing slashes", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("https://api.example.com///") + f := &removeHostTrailingSlashFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "https://api.example.com", node.Value) + }) + + t.Run("no trailing slash is no-op", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("https://api.example.com") + f := &removeHostTrailingSlashFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "https://api.example.com", node.Value) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &removeHostTrailingSlashFix{node: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestRemoveTrailingSlashFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &removeTrailingSlashFix{} + assert.Equal(t, "Remove trailing slash from path", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + }) + + t.Run("removes trailing slash from path", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("/pets/") + f := &removeTrailingSlashFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "/pets", node.Value) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &removeTrailingSlashFix{node: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestRemoveDuplicateEnumFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &removeDuplicateEnumFix{} + assert.Equal(t, "Remove duplicate enum entries", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + }) + + t.Run("removes single duplicate", func(t *testing.T) { + t.Parallel() + enumNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Tag: "!!seq", + Content: []*yaml.Node{ + yml.CreateStringNode("active"), + yml.CreateStringNode("inactive"), + yml.CreateStringNode("active"), // duplicate at index 2 + }, + } + f := &removeDuplicateEnumFix{enumNode: enumNode, duplicateIndices: []int{2}} + require.NoError(t, f.ApplyNode(nil)) + + require.Len(t, enumNode.Content, 2) + assert.Equal(t, "active", enumNode.Content[0].Value) + assert.Equal(t, "inactive", enumNode.Content[1].Value) + }) + + t.Run("removes multiple duplicates", func(t *testing.T) { + t.Parallel() + enumNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Tag: "!!seq", + Content: []*yaml.Node{ + yml.CreateStringNode("a"), + yml.CreateStringNode("b"), + yml.CreateStringNode("a"), // duplicate at index 2 + yml.CreateStringNode("b"), // duplicate at index 3 + }, + } + f := &removeDuplicateEnumFix{enumNode: enumNode, duplicateIndices: []int{2, 3}} + require.NoError(t, f.ApplyNode(nil)) + + require.Len(t, enumNode.Content, 2) + assert.Equal(t, "a", enumNode.Content[0].Value) + assert.Equal(t, "b", enumNode.Content[1].Value) + }) + + t.Run("nil enum node is no-op", func(t *testing.T) { + t.Parallel() + f := &removeDuplicateEnumFix{enumNode: nil, duplicateIndices: []int{0}} + require.NoError(t, f.ApplyNode(nil)) + }) + + t.Run("empty indices is no-op", func(t *testing.T) { + t.Parallel() + enumNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Content: []*yaml.Node{ + yml.CreateStringNode("a"), + }, + } + f := &removeDuplicateEnumFix{enumNode: enumNode, duplicateIndices: nil} + require.NoError(t, f.ApplyNode(nil)) + + require.Len(t, enumNode.Content, 1) + }) +} + +func TestSortTagsFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &sortTagsFix{} + assert.Equal(t, "Sort tags alphabetically", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + }) + + t.Run("sorts tags alphabetically", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + tagsNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Tag: "!!seq", + Content: []*yaml.Node{ + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("users"), + }), + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("admin"), + }), + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("pets"), + }), + }, + } + f := &sortTagsFix{tagsNode: tagsNode} + require.NoError(t, f.ApplyNode(nil)) + + require.Len(t, tagsNode.Content, 3) + assert.Equal(t, "admin", getTagName(tagsNode.Content[0])) + assert.Equal(t, "pets", getTagName(tagsNode.Content[1])) + assert.Equal(t, "users", getTagName(tagsNode.Content[2])) + }) + + t.Run("case insensitive sort", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + tagsNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Tag: "!!seq", + Content: []*yaml.Node{ + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("Zebra"), + }), + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("apple"), + }), + }, + } + f := &sortTagsFix{tagsNode: tagsNode} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "apple", getTagName(tagsNode.Content[0])) + assert.Equal(t, "Zebra", getTagName(tagsNode.Content[1])) + }) + + t.Run("single tag is no-op", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + tagsNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Content: []*yaml.Node{ + yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("only"), + }), + }, + } + f := &sortTagsFix{tagsNode: tagsNode} + require.NoError(t, f.ApplyNode(nil)) + + require.Len(t, tagsNode.Content, 1) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &sortTagsFix{tagsNode: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestAddGlobalTagFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addGlobalTagFix{tagName: "users"} + assert.Equal(t, "Add tag `users` to global tags", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + require.NoError(t, f.SetInput(nil)) + }) + + t.Run("adds tag to document", func(t *testing.T) { + t.Parallel() + doc := &openapi.OpenAPI{} + f := &addGlobalTagFix{tagName: "users"} + require.NoError(t, f.Apply(doc)) + + require.Len(t, doc.Tags, 1) + assert.Equal(t, "users", doc.Tags[0].Name) + }) + + t.Run("idempotent when tag exists", func(t *testing.T) { + t.Parallel() + doc := &openapi.OpenAPI{ + Tags: []*openapi.Tag{{Name: "users"}}, + } + f := &addGlobalTagFix{tagName: "users"} + require.NoError(t, f.Apply(doc)) + + require.Len(t, doc.Tags, 1, "should not duplicate tag") + }) + + t.Run("appends to existing tags", func(t *testing.T) { + t.Parallel() + doc := &openapi.OpenAPI{ + Tags: []*openapi.Tag{{Name: "pets"}}, + } + f := &addGlobalTagFix{tagName: "users"} + require.NoError(t, f.Apply(doc)) + + require.Len(t, doc.Tags, 2) + assert.Equal(t, "pets", doc.Tags[0].Name) + assert.Equal(t, "users", doc.Tags[1].Name) + }) + + t.Run("wrong doc type returns error", func(t *testing.T) { + t.Parallel() + f := &addGlobalTagFix{tagName: "users"} + err := f.Apply("not a doc") + require.Error(t, err) + assert.Contains(t, err.Error(), "expected *openapi.OpenAPI") + }) +} + +func TestUpgradeToHTTPSFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &upgradeToHTTPSFix{} + assert.Equal(t, "Upgrade server URL to HTTPS", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + }) + + t.Run("upgrades http to https", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("http://api.example.com") + f := &upgradeToHTTPSFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "https://api.example.com", node.Value) + }) + + t.Run("preserves path after upgrade", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("http://api.example.com/v1/") + f := &upgradeToHTTPSFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "https://api.example.com/v1/", node.Value) + }) + + t.Run("already https is no-op", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("https://api.example.com") + f := &upgradeToHTTPSFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "https://api.example.com", node.Value) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &upgradeToHTTPSFix{node: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestRemoveNullableFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &removeNullableFix{} + assert.Equal(t, "Replace nullable with type array including null", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + }) + + t.Run("converts scalar type to array with null", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + typeNode := yml.CreateStringNode("string") + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + typeNode, + yml.CreateStringNode("nullable"), + yml.CreateBoolNode(true), + }) + f := &removeNullableFix{schemaNode: schemaNode, typeValueNode: typeNode} + require.NoError(t, f.ApplyNode(nil)) + + // Verify nullable was removed + _, _, found := yml.GetMapElementNodes(ctx, schemaNode, "nullable") + assert.False(t, found, "nullable should be removed") + + // Verify type was converted to [string, null] + _, updatedType, found := yml.GetMapElementNodes(ctx, schemaNode, "type") + require.True(t, found) + assert.Equal(t, yaml.SequenceNode, updatedType.Kind) + require.Len(t, updatedType.Content, 2) + assert.Equal(t, "string", updatedType.Content[0].Value) + assert.Equal(t, "null", updatedType.Content[1].Value) + }) + + t.Run("appends null to existing type array", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + typeNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Tag: "!!seq", + Content: []*yaml.Node{ + yml.CreateStringNode("string"), + yml.CreateStringNode("integer"), + }, + } + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + typeNode, + yml.CreateStringNode("nullable"), + yml.CreateBoolNode(true), + }) + f := &removeNullableFix{schemaNode: schemaNode, typeValueNode: typeNode} + require.NoError(t, f.ApplyNode(nil)) + + require.Len(t, typeNode.Content, 3) + assert.Equal(t, "string", typeNode.Content[0].Value) + assert.Equal(t, "integer", typeNode.Content[1].Value) + assert.Equal(t, "null", typeNode.Content[2].Value) + }) + + t.Run("does not duplicate null in type array", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + typeNode := &yaml.Node{ + Kind: yaml.SequenceNode, + Tag: "!!seq", + Content: []*yaml.Node{ + yml.CreateStringNode("string"), + yml.CreateStringNode("null"), + }, + } + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + typeNode, + yml.CreateStringNode("nullable"), + yml.CreateBoolNode(true), + }) + f := &removeNullableFix{schemaNode: schemaNode, typeValueNode: typeNode} + require.NoError(t, f.ApplyNode(nil)) + + require.Len(t, typeNode.Content, 2, "should not add duplicate null") + }) + + t.Run("no type field adds type null", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + schemaNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("nullable"), + yml.CreateBoolNode(true), + }) + f := &removeNullableFix{schemaNode: schemaNode, typeValueNode: nil} + require.NoError(t, f.ApplyNode(nil)) + + _, typeNode, found := yml.GetMapElementNodes(ctx, schemaNode, "type") + require.True(t, found, "type field should be added") + assert.Equal(t, "null", typeNode.Value) + }) + + t.Run("nil schema is no-op", func(t *testing.T) { + t.Parallel() + f := &removeNullableFix{schemaNode: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestAppendRFC8725Fix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &appendRFC8725Fix{} + assert.Equal(t, "Add RFC8725 mention to security scheme description", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + }) + + t.Run("appends to existing description", func(t *testing.T) { + t.Parallel() + descNode := yml.CreateStringNode("OAuth2 Bearer token") + f := &appendRFC8725Fix{ + schemeNode: yml.CreateMapNode(context.Background(), nil), + descNode: descNode, + } + require.NoError(t, f.ApplyNode(nil)) + + assert.Contains(t, descNode.Value, "RFC8725") + assert.True(t, strings.HasPrefix(descNode.Value, "OAuth2 Bearer token")) + }) + + t.Run("does not duplicate RFC8725 mention", func(t *testing.T) { + t.Parallel() + descNode := yml.CreateStringNode("Already mentions RFC8725 best practices.") + f := &appendRFC8725Fix{ + schemeNode: yml.CreateMapNode(context.Background(), nil), + descNode: descNode, + } + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "Already mentions RFC8725 best practices.", descNode.Value) + }) + + t.Run("creates description when none exists", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + schemeNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode("http"), + }) + f := &appendRFC8725Fix{schemeNode: schemeNode, descNode: nil} + require.NoError(t, f.ApplyNode(nil)) + + _, desc, found := yml.GetMapElementNodes(ctx, schemeNode, "description") + require.True(t, found, "description should be created") + assert.Contains(t, desc.Value, "RFC8725") + }) + + t.Run("nil scheme is no-op", func(t *testing.T) { + t.Parallel() + f := &appendRFC8725Fix{schemeNode: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestSetAdditionalPropertiesFalseFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &setAdditionalPropertiesFalseFix{} + assert.Equal(t, "Set additionalProperties to false", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + }) + + t.Run("changes true to false", func(t *testing.T) { + t.Parallel() + node := yml.CreateBoolNode(true) + f := &setAdditionalPropertiesFalseFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "false", node.Value) + }) + + t.Run("already false is no-op", func(t *testing.T) { + t.Parallel() + node := yml.CreateBoolNode(false) + f := &setAdditionalPropertiesFalseFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, "false", node.Value) + }) + + t.Run("non-scalar node is no-op", func(t *testing.T) { + t.Parallel() + node := &yaml.Node{Kind: yaml.MappingNode} + f := &setAdditionalPropertiesFalseFix{node: node} + require.NoError(t, f.ApplyNode(nil)) + + assert.Equal(t, yaml.MappingNode, node.Kind) + }) + + t.Run("nil node is no-op", func(t *testing.T) { + t.Parallel() + f := &setAdditionalPropertiesFalseFix{node: nil} + require.NoError(t, f.ApplyNode(nil)) + }) +} diff --git a/openapi/linter/rules/tag_description.go b/openapi/linter/rules/tag_description.go index d8e28114..fd46b91c 100644 --- a/openapi/linter/rules/tag_description.go +++ b/openapi/linter/rules/tag_description.go @@ -67,12 +67,14 @@ func (r *TagDescriptionRule) Run(ctx context.Context, docInfo *linter.DocumentIn name := tag.GetName() if description == "" { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleTagDescription, - fmt.Errorf("tag `%s` must have a description", name), - tag.GetRootNode(), - )) + rootNode := tag.GetRootNode() + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("tag `%s` must have a description", name), + Node: rootNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleTagDescription, + Fix: &addDescriptionFix{targetNode: rootNode, targetLabel: "tag '" + name + "'"}, + }) } } diff --git a/openapi/linter/rules/tags_alphabetical.go b/openapi/linter/rules/tags_alphabetical.go index ca12214a..ac525435 100644 --- a/openapi/linter/rules/tags_alphabetical.go +++ b/openapi/linter/rules/tags_alphabetical.go @@ -3,11 +3,14 @@ package rules import ( "context" "fmt" + "sort" "strings" "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/openapi" "github.com/speakeasy-api/openapi/validation" + "github.com/speakeasy-api/openapi/yml" + "gopkg.in/yaml.v3" ) const RuleStyleTagsAlphabetical = "style-tags-alphabetical" @@ -79,12 +82,13 @@ func (r *TagsAlphabeticalRule) Run(ctx context.Context, docInfo *linter.Document tagsNode = doc.GetRootNode() } - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleStyleTagsAlphabetical, - fmt.Errorf("tag `%s` must be placed before `%s` (alphabetical)", nextName, currentName), - tagsNode, - )) + errs = append(errs, &validation.Error{ + UnderlyingError: fmt.Errorf("tag `%s` must be placed before `%s` (alphabetical)", nextName, currentName), + Node: tagsNode, + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleStyleTagsAlphabetical, + Fix: &sortTagsFix{tagsNode: tagsNode}, + }) // Report only the first violation for deterministic behavior break } @@ -92,3 +96,38 @@ func (r *TagsAlphabeticalRule) Run(ctx context.Context, docInfo *linter.Document return errs } + +// sortTagsFix sorts the tags sequence node alphabetically by tag name. +type sortTagsFix struct { + tagsNode *yaml.Node +} + +func (f *sortTagsFix) Description() string { return "Sort tags alphabetically" } +func (f *sortTagsFix) Interactive() bool { return false } +func (f *sortTagsFix) Prompts() []validation.Prompt { return nil } +func (f *sortTagsFix) SetInput([]string) error { return nil } +func (f *sortTagsFix) Apply(doc any) error { return nil } + +func (f *sortTagsFix) ApplyNode(_ *yaml.Node) error { + if f.tagsNode == nil || f.tagsNode.Kind != yaml.SequenceNode || len(f.tagsNode.Content) < 2 { + return nil + } + sort.SliceStable(f.tagsNode.Content, func(i, j int) bool { + nameI := getTagName(f.tagsNode.Content[i]) + nameJ := getTagName(f.tagsNode.Content[j]) + return strings.ToLower(nameI) < strings.ToLower(nameJ) + }) + return nil +} + +// getTagName extracts the "name" field value from a tag mapping node. +func getTagName(node *yaml.Node) string { + if node == nil || node.Kind != yaml.MappingNode { + return "" + } + _, valueNode, found := yml.GetMapElementNodes(context.Background(), node, "name") + if !found || valueNode == nil { + return "" + } + return valueNode.Value +} diff --git a/openapi/linter/rules/unused_components.go b/openapi/linter/rules/unused_components.go index a64ce75e..020cfc58 100644 --- a/openapi/linter/rules/unused_components.go +++ b/openapi/linter/rules/unused_components.go @@ -213,7 +213,7 @@ func extractComponentPointer(ref references.Reference, docLocation string, docSe // checkUnusedComponents iterates through all component entries in the index // and flags those not in the referenced set using ToJSONPointer. -func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[string]struct{}, config *linter.RuleConfig, severity validation.Severity) []error { +func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[string]struct{}, config *linter.RuleConfig, severity validation.Severity) []error { //nolint:cyclop var errs []error // Check component schemas @@ -228,7 +228,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -244,7 +244,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -260,7 +260,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -276,7 +276,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -292,7 +292,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -308,7 +308,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -324,7 +324,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -340,7 +340,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -356,7 +356,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -372,7 +372,7 @@ func checkUnusedComponents(doc *openapi.OpenAPI, idx *openapi.Index, refs map[st continue } errNode := getComponentKeyNode(doc, node.Location) - errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity)) + errs = append(errs, createUnusedComponentError(pointer, errNode, config, severity, doc)) } } @@ -456,12 +456,62 @@ func hasUsageMarkingExtension(exts *extensions.Extensions) bool { } // createUnusedComponentError creates a validation error for an unused component. -func createUnusedComponentError(pointer string, errNode *yaml.Node, config *linter.RuleConfig, severity validation.Severity) error { +func createUnusedComponentError(pointer string, errNode *yaml.Node, config *linter.RuleConfig, severity validation.Severity, doc *openapi.OpenAPI) error { componentRef := "#" + pointer - return validation.NewValidationError( - config.GetSeverity(severity), - RuleSemanticUnusedComponent, - fmt.Errorf("`%s` is potentially unused or has been orphaned", componentRef), - errNode, - ) + + // Build a fix to remove this unused component + var fix validation.Fix + parts := strings.Split(pointer, "/") + if len(parts) >= 4 { + componentType := parts[2] + componentName := parts[3] + if mapNode := getComponentTypeMapNode(doc, componentType); mapNode != nil { + fix = &removeUnusedComponentFix{ + parentMapNode: mapNode, + componentName: componentName, + componentRef: componentRef, + } + } + } + + return &validation.Error{ + UnderlyingError: fmt.Errorf("`%s` is potentially unused or has been orphaned", componentRef), + Node: errNode, + Severity: config.GetSeverity(severity), + Rule: RuleSemanticUnusedComponent, + Fix: fix, + } +} + +// getComponentTypeMapNode returns the YAML mapping node for a given component type. +func getComponentTypeMapNode(doc *openapi.OpenAPI, componentType string) *yaml.Node { + core := doc.GetCore() + if core == nil || !core.Components.Present || core.Components.Value == nil { + return nil + } + cc := core.Components.Value + switch componentType { + case "schemas": + return cc.Schemas.ValueNode + case "parameters": + return cc.Parameters.ValueNode + case "responses": + return cc.Responses.ValueNode + case "requestBodies": + return cc.RequestBodies.ValueNode + case "headers": + return cc.Headers.ValueNode + case "examples": + return cc.Examples.ValueNode + case "links": + return cc.Links.ValueNode + case "callbacks": + return cc.Callbacks.ValueNode + case "pathItems": + return cc.PathItems.ValueNode + case "securitySchemes": + return cc.SecuritySchemes.ValueNode + default: + return nil + } } diff --git a/validation/errors.go b/validation/errors.go index d3ff660c..a37100aa 100644 --- a/validation/errors.go +++ b/validation/errors.go @@ -48,12 +48,6 @@ type Error struct { DocumentLocation string } -// Fix represents a suggested fix for a error finding -type Fix interface { - Apply(doc any) error - FixDescription() string -} - var _ error = (*Error)(nil) func (e Error) Error() string { diff --git a/validation/fix.go b/validation/fix.go new file mode 100644 index 00000000..ec2dd6b0 --- /dev/null +++ b/validation/fix.go @@ -0,0 +1,70 @@ +package validation + +import "gopkg.in/yaml.v3" + +// PromptType describes what kind of user input a fix needs. +type PromptType int + +const ( + // PromptChoice indicates the fix requires selecting from a list of options. + PromptChoice PromptType = iota + // PromptFreeText indicates the fix requires free-form text input. + PromptFreeText +) + +// Prompt describes a single piece of input a fix needs from the user. +type Prompt struct { + // Type is the kind of input needed. + Type PromptType + // Message is a human-readable description of what input is needed and why. + Message string + // Choices is the list of valid choices when Type is PromptChoice. + // Ignored for other prompt types. + Choices []string + // Default is an optional default value. + Default string +} + +// Fix represents a suggested fix for a validation finding. +// Fixes can be non-interactive (applied automatically) or interactive +// (requiring user input before application). +type Fix interface { + // Description returns a human-readable description of what the fix does. + Description() string + + // Interactive returns true if the fix requires user input before being applied. + // Non-interactive fixes can be applied directly with Apply(). + // Interactive fixes must have SetInput() called with user responses before Apply(). + Interactive() bool + + // Prompts returns the input prompts needed for this fix. + // Returns nil for non-interactive fixes. + Prompts() []Prompt + + // SetInput provides user responses for interactive fixes. + // The responses slice must correspond 1:1 with the Prompts() slice. + // Returns an error if the input is invalid. + // Calling this on a non-interactive fix is a no-op. + SetInput(responses []string) error + + // Apply applies the fix to the document. + // For interactive fixes, SetInput() must be called first. + // The doc parameter is typically *openapi.OpenAPI. + // Returns an error if the fix cannot be applied. + Apply(doc any) error +} + +// NodeFix is an optional interface for fixes that operate directly on yaml.Node +// trees rather than the high-level document model. This is useful for simple +// textual changes (renaming a key, changing a value) where going through the +// model is unnecessary. +// +// The fix engine checks for this interface first; if implemented, ApplyNode is +// called instead of Apply. +type NodeFix interface { + Fix + + // ApplyNode applies the fix directly to the YAML node tree. + // rootNode is the document root node. + ApplyNode(rootNode *yaml.Node) error +} diff --git a/validation/prompter.go b/validation/prompter.go new file mode 100644 index 00000000..e0583480 --- /dev/null +++ b/validation/prompter.go @@ -0,0 +1,21 @@ +package validation + +import "errors" + +// ErrSkipFix is a sentinel error returned by Prompter when the user chooses +// to skip a fix. Use errors.Is(err, ErrSkipFix) to check. +var ErrSkipFix = errors.New("fix skipped by user") + +// Prompter collects user input for interactive fixes. +// Implementations can be terminal-based (stdin/stdout), GUI-based, or test stubs. +type Prompter interface { + // PromptFix presents a fix to the user and collects input for its prompts. + // finding provides the error being fixed so the prompter can display context. + // fix is the fix that needs input. + // Returns the user's responses corresponding to fix.Prompts(), or an error. + // Return ErrSkipFix (or wrap it) to indicate the user chose to skip this fix. + PromptFix(finding *Error, fix Fix) ([]string, error) + + // Confirm asks the user a yes/no question. + Confirm(message string) (bool, error) +} From 7081b9005a1bfb1dcac9a1dc3517c93e9b7373b6 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Tue, 10 Feb 2026 09:42:49 +1000 Subject: [PATCH 2/7] fix: address PR review feedback for autofixer --- cmd/openapi/commands/openapi/lint.go | 34 +++++++++++++++++++++++--- cmd/openapi/commands/openapi/shared.go | 2 +- linter/fix/engine.go | 7 ++++-- linter/fix/engine_test.go | 24 +++++++++++++++--- 4 files changed, 58 insertions(+), 9 deletions(-) diff --git a/cmd/openapi/commands/openapi/lint.go b/cmd/openapi/commands/openapi/lint.go index a60ac8ef..04ca419d 100644 --- a/cmd/openapi/commands/openapi/lint.go +++ b/cmd/openapi/commands/openapi/lint.go @@ -7,10 +7,13 @@ import ( "path/filepath" "time" + "sync" + "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/linter/fix" "github.com/speakeasy-api/openapi/openapi" openapiLinter "github.com/speakeasy-api/openapi/openapi/linter" + "github.com/speakeasy-api/openapi/validation" "github.com/spf13/cobra" // Enable custom rules support @@ -216,10 +219,12 @@ func lintOpenAPI(ctx context.Context, file string) error { } func applyFixes(ctx context.Context, fixOpts fix.Options, doc *openapi.OpenAPI, output *linter.Output, cleanFile string) error { - // Create prompter for interactive mode - var prompter *fix.TerminalPrompter + // Create prompter lazily for interactive mode — only initialized when + // an interactive fix is actually encountered, avoiding unnecessary setup + // when all fixes are non-interactive. + var prompter validation.Prompter if fixOpts.Mode == fix.ModeInteractive { - prompter = fix.NewTerminalPrompter(os.Stdin, os.Stderr) + prompter = &lazyPrompter{} } engine := fix.NewEngine(fixOpts, prompter, nil) @@ -293,6 +298,29 @@ func skipReasonString(reason fix.SkipReason) string { } } +// lazyPrompter defers TerminalPrompter creation until an interactive fix is +// actually encountered, avoiding unnecessary setup when all fixes are non-interactive. +type lazyPrompter struct { + once sync.Once + prompter *fix.TerminalPrompter +} + +func (l *lazyPrompter) init() { + l.once.Do(func() { + l.prompter = fix.NewTerminalPrompter(os.Stdin, os.Stderr) + }) +} + +func (l *lazyPrompter) PromptFix(finding *validation.Error, f validation.Fix) ([]string, error) { + l.init() + return l.prompter.PromptFix(finding, f) +} + +func (l *lazyPrompter) Confirm(message string) (bool, error) { + l.init() + return l.prompter.Confirm(message) +} + func buildLintConfig() *linter.Config { config := linter.NewConfig() diff --git a/cmd/openapi/commands/openapi/shared.go b/cmd/openapi/commands/openapi/shared.go index 74c2d361..2c2a85ca 100644 --- a/cmd/openapi/commands/openapi/shared.go +++ b/cmd/openapi/commands/openapi/shared.go @@ -96,7 +96,7 @@ func (p *OpenAPIProcessor) WriteDocument(ctx context.Context, doc *openapi.OpenA return fmt.Errorf("failed to write document: %w", err) } - fmt.Printf("📄 Document written to: %s\n", cleanOutputFile) + fmt.Fprintf(os.Stderr, "📄 Document written to: %s\n", cleanOutputFile) return nil } diff --git a/linter/fix/engine.go b/linter/fix/engine.go index 6d236075..ba87fddf 100644 --- a/linter/fix/engine.go +++ b/linter/fix/engine.go @@ -86,10 +86,13 @@ func NewEngine(opts Options, prompter validation.Prompter, registry *FixRegistry } } -// conflictKey identifies a document location for conflict detection. +// conflictKey identifies a document location and rule for conflict detection. +// Including the rule allows independent fixes from different rules at the same +// YAML node to be applied without incorrectly skipping them as conflicts. type conflictKey struct { Line int Column int + Rule string } // ProcessErrors takes lint output errors and applies fixes where available. @@ -147,7 +150,7 @@ func (e *Engine) ProcessErrors(ctx context.Context, doc *openapi.OpenAPI, errs [ vErr := fe.vErr // Check for conflicts at the same location - key := conflictKey{Line: vErr.GetLineNumber(), Column: vErr.GetColumnNumber()} + key := conflictKey{Line: vErr.GetLineNumber(), Column: vErr.GetColumnNumber(), Rule: vErr.Rule} if key.Line >= 0 && modified[key] { result.Skipped = append(result.Skipped, SkippedFix{ Error: vErr, diff --git a/linter/fix/engine_test.go b/linter/fix/engine_test.go index 3c57f913..1d6c4a26 100644 --- a/linter/fix/engine_test.go +++ b/linter/fix/engine_test.go @@ -184,15 +184,15 @@ func TestEngine_DryRun(t *testing.T) { assert.False(t, f.applied, "fix should NOT have been actually applied in dry-run") } -func TestEngine_ConflictDetection(t *testing.T) { +func TestEngine_ConflictDetection_SameRule(t *testing.T) { t.Parallel() f1 := &mockFix{description: "first fix"} f2 := &mockFix{description: "second fix"} engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ - makeError("rule-a", 5, 3, "issue 1", f1), - makeError("rule-b", 5, 3, "issue 2", f2), + makeError("same-rule", 5, 3, "issue 1", f1), + makeError("same-rule", 5, 3, "issue 2", f2), }) require.NoError(t, err, "ProcessErrors should not fail") @@ -202,6 +202,24 @@ func TestEngine_ConflictDetection(t *testing.T) { assert.False(t, f2.applied, "second fix should not have been applied") } +func TestEngine_ConflictDetection_DifferentRules(t *testing.T) { + t.Parallel() + + f1 := &mockFix{description: "first fix"} + f2 := &mockFix{description: "second fix"} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("rule-a", 5, 3, "issue 1", f1), + makeError("rule-b", 5, 3, "issue 2", f2), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 2, "should apply both fixes from different rules at the same location") + assert.Empty(t, result.Skipped, "should not skip fixes from different rules") + assert.True(t, f1.applied, "first fix should have been applied") + assert.True(t, f2.applied, "second fix should have been applied") +} + func TestEngine_FailedFix(t *testing.T) { t.Parallel() From 6c8a9f9dd0d28781572b6639b6afdf55234543f4 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Tue, 10 Feb 2026 09:44:36 +1000 Subject: [PATCH 3/7] docs: add mise ci pre-commit requirement to AGENTS.md --- AGENTS.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 573d8999..085cee3d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -58,6 +58,14 @@ mise test -count=1 ./... - **Race Detection**: Automatically enables race detection to catch concurrency issues - **Submodule Awareness**: Checks for and warns about uninitialized test submodules +## Pre-Commit CI Check + +**Always run `mise ci` before committing changes.** This runs the full CI pipeline locally (format, lint, test, build) and ensures your changes won't break CI. + +```bash +mise ci +``` + ## Git Commit Conventions **Always use single-line conventional commits.** Do not create multi-line commit messages. Do not add `Co-Authored-By` trailers. From 92e1b44f1040dde78defd091bd92a7029a0add46 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Tue, 10 Feb 2026 09:53:02 +1000 Subject: [PATCH 4/7] test: achieve 100% coverage for linter/fix package --- linter/fix/engine_test.go | 139 ++++++++++++++++++++++++ linter/fix/terminal_prompter_test.go | 151 +++++++++++++++++++++++++++ 2 files changed, 290 insertions(+) diff --git a/linter/fix/engine_test.go b/linter/fix/engine_test.go index 1d6c4a26..59b11981 100644 --- a/linter/fix/engine_test.go +++ b/linter/fix/engine_test.go @@ -59,6 +59,16 @@ func (f *mockNodeFix) ApplyNode(rootNode *yaml.Node) error { return nil } +// mockInteractiveFixWithSetInputErr is an interactive fix that fails on SetInput. +type mockInteractiveFixWithSetInputErr struct { + mockInteractiveFix + setInputErr error +} + +func (f *mockInteractiveFixWithSetInputErr) SetInput(responses []string) error { + return f.setInputErr +} + // mockPrompter is a test prompter that returns predefined responses. type mockPrompter struct { responses []string @@ -321,6 +331,135 @@ func TestEngine_ModeInteractive_NonInteractiveFixAppliesWithoutPrompter(t *testi assert.True(t, f.applied, "fix should have been applied") } +func TestEngine_SkipsNonValidationErrors(t *testing.T) { + t.Parallel() + + f := &mockFix{description: "fix"} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + errors.New("plain error, not a validation.Error"), + makeError("test-rule", 1, 1, "real issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "should only apply fix for the validation error") + assert.True(t, f.applied, "fix for validation error should have been applied") +} + +func TestEngine_ModeInteractive_PrompterError(t *testing.T) { + t.Parallel() + + f := &mockInteractiveFix{ + description: "needs input", + prompts: []validation.Prompt{{Type: validation.PromptFreeText, Message: "enter value"}}, + } + prompter := &mockPrompter{err: errors.New("terminal closed")} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeInteractive}, prompter, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Empty(t, result.Applied, "should not apply fix when prompter fails") + assert.Len(t, result.Failed, 1, "should record the fix as failed") + assert.False(t, f.applied, "fix should not have been applied") +} + +func TestEngine_ModeInteractive_SetInputError(t *testing.T) { + t.Parallel() + + f := &mockInteractiveFixWithSetInputErr{ + mockInteractiveFix: mockInteractiveFix{ + description: "needs input", + prompts: []validation.Prompt{{Type: validation.PromptFreeText, Message: "enter value"}}, + }, + setInputErr: errors.New("invalid input"), + } + prompter := &mockPrompter{responses: []string{"bad value"}} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeInteractive}, prompter, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Empty(t, result.Applied, "should not apply fix when SetInput fails") + assert.Len(t, result.Failed, 1, "should record the fix as failed") +} + +func TestEngine_NodeFix_UsesApplyNode(t *testing.T) { + t.Parallel() + + f := &mockNodeFix{mockFix: mockFix{description: "node fix"}} + rootNode := &yaml.Node{Kind: yaml.MappingNode} + doc := &openapi.OpenAPI{} + doc.GetCore().SetRootNode(rootNode) + + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), doc, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "should apply the fix") + assert.True(t, f.nodeApplied, "ApplyNode should be called when root node is present") + assert.False(t, f.applied, "Apply should not be called when ApplyNode succeeds") +} + +func TestEngine_DryRun_ConflictDetection(t *testing.T) { + t.Parallel() + + f1 := &mockFix{description: "first fix"} + f2 := &mockFix{description: "second fix"} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto, DryRun: true}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("same-rule", 5, 3, "issue 1", f1), + makeError("same-rule", 5, 3, "issue 2", f2), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + assert.Len(t, result.Applied, 1, "dry-run should record first fix as would-apply") + assert.Len(t, result.Skipped, 1, "dry-run should skip second fix as conflict") + assert.False(t, f1.applied, "fix should NOT have been actually applied in dry-run") + assert.False(t, f2.applied, "fix should NOT have been actually applied in dry-run") +} + +func TestApplyNodeFix_WithNodeFix(t *testing.T) { + t.Parallel() + + f := &mockNodeFix{mockFix: mockFix{description: "node fix"}} + rootNode := &yaml.Node{Kind: yaml.MappingNode} + doc := &openapi.OpenAPI{} + + err := fix.ApplyNodeFix(f, doc, rootNode) + require.NoError(t, err, "ApplyNodeFix should not fail") + assert.True(t, f.nodeApplied, "ApplyNode should be called") + assert.False(t, f.applied, "Apply should not be called") +} + +func TestApplyNodeFix_NilRootNode(t *testing.T) { + t.Parallel() + + f := &mockNodeFix{mockFix: mockFix{description: "node fix"}} + doc := &openapi.OpenAPI{} + + err := fix.ApplyNodeFix(f, doc, nil) + require.NoError(t, err, "ApplyNodeFix should not fail") + assert.False(t, f.nodeApplied, "ApplyNode should not be called with nil root") + assert.True(t, f.applied, "Apply should be called as fallback") +} + +func TestApplyNodeFix_RegularFix(t *testing.T) { + t.Parallel() + + f := &mockFix{description: "regular fix"} + rootNode := &yaml.Node{Kind: yaml.MappingNode} + doc := &openapi.OpenAPI{} + + err := fix.ApplyNodeFix(f, doc, rootNode) + require.NoError(t, err, "ApplyNodeFix should not fail") + assert.True(t, f.applied, "Apply should be called for non-NodeFix") +} + func TestEngine_SortsByLocation(t *testing.T) { t.Parallel() diff --git a/linter/fix/terminal_prompter_test.go b/linter/fix/terminal_prompter_test.go index bceccf53..c56e36ef 100644 --- a/linter/fix/terminal_prompter_test.go +++ b/linter/fix/terminal_prompter_test.go @@ -5,6 +5,7 @@ import ( "errors" "strings" "testing" + "testing/iotest" "github.com/speakeasy-api/openapi/linter/fix" "github.com/speakeasy-api/openapi/validation" @@ -214,6 +215,62 @@ func TestTerminalPrompter_Choice_OutOfRangeThenValid(t *testing.T) { assert.Contains(t, output.String(), "Invalid choice", "should show invalid choice message") } +func TestTerminalPrompter_FreeText_Default(t *testing.T) { + t.Parallel() + + input := strings.NewReader("\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "add value", + prompts: []validation.Prompt{ + { + Type: validation.PromptFreeText, + Message: "Enter value", + Default: "default-value", + }, + }, + } + + responses, err := prompter.PromptFix(finding, f) + require.NoError(t, err, "PromptFix should not fail") + assert.Equal(t, []string{"default-value"}, responses, "should return default when input is empty") + assert.Contains(t, output.String(), "(default: default-value)", "should display default value") +} + +func TestTerminalPrompter_FreeText_EmptyNoDefault(t *testing.T) { + t.Parallel() + + input := strings.NewReader("\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "add value", + prompts: []validation.Prompt{ + { + Type: validation.PromptFreeText, + Message: "Enter value", + }, + }, + } + + _, err := prompter.PromptFix(finding, f) + require.Error(t, err, "should return error on empty input without default") + assert.ErrorIs(t, err, validation.ErrSkipFix, "should return ErrSkipFix") +} + func TestTerminalPrompter_Confirm_Yes(t *testing.T) { t.Parallel() @@ -237,3 +294,97 @@ func TestTerminalPrompter_Confirm_No(t *testing.T) { require.NoError(t, err, "Confirm should not fail") assert.False(t, result, "should return false for 'n'") } + +func TestTerminalPrompter_UnknownPromptType(t *testing.T) { + t.Parallel() + + input := strings.NewReader("\n") + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "bad prompt type", + prompts: []validation.Prompt{ + { + Type: validation.PromptType(99), + Message: "unknown", + }, + }, + } + + _, err := prompter.PromptFix(finding, f) + require.Error(t, err, "should fail on unknown prompt type") + assert.Contains(t, err.Error(), "unknown prompt type", "error should mention unknown prompt type") +} + +func TestTerminalPrompter_Choice_ReadError(t *testing.T) { + t.Parallel() + + input := iotest.ErrReader(errors.New("read failed")) + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "pick", + prompts: []validation.Prompt{ + { + Type: validation.PromptChoice, + Message: "Choose:", + Choices: []string{"a", "b"}, + }, + }, + } + + _, err := prompter.PromptFix(finding, f) + require.Error(t, err, "should fail on read error") + assert.Contains(t, err.Error(), "reading input", "error should mention reading input") +} + +func TestTerminalPrompter_FreeText_ReadError(t *testing.T) { + t.Parallel() + + input := iotest.ErrReader(errors.New("read failed")) + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + finding := &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Rule: "test-rule", + } + f := &mockInteractiveFix{ + description: "add value", + prompts: []validation.Prompt{ + { + Type: validation.PromptFreeText, + Message: "Enter value", + }, + }, + } + + _, err := prompter.PromptFix(finding, f) + require.Error(t, err, "should fail on read error") + assert.Contains(t, err.Error(), "reading input", "error should mention reading input") +} + +func TestTerminalPrompter_Confirm_ReadError(t *testing.T) { + t.Parallel() + + input := iotest.ErrReader(errors.New("read failed")) + output := &bytes.Buffer{} + prompter := fix.NewTerminalPrompter(input, output) + + _, err := prompter.Confirm("Apply fix?") + require.Error(t, err, "should fail on read error") + assert.Contains(t, err.Error(), "reading input", "error should mention reading input") +} From d0f18823227c7935344029cac1d748eb8358ce92 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Tue, 10 Feb 2026 10:12:09 +1000 Subject: [PATCH 5/7] =?UTF-8?q?test:=20improve=20coverage=20for=20linter/f?= =?UTF-8?q?ormat=20(53%=E2=86=9299%)=20and=20openapi/linter=20(86%?= =?UTF-8?q?=E2=86=9298%)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- linter/format/format_test.go | 234 ++++++++++++++++++++++++++++++++++ openapi/linter/linter_test.go | 131 +++++++++++++++++++ 2 files changed, 365 insertions(+) diff --git a/linter/format/format_test.go b/linter/format/format_test.go index 6788041a..ed42e411 100644 --- a/linter/format/format_test.go +++ b/linter/format/format_test.go @@ -235,6 +235,240 @@ func TestTextFormatter_FixableMarker(t *testing.T) { } } +func TestTextFormatter_NonValidationError(t *testing.T) { + t.Parallel() + + formatter := format.NewTextFormatter() + result, err := formatter.Format([]error{ + errors.New("something went wrong internally"), + }) + require.NoError(t, err, "Format should not fail") + assert.Contains(t, result, "internal", "should show 'internal' rule for non-validation errors") + assert.Contains(t, result, "something went wrong internally", "should show error message") + assert.Contains(t, result, "error", "should show error severity") + assert.Contains(t, result, "1 errors", "summary should count as error") +} + +func TestTextFormatter_DocumentLocation(t *testing.T) { + t.Parallel() + + formatter := format.NewTextFormatter() + result, err := formatter.Format([]error{ + &validation.Error{ + UnderlyingError: errors.New("ref issue"), + Node: &yaml.Node{Line: 5, Column: 3}, + Severity: validation.SeverityWarning, + Rule: "ref-rule", + DocumentLocation: "/path/to/other.yaml", + }, + }) + require.NoError(t, err, "Format should not fail") + assert.Contains(t, result, "(document: /path/to/other.yaml)", "should include document location") +} + +func TestJSONFormatter_NonValidationError(t *testing.T) { + t.Parallel() + + formatter := format.NewJSONFormatter() + result, err := formatter.Format([]error{ + errors.New("internal failure"), + }) + require.NoError(t, err, "Format should not fail") + + var output struct { + Results []struct { + Rule string `json:"rule"` + Category string `json:"category"` + Severity string `json:"severity"` + Message string `json:"message"` + } `json:"results"` + Summary struct { + Errors int `json:"errors"` + Total int `json:"total"` + } `json:"summary"` + } + require.NoError(t, json.Unmarshal([]byte(result), &output), "should be valid JSON") + require.Len(t, output.Results, 1, "should have one result") + assert.Equal(t, "internal", output.Results[0].Rule, "should use 'internal' rule") + assert.Equal(t, "internal", output.Results[0].Category, "should use 'internal' category") + assert.Equal(t, "error", output.Results[0].Severity, "should use error severity") + assert.Equal(t, "internal failure", output.Results[0].Message, "should have error message") + assert.Equal(t, 1, output.Summary.Errors, "should count as error") + assert.Equal(t, 1, output.Summary.Total, "should count in total") +} + +func TestJSONFormatter_DocumentLocation(t *testing.T) { + t.Parallel() + + formatter := format.NewJSONFormatter() + result, err := formatter.Format([]error{ + &validation.Error{ + UnderlyingError: errors.New("ref issue"), + Node: &yaml.Node{Line: 5, Column: 3}, + Severity: validation.SeverityWarning, + Rule: "ref-rule", + DocumentLocation: "/path/to/other.yaml", + }, + }) + require.NoError(t, err, "Format should not fail") + + var output struct { + Results []struct { + Document string `json:"document"` + } `json:"results"` + } + require.NoError(t, json.Unmarshal([]byte(result), &output), "should be valid JSON") + require.Len(t, output.Results, 1, "should have one result") + assert.Equal(t, "/path/to/other.yaml", output.Results[0].Document, "should include document location") +} + +func TestJSONFormatter_HintSeverity(t *testing.T) { + t.Parallel() + + formatter := format.NewJSONFormatter() + result, err := formatter.Format([]error{ + &validation.Error{ + UnderlyingError: errors.New("hint message"), + Node: &yaml.Node{Line: 1, Column: 1}, + Severity: validation.SeverityHint, + Rule: "hint-rule", + }, + }) + require.NoError(t, err, "Format should not fail") + + var output struct { + Summary struct { + Hints int `json:"hints"` + Total int `json:"total"` + } `json:"summary"` + } + require.NoError(t, json.Unmarshal([]byte(result), &output), "should be valid JSON") + assert.Equal(t, 1, output.Summary.Hints, "should count hint") + assert.Equal(t, 1, output.Summary.Total, "should count in total") +} + +func TestSummaryFormatter_Empty(t *testing.T) { + t.Parallel() + + formatter := format.NewSummaryFormatter() + result, err := formatter.Format([]error{}) + require.NoError(t, err, "Format should not fail") + assert.Contains(t, result, "0 problems", "should show zero problems") + assert.Contains(t, result, "Rule", "should contain table header") +} + +func TestSummaryFormatter_SingleRule(t *testing.T) { + t.Parallel() + + formatter := format.NewSummaryFormatter() + result, err := formatter.Format([]error{ + &validation.Error{ + UnderlyingError: errors.New("error 1"), + Node: &yaml.Node{Line: 1, Column: 1}, + Severity: validation.SeverityError, + Rule: "owasp-rate-limit", + }, + }) + require.NoError(t, err, "Format should not fail") + assert.Contains(t, result, "owasp-rate-limit", "should show rule name") + assert.Contains(t, result, "owasp", "should extract category from rule prefix") + assert.Contains(t, result, "1 problems", "should show problem count") + assert.Contains(t, result, "1 errors", "should count errors") +} + +func TestSummaryFormatter_MultipleSeverities(t *testing.T) { + t.Parallel() + + formatter := format.NewSummaryFormatter() + result, err := formatter.Format([]error{ + &validation.Error{ + UnderlyingError: errors.New("err"), + Node: &yaml.Node{Line: 1, Column: 1}, + Severity: validation.SeverityError, + Rule: "err-rule", + }, + &validation.Error{ + UnderlyingError: errors.New("warn"), + Node: &yaml.Node{Line: 2, Column: 1}, + Severity: validation.SeverityWarning, + Rule: "warn-rule", + }, + &validation.Error{ + UnderlyingError: errors.New("hint"), + Node: &yaml.Node{Line: 3, Column: 1}, + Severity: validation.SeverityHint, + Rule: "hint-rule", + }, + }) + require.NoError(t, err, "Format should not fail") + assert.Contains(t, result, "3 problems", "should count all problems") + assert.Contains(t, result, "1 errors", "should count errors") + assert.Contains(t, result, "1 warnings", "should count warnings") + assert.Contains(t, result, "1 hints", "should count hints") + assert.Contains(t, result, "3 rules", "should count rules") +} + +func TestSummaryFormatter_AggregatesSameRule(t *testing.T) { + t.Parallel() + + formatter := format.NewSummaryFormatter() + result, err := formatter.Format([]error{ + &validation.Error{ + UnderlyingError: errors.New("first"), + Node: &yaml.Node{Line: 1, Column: 1}, + Severity: validation.SeverityWarning, + Rule: "oas-description", + }, + &validation.Error{ + UnderlyingError: errors.New("second"), + Node: &yaml.Node{Line: 2, Column: 1}, + Severity: validation.SeverityWarning, + Rule: "oas-description", + }, + &validation.Error{ + UnderlyingError: errors.New("other"), + Node: &yaml.Node{Line: 3, Column: 1}, + Severity: validation.SeverityWarning, + Rule: "oas-other", + }, + }) + require.NoError(t, err, "Format should not fail") + assert.Contains(t, result, "2 rules", "should count distinct rules") + + // The rule with count 2 should appear first (sorted by count descending) + descIdx := strings.Index(result, "oas-description") + otherIdx := strings.Index(result, "oas-other") + assert.Greater(t, otherIdx, descIdx, "higher-count rule should appear before lower-count rule") +} + +func TestSummaryFormatter_NonValidationError(t *testing.T) { + t.Parallel() + + formatter := format.NewSummaryFormatter() + result, err := formatter.Format([]error{ + errors.New("plain error"), + }) + require.NoError(t, err, "Format should not fail") + assert.Contains(t, result, "internal", "should bucket non-validation errors as 'internal'") + assert.Contains(t, result, "1 errors", "should count as error") +} + +func TestSummaryFormatter_RuleWithoutCategory(t *testing.T) { + t.Parallel() + + formatter := format.NewSummaryFormatter() + result, err := formatter.Format([]error{ + &validation.Error{ + UnderlyingError: errors.New("issue"), + Node: &yaml.Node{Line: 1, Column: 1}, + Severity: validation.SeverityWarning, + Rule: "norule", + }, + }) + require.NoError(t, err, "Format should not fail") + assert.Contains(t, result, "unknown", "should use 'unknown' category for rules without a dash") +} + func TestJSONFormatter_FixMetadata(t *testing.T) { t.Parallel() diff --git a/openapi/linter/linter_test.go b/openapi/linter/linter_test.go index 7e5793d8..0f383326 100644 --- a/openapi/linter/linter_test.go +++ b/openapi/linter/linter_test.go @@ -8,14 +8,19 @@ import ( "strings" "testing" + "regexp" + "sync/atomic" + "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/openapi" openapiLinter "github.com/speakeasy-api/openapi/openapi/linter" "github.com/speakeasy-api/openapi/openapi/linter/rules" "github.com/speakeasy-api/openapi/pointer" "github.com/speakeasy-api/openapi/references" + "github.com/speakeasy-api/openapi/validation" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) type mockVirtualFS struct { @@ -455,6 +460,132 @@ paths: }, documentErrors) } +func TestLinter_FilterErrors(t *testing.T) { + t.Parallel() + + disabled := true + config := &linter.Config{ + Extends: []string{"all"}, + Rules: []linter.RuleEntry{ + { + ID: rules.RuleSemanticPathParams, + Match: regexp.MustCompile(".*userId.*"), + Disabled: &disabled, + }, + }, + } + lntr, err := openapiLinter.NewLinter(config) + require.NoError(t, err) + + errs := []error{ + &validation.Error{ + UnderlyingError: errors.New("userId is missing"), + Node: &yaml.Node{Line: 1, Column: 1}, + Severity: validation.SeverityError, + Rule: rules.RuleSemanticPathParams, + }, + &validation.Error{ + UnderlyingError: errors.New("orderId is missing"), + Node: &yaml.Node{Line: 2, Column: 1}, + Severity: validation.SeverityError, + Rule: rules.RuleSemanticPathParams, + }, + } + + filtered := lntr.FilterErrors(errs) + // The "userId" error matches the disable pattern, so it should be removed + // The "orderId" error does NOT match, so it passes through + require.Len(t, filtered, 1, "should filter out the matching error") + assert.Contains(t, filtered[0].Error(), "orderId", "should keep the non-matching error") +} + +// testLoaderCalled tracks whether the test custom rule loader was invoked. +// Protected by the global customRuleLoadersMu inside RegisterCustomRuleLoader. +var testLoaderCalled int32 + +func init() { + // Register a single test loader that branches on a sentinel path. + // This avoids race conditions from registering loaders inside parallel tests. + openapiLinter.RegisterCustomRuleLoader(func(config *linter.CustomRulesConfig) ([]linter.RuleRunner[*openapi.OpenAPI], error) { + for _, p := range config.Paths { + if p == "__test_error_path__" { + return nil, errors.New("failed to compile rules") + } + } + atomic.AddInt32(&testLoaderCalled, 1) + return nil, nil + }) +} + +func TestNewLinter_CustomRuleLoader(t *testing.T) { + t.Parallel() + + before := atomic.LoadInt32(&testLoaderCalled) + + config := &linter.Config{ + CustomRules: &linter.CustomRulesConfig{ + Paths: []string{"./rules/*.ts"}, + }, + } + lntr, err := openapiLinter.NewLinter(config) + require.NoError(t, err, "NewLinter should not fail") + assert.NotNil(t, lntr, "linter should be created") + + after := atomic.LoadInt32(&testLoaderCalled) + assert.Greater(t, after, before, "custom rule loader should have been called") +} + +func TestNewLinter_CustomRuleLoaderError(t *testing.T) { + t.Parallel() + + config := &linter.Config{ + CustomRules: &linter.CustomRulesConfig{ + Paths: []string{"__test_error_path__"}, + }, + } + _, err := openapiLinter.NewLinter(config) + require.Error(t, err, "NewLinter should fail when custom rule loader fails") + assert.Contains(t, err.Error(), "loading custom rules", "error should mention loading custom rules") +} + +func TestLinter_Lint_WithResolveOptions(t *testing.T) { + t.Parallel() + ctx := t.Context() + + yamlInput := ` +openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /users: + get: + operationId: getUsers + responses: + '200': + description: ok +` + + doc, _, err := openapi.Unmarshal(ctx, strings.NewReader(yamlInput)) + require.NoError(t, err) + + config := &linter.Config{ + Extends: []string{}, + } + lntr, err := openapiLinter.NewLinter(config) + require.NoError(t, err) + + docInfo := linter.NewDocumentInfo(doc, "/spec/openapi.yaml") + output, err := lntr.Lint(ctx, docInfo, nil, &linter.LintOptions{ + ResolveOptions: &references.ResolveOptions{ + DisableExternalRefs: true, + SkipValidation: true, + }, + }) + require.NoError(t, err, "Lint should not fail with resolve options") + assert.NotNil(t, output, "should return output") +} + func TestNewLinter_WithoutDefaultRules(t *testing.T) { t.Parallel() ctx := t.Context() From ee57e4ff7964024fb89d6e4c85e092b7a7d5ad1e Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Tue, 10 Feb 2026 11:08:27 +1000 Subject: [PATCH 6/7] test: cover FixAvailable() for all 32 fixable rules --- openapi/linter/rules/rule_metadata_test.go | 25 ++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/openapi/linter/rules/rule_metadata_test.go b/openapi/linter/rules/rule_metadata_test.go index cc0333e8..61a47594 100644 --- a/openapi/linter/rules/rule_metadata_test.go +++ b/openapi/linter/rules/rule_metadata_test.go @@ -14,6 +14,11 @@ type howToFixer interface { HowToFix() string } +// fixAvailabler is the interface satisfied by rules that advertise auto-fix support. +type fixAvailabler interface { + FixAvailable() bool +} + // allRules returns every built-in rule instance. func allRules() []linter.RuleRunner[*openapi.OpenAPI] { return []linter.RuleRunner[*openapi.OpenAPI]{ @@ -82,6 +87,26 @@ func allRules() []linter.RuleRunner[*openapi.OpenAPI] { } } +func TestAllRules_FixAvailable(t *testing.T) { + t.Parallel() + + fixableCount := 0 + + for _, rule := range allRules() { + if fa, ok := rule.(fixAvailabler); ok { + fixableCount++ + + t.Run(rule.ID(), func(t *testing.T) { + t.Parallel() + + assert.True(t, fa.FixAvailable(), "rule %s should report fix available", rule.ID()) + }) + } + } + + assert.Equal(t, 32, fixableCount, "expected 32 rules to implement FixAvailable") +} + func TestAllRules_MetadataPopulated(t *testing.T) { t.Parallel() From 6eb0a12dfa560359499e0fc34cdcb64a2f029dbf Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Wed, 11 Feb 2026 20:16:09 +1000 Subject: [PATCH 7/7] missing fixes --- cmd/openapi/commands/openapi/lint.go | 5 +- linter/fix/engine.go | 29 ++- linter/fix/engine_test.go | 65 +++++ openapi/linter/rules/fix_available.go | 1 + openapi/linter/rules/fix_helpers.go | 76 ++++++ openapi/linter/rules/fix_integration_test.go | 77 ++++++ openapi/linter/rules/host_trailing_slash.go | 7 + .../rules/owasp_security_hosts_https_oas3.go | 7 + openapi/linter/rules/path_params.go | 33 ++- openapi/linter/rules/path_trailing_slash.go | 7 + openapi/linter/rules/rule_fixes_test.go | 239 ++++++++++++++++++ openapi/linter/rules/rule_metadata_test.go | 2 +- validation/fix.go | 7 + 13 files changed, 543 insertions(+), 12 deletions(-) diff --git a/cmd/openapi/commands/openapi/lint.go b/cmd/openapi/commands/openapi/lint.go index 04ca419d..2ac486f9 100644 --- a/cmd/openapi/commands/openapi/lint.go +++ b/cmd/openapi/commands/openapi/lint.go @@ -227,7 +227,7 @@ func applyFixes(ctx context.Context, fixOpts fix.Options, doc *openapi.OpenAPI, prompter = &lazyPrompter{} } - engine := fix.NewEngine(fixOpts, prompter, nil) + engine := fix.NewEngine(fixOpts, prompter, fix.NewFixRegistry()) result, err := engine.ProcessErrors(ctx, doc, output.Results) if err != nil { return fmt.Errorf("fix processing failed: %w", err) @@ -263,6 +263,9 @@ func reportFixResults(result *fix.Result, dryRun bool) { fmt.Fprintf(os.Stderr, " [%d:%d] %s - %s\n", af.Error.GetLineNumber(), af.Error.GetColumnNumber(), af.Error.Rule, af.Fix.Description()) + if af.Before != "" || af.After != "" { + fmt.Fprintf(os.Stderr, " %s -> %s\n", af.Before, af.After) + } } } diff --git a/linter/fix/engine.go b/linter/fix/engine.go index ba87fddf..7431f70e 100644 --- a/linter/fix/engine.go +++ b/linter/fix/engine.go @@ -45,8 +45,10 @@ const ( // AppliedFix records a successfully applied fix. type AppliedFix struct { - Error *validation.Error - Fix validation.Fix + Error *validation.Error + Fix validation.Fix + Before string // populated from ChangeDescriber if implemented + After string // populated from ChangeDescriber if implemented } // SkippedFix records a fix that was skipped. @@ -97,6 +99,17 @@ type conflictKey struct { // ProcessErrors takes lint output errors and applies fixes where available. // The doc is modified in-place by successful fixes. +// +// Pipeline ordering: +// 1. Fixable errors are collected from both Error.Fix fields and the FixRegistry. +// 2. Errors are sorted by document location (line, column ascending) so fixes +// are applied in first-in-document-order. This ensures deterministic results. +// 3. Conflict detection: the key is {line, column, rule}. If two errors from +// the same rule share a location, only the first (by sort order) is applied. +// Different rules CAN independently fix the same location. +// 4. Interactive fixes are skipped in ModeAuto or when no prompter is available. +// 5. In dry-run mode, fixes are recorded without modifying the document but +// conflict detection still operates. func (e *Engine) ProcessErrors(ctx context.Context, doc *openapi.OpenAPI, errs []error) (*Result, error) { if e.opts.Mode == ModeNone { return &Result{}, nil @@ -172,7 +185,7 @@ func (e *Engine) ProcessErrors(ctx context.Context, doc *openapi.OpenAPI, errs [ // Dry-run: record what would happen without applying if e.opts.DryRun { - result.Applied = append(result.Applied, AppliedFix{Error: vErr, Fix: fix}) + result.Applied = append(result.Applied, makeAppliedFix(vErr, fix)) if key.Line >= 0 { modified[key] = true } @@ -230,12 +243,20 @@ func (e *Engine) ProcessErrors(ctx context.Context, doc *openapi.OpenAPI, errs [ modified[key] = true } - result.Applied = append(result.Applied, AppliedFix{Error: vErr, Fix: fix}) + result.Applied = append(result.Applied, makeAppliedFix(vErr, fix)) } return result, nil } +func makeAppliedFix(vErr *validation.Error, fix validation.Fix) AppliedFix { + af := AppliedFix{Error: vErr, Fix: fix} + if cd, ok := fix.(validation.ChangeDescriber); ok { + af.Before, af.After = cd.DescribeChange() + } + return af +} + // ApplyNodeFix is a helper that applies a NodeFix if the fix implements the interface, // otherwise falls back to Apply. func ApplyNodeFix(fix validation.Fix, doc *openapi.OpenAPI, rootNode *yaml.Node) error { diff --git a/linter/fix/engine_test.go b/linter/fix/engine_test.go index 59b11981..32f23f80 100644 --- a/linter/fix/engine_test.go +++ b/linter/fix/engine_test.go @@ -460,6 +460,71 @@ func TestApplyNodeFix_RegularFix(t *testing.T) { assert.True(t, f.applied, "Apply should be called for non-NodeFix") } +// mockChangeDescriberFix is a fix that implements ChangeDescriber. +type mockChangeDescriberFix struct { + mockFix + before string + after string +} + +func (f *mockChangeDescriberFix) DescribeChange() (string, string) { + return f.before, f.after +} + +func TestEngine_ChangeDescriber_PopulatesBeforeAfter(t *testing.T) { + t.Parallel() + + f := &mockChangeDescriberFix{ + mockFix: mockFix{description: "trim slash"}, + before: "/users/", + after: "/users", + } + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "trailing slash", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + require.Len(t, result.Applied, 1, "should apply the fix") + assert.Equal(t, "/users/", result.Applied[0].Before, "should populate Before from ChangeDescriber") + assert.Equal(t, "/users", result.Applied[0].After, "should populate After from ChangeDescriber") +} + +func TestEngine_ChangeDescriber_DryRun(t *testing.T) { + t.Parallel() + + f := &mockChangeDescriberFix{ + mockFix: mockFix{description: "upgrade https"}, + before: "http://api.example.com", + after: "https://api.example.com", + } + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto, DryRun: true}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "use https", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + require.Len(t, result.Applied, 1, "should record the fix as would-apply") + assert.Equal(t, "http://api.example.com", result.Applied[0].Before, "dry-run should populate Before") + assert.Equal(t, "https://api.example.com", result.Applied[0].After, "dry-run should populate After") + assert.False(t, f.applied, "fix should NOT have been actually applied in dry-run") +} + +func TestEngine_NoChangeDescriber_EmptyBeforeAfter(t *testing.T) { + t.Parallel() + + f := &mockFix{description: "simple fix"} + engine := fix.NewEngine(fix.Options{Mode: fix.ModeAuto}, nil, nil) + result, err := engine.ProcessErrors(t.Context(), &openapi.OpenAPI{}, []error{ + makeError("test-rule", 1, 1, "issue", f), + }) + + require.NoError(t, err, "ProcessErrors should not fail") + require.Len(t, result.Applied, 1, "should apply the fix") + assert.Empty(t, result.Applied[0].Before, "Before should be empty without ChangeDescriber") + assert.Empty(t, result.Applied[0].After, "After should be empty without ChangeDescriber") +} + func TestEngine_SortsByLocation(t *testing.T) { t.Parallel() diff --git a/openapi/linter/rules/fix_available.go b/openapi/linter/rules/fix_available.go index 6455ae3e..7935ae95 100644 --- a/openapi/linter/rules/fix_available.go +++ b/openapi/linter/rules/fix_available.go @@ -35,3 +35,4 @@ func (r *OwaspArrayLimitRule) FixAvailable() bool { return func (r *OwaspIntegerLimitRule) FixAvailable() bool { return true } func (r *OwaspAdditionalPropertiesConstrainedRule) FixAvailable() bool { return true } func (r *UnusedComponentRule) FixAvailable() bool { return true } +func (r *PathParamsRule) FixAvailable() bool { return true } diff --git a/openapi/linter/rules/fix_helpers.go b/openapi/linter/rules/fix_helpers.go index 509f9c86..08d8dcf7 100644 --- a/openapi/linter/rules/fix_helpers.go +++ b/openapi/linter/rules/fix_helpers.go @@ -572,3 +572,79 @@ func (f *removeUnusedComponentFix) ApplyNode(_ *yaml.Node) error { yml.DeleteMapNodeElement(ctx, f.componentName, f.parentMapNode) return nil } + +// addPathParameterFix adds a missing path parameter definition to an operation. +type addPathParameterFix struct { + operationNode *yaml.Node // the operation mapping node + paramName string // e.g. "userId" + schemaType string // "integer" or "string" + schemaFormat string // e.g. "uuid" or "" +} + +func (f *addPathParameterFix) Description() string { + desc := "Add missing path parameter '" + f.paramName + "'" + if f.schemaFormat != "" { + desc += " (type: " + f.schemaType + ", format: " + f.schemaFormat + ")" + } else { + desc += " (type: " + f.schemaType + ")" + } + return desc +} +func (f *addPathParameterFix) Interactive() bool { return false } +func (f *addPathParameterFix) Prompts() []validation.Prompt { return nil } +func (f *addPathParameterFix) SetInput([]string) error { return nil } +func (f *addPathParameterFix) Apply(doc any) error { return nil } + +func (f *addPathParameterFix) ApplyNode(_ *yaml.Node) error { + if f.operationNode == nil || f.operationNode.Kind != yaml.MappingNode { + return nil + } + + ctx := context.Background() + + // Get or create parameters sequence + _, paramsNode, found := yml.GetMapElementNodes(ctx, f.operationNode, "parameters") + if !found || paramsNode == nil || paramsNode.Kind != yaml.SequenceNode { + paramsNode = &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"} + yml.CreateOrUpdateMapNodeElement(ctx, "parameters", nil, paramsNode, f.operationNode) + } + + // Idempotency: check if parameter already exists + for _, elem := range paramsNode.Content { + if elem.Kind != yaml.MappingNode { + continue + } + _, nameNode, nameFound := yml.GetMapElementNodes(ctx, elem, "name") + _, inNode, inFound := yml.GetMapElementNodes(ctx, elem, "in") + if nameFound && inFound && nameNode.Value == f.paramName && inNode.Value == "path" { + return nil // already exists + } + } + + // Build schema node + schemaContent := []*yaml.Node{ + yml.CreateStringNode("type"), + yml.CreateStringNode(f.schemaType), + } + if f.schemaFormat != "" { + schemaContent = append(schemaContent, + yml.CreateStringNode("format"), + yml.CreateStringNode(f.schemaFormat)) + } + schemaNode := yml.CreateMapNode(ctx, schemaContent) + + // Build parameter node + paramNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode(f.paramName), + yml.CreateStringNode("in"), + yml.CreateStringNode("path"), + yml.CreateStringNode("required"), + yml.CreateBoolNode(true), + yml.CreateStringNode("schema"), + schemaNode, + }) + + paramsNode.Content = append(paramsNode.Content, paramNode) + return nil +} diff --git a/openapi/linter/rules/fix_integration_test.go b/openapi/linter/rules/fix_integration_test.go index d81e204f..9bf7a844 100644 --- a/openapi/linter/rules/fix_integration_test.go +++ b/openapi/linter/rules/fix_integration_test.go @@ -437,6 +437,83 @@ components: assert.Empty(t, errs2, "fix should resolve the array limit violation") } +func TestFixIntegration_MissingPathParam(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: + /users/{userId}: + get: + operationId: getUser + responses: + "200": + description: OK +`) + + rule := &PathParamsRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.Len(t, errs, 1, "should detect missing userId param") + + // Extract and apply fix + nodeFix := extractNodeFix(t, errs) + var valErr0 *validation.Error + require.ErrorAs(t, errs[0], &valErr0) + assert.False(t, valErr0.Fix.Interactive(), "should be non-interactive") + require.NoError(t, nodeFix.ApplyNode(nil)) + + // Re-parse and verify violation is resolved + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fix should resolve the missing path param violation") +} + +func TestFixIntegration_MissingPathParam_TypeInference(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := parseOpenAPIDoc(t, ` +openapi: "3.1.0" +info: + title: Test + version: "1.0" +paths: + /users/{userId}/sessions/{sessionUuid}: + get: + operationId: getSession + responses: + "200": + description: OK +`) + + rule := &PathParamsRule{} + docInfo := buildDocInfo(t, doc) + errs := rule.Run(ctx, docInfo, &linter.RuleConfig{}) + require.Len(t, errs, 2, "should detect both missing params") + + // Apply all fixes + for _, err := range errs { + var valErr *validation.Error + require.ErrorAs(t, err, &valErr) + require.NotNil(t, valErr.Fix, "should have a fix") + nodeFix, ok := valErr.Fix.(validation.NodeFix) + require.True(t, ok) + require.NoError(t, nodeFix.ApplyNode(nil)) + } + + // Re-parse and verify + doc2 := remarshalAndParse(t, doc) + docInfo2 := buildDocInfo(t, doc2) + errs2 := rule.Run(ctx, docInfo2, &linter.RuleConfig{}) + assert.Empty(t, errs2, "fixes should resolve all missing path param violations") +} + // ============================================================ // Fix engine integration test // ============================================================ diff --git a/openapi/linter/rules/host_trailing_slash.go b/openapi/linter/rules/host_trailing_slash.go index 469aa41e..0c44a857 100644 --- a/openapi/linter/rules/host_trailing_slash.go +++ b/openapi/linter/rules/host_trailing_slash.go @@ -94,6 +94,13 @@ func (f *removeHostTrailingSlashFix) Prompts() []validation.Prompt { return nil func (f *removeHostTrailingSlashFix) SetInput([]string) error { return nil } func (f *removeHostTrailingSlashFix) Apply(doc any) error { return nil } +func (f *removeHostTrailingSlashFix) DescribeChange() (string, string) { + if f.node == nil { + return "", "" + } + return f.node.Value, strings.TrimRight(f.node.Value, "/") +} + func (f *removeHostTrailingSlashFix) ApplyNode(_ *yaml.Node) error { if f.node != nil { f.node.Value = strings.TrimRight(f.node.Value, "/") diff --git a/openapi/linter/rules/owasp_security_hosts_https_oas3.go b/openapi/linter/rules/owasp_security_hosts_https_oas3.go index 7b4bc3a4..b19a0a3f 100644 --- a/openapi/linter/rules/owasp_security_hosts_https_oas3.go +++ b/openapi/linter/rules/owasp_security_hosts_https_oas3.go @@ -97,6 +97,13 @@ func (f *upgradeToHTTPSFix) Prompts() []validation.Prompt { return nil } func (f *upgradeToHTTPSFix) SetInput([]string) error { return nil } func (f *upgradeToHTTPSFix) Apply(doc any) error { return nil } +func (f *upgradeToHTTPSFix) DescribeChange() (string, string) { + if f.node == nil || !strings.HasPrefix(f.node.Value, "http://") { + return "", "" + } + return f.node.Value, "https://" + strings.TrimPrefix(f.node.Value, "http://") +} + func (f *upgradeToHTTPSFix) ApplyNode(_ *yaml.Node) error { if f.node != nil && strings.HasPrefix(f.node.Value, "http://") { f.node.Value = "https://" + strings.TrimPrefix(f.node.Value, "http://") diff --git a/openapi/linter/rules/path_params.go b/openapi/linter/rules/path_params.go index 459679e5..f03dcec2 100644 --- a/openapi/linter/rules/path_params.go +++ b/openapi/linter/rules/path_params.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "regexp" + "strings" "github.com/speakeasy-api/openapi/linter" "github.com/speakeasy-api/openapi/openapi" @@ -91,12 +92,19 @@ func (r *PathParamsRule) Run(ctx context.Context, docInfo *linter.DocumentInfo[* // 1. All template params must be in effectiveParams for _, tmplParam := range templateParams { if _, ok := effectiveParams[tmplParam]; !ok { - errs = append(errs, validation.NewValidationError( - config.GetSeverity(r.DefaultSeverity()), - RuleSemanticPathParams, - fmt.Errorf("path parameter `{%s}` is not defined in operation parameters", tmplParam), - op.GetRootNode(), - )) + schemaType, schemaFormat := inferPathParamType(tmplParam) + errs = append(errs, &validation.Error{ + Severity: config.GetSeverity(r.DefaultSeverity()), + Rule: RuleSemanticPathParams, + UnderlyingError: fmt.Errorf("path parameter `{%s}` is not defined in operation parameters", tmplParam), + Node: op.GetRootNode(), + Fix: &addPathParameterFix{ + operationNode: op.GetRootNode(), + paramName: tmplParam, + schemaType: schemaType, + schemaFormat: schemaFormat, + }, + }) } } @@ -184,3 +192,16 @@ func mergeParameters(base, override map[string]bool) map[string]bool { } return result } + +// inferPathParamType guesses the schema type for a path parameter based on naming conventions. +func inferPathParamType(name string) (schemaType, format string) { + lower := strings.ToLower(name) + switch { + case strings.Contains(lower, "uuid") || strings.Contains(lower, "guid"): + return "string", "uuid" + case strings.HasSuffix(lower, "id"): + return "integer", "" + default: + return "string", "" + } +} diff --git a/openapi/linter/rules/path_trailing_slash.go b/openapi/linter/rules/path_trailing_slash.go index 6b76c5a8..e758ef27 100644 --- a/openapi/linter/rules/path_trailing_slash.go +++ b/openapi/linter/rules/path_trailing_slash.go @@ -79,6 +79,13 @@ func (f *removeTrailingSlashFix) Prompts() []validation.Prompt { return nil } func (f *removeTrailingSlashFix) SetInput([]string) error { return nil } func (f *removeTrailingSlashFix) Apply(doc any) error { return nil } +func (f *removeTrailingSlashFix) DescribeChange() (string, string) { + if f.node == nil { + return "", "" + } + return f.node.Value, strings.TrimRight(f.node.Value, "/") +} + func (f *removeTrailingSlashFix) ApplyNode(_ *yaml.Node) error { if f.node != nil { f.node.Value = strings.TrimRight(f.node.Value, "/") diff --git a/openapi/linter/rules/rule_fixes_test.go b/openapi/linter/rules/rule_fixes_test.go index 3968d6e0..92fd988e 100644 --- a/openapi/linter/rules/rule_fixes_test.go +++ b/openapi/linter/rules/rule_fixes_test.go @@ -61,6 +61,23 @@ func TestRemoveHostTrailingSlashFix(t *testing.T) { f := &removeHostTrailingSlashFix{node: nil} require.NoError(t, f.ApplyNode(nil)) }) + + t.Run("describe change", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("https://api.example.com/") + f := &removeHostTrailingSlashFix{node: node} + before, after := f.DescribeChange() + assert.Equal(t, "https://api.example.com/", before, "before should be original value") + assert.Equal(t, "https://api.example.com", after, "after should have slash removed") + }) + + t.Run("describe change nil node", func(t *testing.T) { + t.Parallel() + f := &removeHostTrailingSlashFix{node: nil} + before, after := f.DescribeChange() + assert.Empty(t, before, "before should be empty for nil node") + assert.Empty(t, after, "after should be empty for nil node") + }) } func TestRemoveTrailingSlashFix(t *testing.T) { @@ -88,6 +105,23 @@ func TestRemoveTrailingSlashFix(t *testing.T) { f := &removeTrailingSlashFix{node: nil} require.NoError(t, f.ApplyNode(nil)) }) + + t.Run("describe change", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("/pets/") + f := &removeTrailingSlashFix{node: node} + before, after := f.DescribeChange() + assert.Equal(t, "/pets/", before, "before should be original value") + assert.Equal(t, "/pets", after, "after should have slash removed") + }) + + t.Run("describe change nil node", func(t *testing.T) { + t.Parallel() + f := &removeTrailingSlashFix{node: nil} + before, after := f.DescribeChange() + assert.Empty(t, before, "before should be empty for nil node") + assert.Empty(t, after, "after should be empty for nil node") + }) } func TestRemoveDuplicateEnumFix(t *testing.T) { @@ -349,6 +383,32 @@ func TestUpgradeToHTTPSFix(t *testing.T) { f := &upgradeToHTTPSFix{node: nil} require.NoError(t, f.ApplyNode(nil)) }) + + t.Run("describe change", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("http://api.example.com") + f := &upgradeToHTTPSFix{node: node} + before, after := f.DescribeChange() + assert.Equal(t, "http://api.example.com", before, "before should be original value") + assert.Equal(t, "https://api.example.com", after, "after should be upgraded to HTTPS") + }) + + t.Run("describe change nil node", func(t *testing.T) { + t.Parallel() + f := &upgradeToHTTPSFix{node: nil} + before, after := f.DescribeChange() + assert.Empty(t, before, "before should be empty for nil node") + assert.Empty(t, after, "after should be empty for nil node") + }) + + t.Run("describe change already https", func(t *testing.T) { + t.Parallel() + node := yml.CreateStringNode("https://api.example.com") + f := &upgradeToHTTPSFix{node: node} + before, after := f.DescribeChange() + assert.Empty(t, before, "before should be empty when already HTTPS") + assert.Empty(t, after, "after should be empty when already HTTPS") + }) } func TestRemoveNullableFix(t *testing.T) { @@ -561,3 +621,182 @@ func TestSetAdditionalPropertiesFalseFix(t *testing.T) { require.NoError(t, f.ApplyNode(nil)) }) } + +func TestAddPathParameterFix(t *testing.T) { + t.Parallel() + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + f := &addPathParameterFix{paramName: "userId", schemaType: "integer"} + assert.Equal(t, "Add missing path parameter 'userId' (type: integer)", f.Description()) + assert.False(t, f.Interactive()) + assert.Nil(t, f.Prompts()) + require.NoError(t, f.SetInput(nil)) + require.NoError(t, f.Apply(nil)) + }) + + t.Run("metadata with format", func(t *testing.T) { + t.Parallel() + f := &addPathParameterFix{paramName: "requestUuid", schemaType: "string", schemaFormat: "uuid"} + assert.Equal(t, "Add missing path parameter 'requestUuid' (type: string, format: uuid)", f.Description()) + }) + + t.Run("adds param to existing parameters sequence", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Create an operation node with existing parameters + paramsSeq := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"} + opNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("parameters"), + paramsSeq, + }) + + f := &addPathParameterFix{ + operationNode: opNode, + paramName: "userId", + schemaType: "integer", + } + require.NoError(t, f.ApplyNode(nil)) + + // Verify parameter was added + _, updatedParams, found := yml.GetMapElementNodes(ctx, opNode, "parameters") + require.True(t, found, "parameters should exist") + require.Equal(t, yaml.SequenceNode, updatedParams.Kind) + assert.Len(t, updatedParams.Content, 1, "should have one parameter") + + // Verify parameter content + param := updatedParams.Content[0] + _, nameNode, _ := yml.GetMapElementNodes(ctx, param, "name") + _, inNode, _ := yml.GetMapElementNodes(ctx, param, "in") + _, reqNode, _ := yml.GetMapElementNodes(ctx, param, "required") + assert.Equal(t, "userId", nameNode.Value) + assert.Equal(t, "path", inNode.Value) + assert.Equal(t, "true", reqNode.Value) + + _, schemaNode, _ := yml.GetMapElementNodes(ctx, param, "schema") + require.NotNil(t, schemaNode) + _, typeNode, _ := yml.GetMapElementNodes(ctx, schemaNode, "type") + assert.Equal(t, "integer", typeNode.Value) + }) + + t.Run("creates parameters sequence when missing", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Create an operation node without parameters + opNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("summary"), + yml.CreateStringNode("Get user"), + }) + + f := &addPathParameterFix{ + operationNode: opNode, + paramName: "userId", + schemaType: "string", + } + require.NoError(t, f.ApplyNode(nil)) + + // Verify parameters was created + _, updatedParams, found := yml.GetMapElementNodes(ctx, opNode, "parameters") + require.True(t, found, "parameters should be created") + require.Equal(t, yaml.SequenceNode, updatedParams.Kind) + assert.Len(t, updatedParams.Content, 1, "should have one parameter") + }) + + t.Run("includes format when specified", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + opNode := yml.CreateMapNode(ctx, nil) + + f := &addPathParameterFix{ + operationNode: opNode, + paramName: "requestUuid", + schemaType: "string", + schemaFormat: "uuid", + } + require.NoError(t, f.ApplyNode(nil)) + + _, paramsNode, _ := yml.GetMapElementNodes(ctx, opNode, "parameters") + param := paramsNode.Content[0] + _, schemaNode, _ := yml.GetMapElementNodes(ctx, param, "schema") + _, formatNode, found := yml.GetMapElementNodes(ctx, schemaNode, "format") + require.True(t, found, "format should be present") + assert.Equal(t, "uuid", formatNode.Value) + }) + + t.Run("idempotent - does not add duplicate", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Pre-populate with existing path param + existingParam := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("name"), + yml.CreateStringNode("userId"), + yml.CreateStringNode("in"), + yml.CreateStringNode("path"), + }) + paramsSeq := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq", Content: []*yaml.Node{existingParam}} + opNode := yml.CreateMapNode(ctx, []*yaml.Node{ + yml.CreateStringNode("parameters"), + paramsSeq, + }) + + f := &addPathParameterFix{ + operationNode: opNode, + paramName: "userId", + schemaType: "integer", + } + require.NoError(t, f.ApplyNode(nil)) + + _, updatedParams, _ := yml.GetMapElementNodes(ctx, opNode, "parameters") + assert.Len(t, updatedParams.Content, 1, "should still have one parameter (no duplicate)") + }) + + t.Run("nil operation node is no-op", func(t *testing.T) { + t.Parallel() + f := &addPathParameterFix{operationNode: nil, paramName: "userId", schemaType: "string"} + require.NoError(t, f.ApplyNode(nil)) + }) + + t.Run("non-mapping operation node is no-op", func(t *testing.T) { + t.Parallel() + f := &addPathParameterFix{ + operationNode: &yaml.Node{Kind: yaml.SequenceNode}, + paramName: "userId", + schemaType: "string", + } + require.NoError(t, f.ApplyNode(nil)) + }) +} + +func TestInferPathParamType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + paramName string + expectedType string + expectedFormat string + }{ + {name: "userId inferred as integer", paramName: "userId", expectedType: "integer", expectedFormat: ""}, + {name: "postId inferred as integer", paramName: "postId", expectedType: "integer", expectedFormat: ""}, + {name: "orgid inferred as integer", paramName: "orgid", expectedType: "integer", expectedFormat: ""}, + {name: "requestUuid inferred as string uuid", paramName: "requestUuid", expectedType: "string", expectedFormat: "uuid"}, + {name: "sessionGuid inferred as string uuid", paramName: "sessionGuid", expectedType: "string", expectedFormat: "uuid"}, + {name: "UUID uppercase inferred as string uuid", paramName: "UUID", expectedType: "string", expectedFormat: "uuid"}, + {name: "name inferred as string", paramName: "name", expectedType: "string", expectedFormat: ""}, + {name: "slug inferred as string", paramName: "slug", expectedType: "string", expectedFormat: ""}, + {name: "version inferred as string", paramName: "version", expectedType: "string", expectedFormat: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + schemaType, format := inferPathParamType(tt.paramName) + assert.Equal(t, tt.expectedType, schemaType, "schema type should match") + assert.Equal(t, tt.expectedFormat, format, "format should match") + }) + } +} diff --git a/openapi/linter/rules/rule_metadata_test.go b/openapi/linter/rules/rule_metadata_test.go index 61a47594..3eb9499c 100644 --- a/openapi/linter/rules/rule_metadata_test.go +++ b/openapi/linter/rules/rule_metadata_test.go @@ -104,7 +104,7 @@ func TestAllRules_FixAvailable(t *testing.T) { } } - assert.Equal(t, 32, fixableCount, "expected 32 rules to implement FixAvailable") + assert.Equal(t, 33, fixableCount, "expected 33 rules to implement FixAvailable") } func TestAllRules_MetadataPopulated(t *testing.T) { diff --git a/validation/fix.go b/validation/fix.go index ec2dd6b0..0364e3d7 100644 --- a/validation/fix.go +++ b/validation/fix.go @@ -54,6 +54,13 @@ type Fix interface { Apply(doc any) error } +// ChangeDescriber is an optional interface that fixes can implement to provide +// human-readable before/after descriptions of what the fix changes. This enables +// richer dry-run and reporting output. +type ChangeDescriber interface { + DescribeChange() (before, after string) +} + // NodeFix is an optional interface for fixes that operate directly on yaml.Node // trees rather than the high-level document model. This is useful for simple // textual changes (renaming a key, changing a value) where going through the