|
1 | 1 | from copy import deepcopy |
| 2 | +from typing import Any |
2 | 3 | from typing import Optional |
3 | 4 | from typing import cast |
4 | 5 |
|
5 | 6 | from jsonschema._format import FormatChecker |
6 | 7 | from jsonschema.protocols import Validator |
| 8 | +from jsonschema.validators import validator_for |
7 | 9 | from jsonschema_path import SchemaPath |
8 | 10 |
|
9 | 11 | from openapi_core.validation.schemas._validators import ( |
10 | 12 | build_enforce_properties_required_validator, |
11 | 13 | ) |
| 14 | +from openapi_core.validation.schemas._validators import ( |
| 15 | + build_forbid_unspecified_additional_properties_validator, |
| 16 | +) |
12 | 17 | from openapi_core.validation.schemas.datatypes import FormatValidatorsDict |
13 | 18 | from openapi_core.validation.schemas.validators import SchemaValidator |
14 | 19 |
|
15 | 20 |
|
16 | 21 | class SchemaValidatorsFactory: |
17 | 22 | def __init__( |
18 | 23 | self, |
19 | | - schema_validator_class: type[Validator], |
20 | | - strict_schema_validator_class: Optional[type[Validator]] = None, |
| 24 | + schema_validator_cls: type[Validator], |
21 | 25 | format_checker: Optional[FormatChecker] = None, |
22 | 26 | ): |
23 | | - self.schema_validator_class = schema_validator_class |
24 | | - self.strict_schema_validator_class = strict_schema_validator_class |
| 27 | + self.schema_validator_cls = schema_validator_cls |
25 | 28 | if format_checker is None: |
26 | | - format_checker = self.schema_validator_class.FORMAT_CHECKER |
| 29 | + format_checker = self.schema_validator_cls.FORMAT_CHECKER |
27 | 30 | assert format_checker is not None |
28 | 31 | self.format_checker = format_checker |
29 | 32 |
|
| 33 | + def get_validator_cls( |
| 34 | + self, spec: SchemaPath, schema: SchemaPath |
| 35 | + ) -> type[Validator]: |
| 36 | + return self.schema_validator_cls |
| 37 | + |
30 | 38 | def get_format_checker( |
31 | 39 | self, |
32 | 40 | format_validators: Optional[FormatValidatorsDict] = None, |
@@ -57,34 +65,90 @@ def _add_validators( |
57 | 65 |
|
58 | 66 | def create( |
59 | 67 | self, |
| 68 | + spec: SchemaPath, |
60 | 69 | schema: SchemaPath, |
61 | 70 | format_validators: Optional[FormatValidatorsDict] = None, |
62 | 71 | extra_format_validators: Optional[FormatValidatorsDict] = None, |
63 | 72 | forbid_unspecified_additional_properties: bool = False, |
64 | 73 | enforce_properties_required: bool = False, |
65 | 74 | ) -> SchemaValidator: |
66 | | - validator_class: type[Validator] = self.schema_validator_class |
| 75 | + validator_cls: type[Validator] = self.get_validator_cls(spec, schema) |
| 76 | + if enforce_properties_required: |
| 77 | + validator_cls = build_enforce_properties_required_validator( |
| 78 | + validator_cls |
| 79 | + ) |
67 | 80 | if forbid_unspecified_additional_properties: |
68 | | - if self.strict_schema_validator_class is None: |
69 | | - raise ValueError( |
70 | | - "Strict additional properties validation is not supported " |
71 | | - "by this factory." |
| 81 | + validator_cls = ( |
| 82 | + build_forbid_unspecified_additional_properties_validator( |
| 83 | + validator_cls |
72 | 84 | ) |
73 | | - validator_class = self.strict_schema_validator_class |
74 | | - |
75 | | - if enforce_properties_required: |
76 | | - validator_class = build_enforce_properties_required_validator( |
77 | | - validator_class |
78 | 85 | ) |
79 | 86 |
|
80 | 87 | format_checker = self.get_format_checker( |
81 | 88 | format_validators, extra_format_validators |
82 | 89 | ) |
83 | 90 | with schema.resolve() as resolved: |
84 | | - jsonschema_validator = validator_class( |
| 91 | + jsonschema_validator = validator_cls( |
85 | 92 | resolved.contents, |
86 | 93 | _resolver=resolved.resolver, |
87 | 94 | format_checker=format_checker, |
88 | 95 | ) |
89 | 96 |
|
90 | 97 | return SchemaValidator(schema, jsonschema_validator) |
| 98 | + |
| 99 | + |
| 100 | +class DialectSchemaValidatorsFactory(SchemaValidatorsFactory): |
| 101 | + def __init__( |
| 102 | + self, |
| 103 | + schema_validator_cls: type[Validator], |
| 104 | + default_jsonschema_dialect_id: str, |
| 105 | + format_checker: Optional[FormatChecker] = None, |
| 106 | + ): |
| 107 | + super().__init__(schema_validator_cls, format_checker) |
| 108 | + self.default_jsonschema_dialect_id = default_jsonschema_dialect_id |
| 109 | + |
| 110 | + self._validator_classes_by_dialect: dict[ |
| 111 | + str, type[Validator] | None |
| 112 | + ] = {} |
| 113 | + |
| 114 | + def get_validator_cls( |
| 115 | + self, spec: SchemaPath, schema: SchemaPath |
| 116 | + ) -> type[Validator]: |
| 117 | + dialect_id = self._get_dialect_id(spec, schema) |
| 118 | + |
| 119 | + validator_cls = self._get_validator_class_for_dialect(dialect_id) |
| 120 | + if validator_cls is None: |
| 121 | + raise ValueError(f"Unknown JSON Schema dialect: {dialect_id!r}") |
| 122 | + |
| 123 | + return validator_cls |
| 124 | + |
| 125 | + def _get_dialect_id( |
| 126 | + self, |
| 127 | + spec: SchemaPath, |
| 128 | + schema: SchemaPath, |
| 129 | + ) -> str: |
| 130 | + try: |
| 131 | + return (schema / "$schema").read_str() |
| 132 | + except KeyError: |
| 133 | + return self._get_default_jsonschema_dialect_id(spec) |
| 134 | + |
| 135 | + def _get_default_jsonschema_dialect_id(self, spec: SchemaPath) -> str: |
| 136 | + return (spec / "jsonSchemaDialect").read_str( |
| 137 | + default=self.default_jsonschema_dialect_id |
| 138 | + ) |
| 139 | + |
| 140 | + def _get_validator_class_for_dialect( |
| 141 | + self, dialect_id: str |
| 142 | + ) -> type[Validator] | None: |
| 143 | + if dialect_id in self._validator_classes_by_dialect: |
| 144 | + return self._validator_classes_by_dialect[dialect_id] |
| 145 | + |
| 146 | + validator_cls = cast( |
| 147 | + type[Validator] | None, |
| 148 | + validator_for( |
| 149 | + {"$schema": dialect_id}, |
| 150 | + default=cast(Any, None), |
| 151 | + ), |
| 152 | + ) |
| 153 | + self._validator_classes_by_dialect[dialect_id] = validator_cls |
| 154 | + return validator_cls |
0 commit comments