From c70781d498b1b6ff4202372303f50af25fa86cf6 Mon Sep 17 00:00:00 2001 From: Ezeudoh Tochukwu Date: Tue, 16 Dec 2025 22:22:10 +0100 Subject: [PATCH 1/3] refactor: use reflect module for controller metadata management Replace direct attribute storage with reflect-based metadata for APIController instances and route functions. This centralizes metadata handling and enables better isolation in tests via reflect.context(). --- ninja_extra/constants.py | 11 + ninja_extra/context.py | 5 + ninja_extra/controllers/__init__.py | 16 - ninja_extra/controllers/base.py | 312 +++++++------- ninja_extra/controllers/model/builder.py | 22 +- ninja_extra/controllers/model/endpoints.py | 19 +- ninja_extra/controllers/registry.py | 57 ++- ninja_extra/controllers/route/__init__.py | 12 +- .../controllers/route/route_functions.py | 33 +- ninja_extra/controllers/utils.py | 15 + ninja_extra/helper.py | 15 +- ninja_extra/main.py | 66 ++- ninja_extra/ordering/operation.py | 4 +- ninja_extra/pagination/operations.py | 17 +- ninja_extra/reflect/__init__.py | 9 + ninja_extra/reflect/_reflect.py | 390 ++++++++++++++++++ ninja_extra/reflect/utils.py | 113 +++++ ninja_extra/searching/operations.py | 4 +- ninja_extra/testing/client.py | 31 +- tests/conftest.py | 31 ++ tests/test_api_instance.py | 15 +- tests/test_controller.py | 77 ++-- tests/test_controller_registry.py | 36 +- tests/test_deprecation_warnings.py | 12 - .../test_model_async_controller_operation.py | 5 +- tests/test_operation.py | 29 +- tests/test_ordering.py | 40 +- tests/test_pagination.py | 7 +- tests/test_reflect.py | 186 +++++++++ tests/test_route.py | 124 +++--- tests/test_searching.py | 48 ++- .../test_throttle_controller.py | 3 +- 32 files changed, 1318 insertions(+), 446 deletions(-) create mode 100644 ninja_extra/controllers/utils.py create mode 100644 ninja_extra/reflect/__init__.py create mode 100644 ninja_extra/reflect/_reflect.py create mode 100644 ninja_extra/reflect/utils.py create mode 100644 tests/test_reflect.py diff --git a/ninja_extra/constants.py b/ninja_extra/constants.py index 6031c65d..3c5752a0 100644 --- a/ninja_extra/constants.py +++ b/ninja_extra/constants.py @@ -18,6 +18,17 @@ THROTTLED_OBJECTS = "__throttled_objects__" ROUTE_FUNCTION = "__route_function__" +CONTROLLER_OPERATION_HANDLER_KEY = "CONTROLLER_OPERATION_HANDLER" +CONTROLLER_WATERMARK = "CONTROLLER_WATERMARK" +ROUTE_OBJECT = "ROUTE_OBJECT" +ROUTE_OBJECT_FUNCTION = "ROUTE_OBJECT_FUNCTION" +OPERATION_ENDPOINT_KEY = "OPERATION_ENDPOINT_KEY" +API_CONTROLLER_INSTANCE = "API_CONTROLLER_INSTANCE" +ORDERATOR_OBJECT = "ORDERATOR_WATERMARK" +SEARCH_OPERATOR_OBJECT = "SEARCH_OPERATOR_OBJECT" +PAGINATOR_OBJECT = "PAGINATOR_WATERMARK" +NINJA_EXTRA_API_CONTROLLER_REGISTERED_KEY = "NINJA_EXTRA_API_CONTROLLER_REGISTERED_KEY" + ROUTE_CONTEXT_VAR: contextvars.ContextVar[t.Optional["RouteContext"]] = ( contextvars.ContextVar("ROUTE_CONTEXT_VAR") ) diff --git a/ninja_extra/context.py b/ninja_extra/context.py index 975027f7..76b5aab3 100644 --- a/ninja_extra/context.py +++ b/ninja_extra/context.py @@ -58,6 +58,11 @@ def __init__( self._view_signature = view_signature self._has_computed_route_parameters = False + @property + def api(self) -> "NinjaExtraAPI": + assert self._api, "API instance is not set in RouteContext" + return self._api + @property def has_computed_route_parameters(self) -> bool: return self._has_computed_route_parameters diff --git a/ninja_extra/controllers/__init__.py b/ninja_extra/controllers/__init__.py index f8b718dd..b62d5f41 100644 --- a/ninja_extra/controllers/__init__.py +++ b/ninja_extra/controllers/__init__.py @@ -1,6 +1,3 @@ -import typing as t -import warnings - from .base import ControllerBase, ModelControllerBase, api_controller from .model import ( ModelAsyncEndpointFactory, @@ -49,16 +46,3 @@ "ModelEndpointFactory", "ModelAsyncEndpointFactory", ] - - -def __getattr__(name: str) -> t.Any: - if name == "RouteContext": - warnings.warn( - "RouteContext is deprecated and will be removed in a future version.", - DeprecationWarning, - stacklevel=2, - ) - from ninja_extra.context import RouteContext - - return RouteContext - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/ninja_extra/controllers/base.py b/ninja_extra/controllers/base.py index 2ec563f0..99408d1c 100644 --- a/ninja_extra/controllers/base.py +++ b/ninja_extra/controllers/base.py @@ -4,40 +4,31 @@ import inspect import re +import typing as t import uuid import warnings -from abc import ABC -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, - overload, -) from django.db.models import Model, QuerySet from django.http import HttpResponse from django.urls import URLPattern, URLResolver, include from django.urls import path as django_path from injector import inject, is_decorated_with_inject -from ninja import NinjaAPI, Router +from ninja import Router from ninja.constants import NOT_SET, NOT_SET_TYPE from ninja.security.base import AuthBase from ninja.signature import is_async from ninja.throttling import BaseThrottle from ninja.utils import normalize_path -from ninja_extra.constants import ROUTE_FUNCTION, THROTTLED_FUNCTION, THROTTLED_OBJECTS +from ninja_extra.constants import ( + API_CONTROLLER_INSTANCE, + CONTROLLER_WATERMARK, + NINJA_EXTRA_API_CONTROLLER_REGISTERED_KEY, + OPERATION_ENDPOINT_KEY, + ROUTE_OBJECT, + THROTTLED_FUNCTION, + THROTTLED_OBJECTS, +) from ninja_extra.context import RouteContext from ninja_extra.exceptions import APIException, NotFound, PermissionDenied from ninja_extra.helper import get_function_name @@ -48,6 +39,7 @@ BasePermission, BasePermissionType, ) +from ninja_extra.reflect import reflect from ninja_extra.shortcuts import ( aget_object_or_exception, aget_object_or_none, @@ -57,57 +49,51 @@ ) from .model import ModelConfig, ModelControllerBuilder, ModelService -from .registry import ControllerRegistry +from .registry import controller_registry from .route.route_functions import AsyncRouteFunction, RouteFunction -if TYPE_CHECKING: # pragma: no cover +if t.TYPE_CHECKING: # pragma: no cover from ninja_extra import NinjaExtraAPI from ninja_extra.controllers.model import ModelConfig + from ninja_extra.controllers.route import Route -T = TypeVar("T") - +T = t.TypeVar("T") -class MissingAPIControllerDecoratorException(Exception): - pass - -def get_route_functions(cls: Type) -> Iterable[RouteFunction]: +def get_route_functions( + klass: t.Type, + api_controller_instance: "APIController", +) -> t.Iterator[RouteFunction]: """ - Get all route functions from a controller class. - This function will recursively search for route functions in the base classes of the controller class - in order that they are defined. + Get all route functions from a class, creating RouteFunction instances from decorated methods. - Args: - cls (Type): The controller class. + This function scans a class for methods decorated with route decorators (e.g., @http_get, @http_post) + and yields RouteFunction or AsyncRouteFunction instances for each. - Returns: - Iterable[RouteFunction]: An iterable of route functions. + :param klass: The class to scan for route functions + :param api_controller_instance: The APIController instance associated with the class + :return: An iterator of RouteFunction instances """ - bases = inspect.getmro(cls) - for base_cls in reversed(bases): - if base_cls not in [ControllerBase, ABC, object]: - for method in base_cls.__dict__.values(): - if hasattr(method, ROUTE_FUNCTION): - yield getattr(method, ROUTE_FUNCTION) - - -def get_all_controller_route_function( - controller: Union[Type["ControllerBase"], Type], -) -> List[RouteFunction]: # pragma: no cover - route_functions: List[RouteFunction] = [] - for item in dir(controller): - attr = getattr(controller, item) - if isinstance(attr, RouteFunction): - route_functions.append(attr) - return route_functions + for _, method in inspect.getmembers(klass, predicate=inspect.isfunction): + if hasattr(method, OPERATION_ENDPOINT_KEY): + route_obj: "Route" = reflect.get_metadata_or_raise_exception( + ROUTE_OBJECT, method + ) + if route_obj.is_async: + yield AsyncRouteFunction( + route_obj, api_controller=api_controller_instance + ) + else: + yield RouteFunction(route_obj, api_controller=api_controller_instance) def compute_api_route_function( - base_cls: Type, api_controller_instance: "APIController" + base_cls: t.Type, api_controller_instance: "APIController" ) -> None: - for cls_route_function in get_route_functions(base_cls): - cls_route_function.api_controller = api_controller_instance + controller_routes = list(get_route_functions(base_cls, api_controller_instance)) + controller_routes.reverse() + for cls_route_function in controller_routes: api_controller_instance.add_controller_route_function(cls_route_function) @@ -138,30 +124,15 @@ def some_method_name(self): ``` """ - # `_api_controller` a reference to APIController instance - _api_controller: Optional["APIController"] = None - - # `api` a reference to NinjaExtraAPI on APIController registration - api: Optional[NinjaAPI] = None - # `context` variable will change based on the route function called on the APIController # that way we can get some specific items things that belong the route function during execution - context: Optional["RouteContext"] = None - throttling_classes: List[Type["BaseThrottle"]] = [] - throttling_init_kwargs: Optional[Dict[Any, Any]] = None - - @classmethod - def get_api_controller(cls) -> "APIController": - if not cls._api_controller: - raise MissingAPIControllerDecoratorException( - "api_controller not found. " - "Did you forget to use the `api_controller` decorator" - ) - return cls._api_controller + context: t.Optional["RouteContext"] = None + throttling_classes: t.List[t.Type["BaseThrottle"]] = [] + throttling_init_kwargs: t.Optional[t.Dict[t.Any, t.Any]] = None @classmethod def permission_denied( - cls, permission: Union[BasePermission, AsyncBasePermission] + cls, permission: t.Union[BasePermission, AsyncBasePermission] ) -> None: """ This method is called when the permission check fails. By default, it raises an exception. @@ -172,11 +143,11 @@ def permission_denied( def get_object_or_exception( self, - klass: Union[Type[Model], QuerySet], - error_message: Optional[str] = None, - exception: Type[APIException] = NotFound, - **kwargs: Any, - ) -> Any: + klass: t.Union[t.Type[Model], QuerySet], + error_message: t.Optional[str] = None, + exception: t.Type[APIException] = NotFound, + **kwargs: t.Any, + ) -> t.Any: obj = get_object_or_exception( klass=klass, error_message=error_message, exception=exception, **kwargs ) @@ -185,11 +156,11 @@ def get_object_or_exception( async def aget_object_or_exception( self, - klass: Union[Type[Model], QuerySet], - error_message: Optional[str] = None, - exception: Type[APIException] = NotFound, - **kwargs: Any, - ) -> Any: + klass: t.Union[t.Type[Model], QuerySet], + error_message: t.Optional[str] = None, + exception: t.Type[APIException] = NotFound, + **kwargs: t.Any, + ) -> t.Any: obj = await aget_object_or_exception( klass=klass, error_message=error_message, exception=exception, **kwargs ) @@ -197,29 +168,31 @@ async def aget_object_or_exception( return obj def get_object_or_none( - self, klass: Union[Type[Model], QuerySet], **kwargs: Any - ) -> Optional[Any]: + self, klass: t.Union[t.Type[Model], QuerySet], **kwargs: t.Any + ) -> t.Optional[t.Any]: obj = get_object_or_none(klass=klass, **kwargs) if obj: self.check_object_permissions(obj) return obj async def aget_object_or_none( - self, klass: Union[Type[Model], QuerySet], **kwargs: Any - ) -> Optional[Any]: + self, klass: t.Union[t.Type[Model], QuerySet], **kwargs: t.Any + ) -> t.Optional[t.Any]: obj = await aget_object_or_none(klass=klass, **kwargs) if obj: await self.async_check_object_permissions(obj) return obj - def _get_permissions(self) -> Iterable[Union[BasePermission, AsyncBasePermission]]: + def _get_permissions( + self, + ) -> t.Iterable[t.Union[BasePermission, AsyncBasePermission]]: """ Instantiates and returns the list of permissions that this view requires. """ assert self.context for permission_class in self.context.permission_classes: - permission_instance: Union[BasePermission, AsyncBasePermission] = ( + permission_instance: t.Union[BasePermission, AsyncBasePermission] = ( permission_class # type: ignore[assignment] ) if isinstance(permission_class, type) and issubclass( @@ -244,7 +217,7 @@ def check_permissions(self) -> None: ): self.permission_denied(permission) - def check_object_permissions(self, obj: Union[Any, Model]) -> None: + def check_object_permissions(self, obj: t.Union[t.Any, Model]) -> None: """ Check if the request should be permitted for a given object. Raises an appropriate exception if the request is not permitted. @@ -285,7 +258,7 @@ async def async_check_permissions(self) -> None: if not has_permission: self.permission_denied(permission) - async def async_check_object_permissions(self, obj: Union[Any, Model]) -> None: + async def async_check_object_permissions(self, obj: t.Union[t.Any, Model]) -> None: """ Asynchronous version of check_object_permissions. Check if the request should be permitted for a given object, using async permission checks when available. @@ -311,14 +284,14 @@ async def async_check_object_permissions(self, obj: Union[Any, Model]) -> None: self.permission_denied(permission) def create_response( - self, message: Any, status_code: int = 200, **kwargs: Any + self, message: t.Any, status_code: int = 200, **kwargs: t.Any ) -> HttpResponse: - assert self.api and self.context and self.context.request - content = self.api.renderer.render( + assert self.context and self.context.request + content = self.context.api.renderer.render( self.context.request, message, response_status=status_code ) content_type = "{}; charset={}".format( - self.api.renderer.media_type, self.api.renderer.charset + self.context.api.renderer.media_type, self.context.api.renderer.charset ) return HttpResponse( content, status=status_code, content_type=content_type, **kwargs @@ -342,17 +315,17 @@ class SomeController(ControllerBase): ``` """ - service_type: Type[ModelService] = ModelService + service_type: t.Type[ModelService] = ModelService def __init__(self, service: ModelService): self.service = service - model_config: Optional["ModelConfig"] = None + model_config: t.Optional["ModelConfig"] = None -ControllerClassType = TypeVar( +ControllerClassType = t.TypeVar( "ControllerClassType", - bound=Union[Type[ControllerBase], Type], + bound=t.Union[t.Type[ControllerBase], t.Type], ) @@ -398,19 +371,19 @@ def __init__( self, prefix: str, *, - auth: Any = NOT_SET, - throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, - tags: Union[Optional[List[str]], str] = None, - permissions: Optional[List[BasePermissionType]] = None, + auth: t.Any = NOT_SET, + throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, + tags: t.Union[t.Optional[t.List[str]], str] = None, + permissions: t.Optional[t.List[BasePermissionType]] = None, auto_import: bool = True, - urls_namespace: Optional[str] = None, + urls_namespace: t.Optional[str] = None, use_unique_op_id: bool = True, ) -> None: self.prefix = prefix # Optional controller-level URL namespace. Applied to all route paths. self.urls_namespace = urls_namespace or None # `auth` primarily defines APIController route function global authentication method. - self.auth: Optional[AuthBase] = auth + self.auth: t.Optional[AuthBase] = auth self.tags = tags # type: ignore self.throttle = throttle @@ -418,15 +391,13 @@ def __init__( self.auto_import: bool = auto_import # set to false and it would be ignored when api.auto_discover is called # `controller_class` target class that the APIController wraps - self._controller_class: Optional[Type["ControllerBase"]] = None + self._controller_class: t.Optional[t.Type["ControllerBase"]] = None # `_path_operations` a converted dict of APIController route function used by Django-Ninja library - self._path_operations: Dict[str, PathView] = {} - self._controller_class_route_functions: Dict[str, RouteFunction] = {} + self._path_operations: t.Dict[str, PathView] = {} + self._controller_class_route_functions: t.Dict[str, RouteFunction] = {} # `permission_classes` a collection of BasePermission Types # a fallback if route functions has no permissions definition - self.permission_classes: List[BasePermissionType] = permissions or [AllowAny] - # `registered` prevents controllers from being register twice or exist in two different `api` instances - self.registered: bool = False + self.permission_classes: t.List[BasePermissionType] = permissions or [AllowAny] self._prefix_has_route_param = False @@ -435,7 +406,7 @@ def __init__( self.has_auth_async = False if auth and auth is not NOT_SET: - auth_callbacks = isinstance(auth, Sequence) and auth or [auth] + auth_callbacks = isinstance(auth, t.Sequence) and auth or [auth] for _auth in auth_callbacks: _call_back = _auth if inspect.isfunction(_auth) else _auth.__call__ if is_async(_call_back): @@ -452,22 +423,22 @@ def __init__( ) @property - def prefix_route_params(self) -> Dict[str, str]: + def prefix_route_params(self) -> t.Dict[str, str]: return self._prefix_route_params @property - def controller_class(self) -> Type["ControllerBase"]: + def controller_class(self) -> t.Type["ControllerBase"]: assert self._controller_class, "Controller Class is not available" return self._controller_class @property - def tags(self) -> Optional[List[str]]: + def tags(self) -> t.Optional[t.List[str]]: # `tags` is a property for grouping endpoint in Swagger API docs return self._tags @tags.setter - def tags(self, value: Union[str, List[str], None]) -> None: - tag: Optional[List[str]] = cast(Optional[List[str]], value) + def tags(self, value: t.Union[str, t.List[str], None]) -> None: + tag: t.Optional[t.List[str]] = t.cast(t.Optional[t.List[str]], value) if tag and isinstance(value, str): tag = [value] self._tags = tag @@ -476,17 +447,21 @@ def __call__(self, cls: ControllerClassType) -> ControllerClassType: self.auto_import = getattr(cls, "auto_import", self.auto_import) if not issubclass(cls, ControllerBase): # We force the cls to inherit from `ControllerBase` by creating another type. - cls = type(cls.__name__, (ControllerBase, cls), {"_api_controller": self}) # type:ignore[assignment] - else: - cls._api_controller = self + cls = type(cls.__name__, (ControllerBase, cls), {}) # type:ignore[assignment] + + if reflect.has_metadata(API_CONTROLLER_INSTANCE, cls): + raise Exception("Controller is already decorated with @api_controller") + + reflect.define_metadata(API_CONTROLLER_INSTANCE, self, cls) + reflect.define_metadata(CONTROLLER_WATERMARK, True, cls) assert isinstance(cls.throttling_classes, (list, tuple)), ( f"Controller[{cls.__name__}].throttling_class must be a list or tuple" ) - throttling_objects: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = ( - NOT_SET - ) + throttling_objects: t.Union[ + BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE + ] = NOT_SET if self.throttle is not NOT_SET: throttling_objects = self.throttle @@ -537,19 +512,30 @@ def __call__(self, cls: ControllerClassType) -> ControllerClassType: if not is_decorated_with_inject(cls.__init__): fail_silently(inject, constructor_or_class=cls) - ControllerRegistry().add_controller(cls) + controller_registry.add_controller(cls) return cls @property - def path_operations(self) -> Dict[str, PathView]: + def path_operations(self) -> t.Dict[str, PathView]: return self._path_operations def set_api_instance(self, api: "NinjaExtraAPI") -> None: - self.controller_class.api = api + reflect.define_metadata( + NINJA_EXTRA_API_CONTROLLER_REGISTERED_KEY, {id(api)}, self + ) for path_view in self.path_operations.values(): - path_view.set_api_instance(api, cast(Router, self)) + path_view.set_api_instance(api, t.cast(Router, self)) + + def is_registered(self, api: "NinjaExtraAPI") -> bool: + keys = ( + reflect.get_metadata(NINJA_EXTRA_API_CONTROLLER_REGISTERED_KEY, self) + or set() + ) + if id(api) in keys: + return True + return False - def build_routers(self) -> List[Tuple[str, "APIController"]]: + def build_routers(self) -> t.List[t.Tuple[str, "APIController"]]: prefix = self.prefix if self._prefix_has_route_param: prefix = "" @@ -560,8 +546,8 @@ def add_controller_route_function(self, route_function: RouteFunction) -> None: get_function_name(route_function.route.view_func) ] = route_function - def urls_paths(self, prefix: str) -> Iterator[Union[URLPattern, URLResolver]]: - namespaced_patterns: List[URLPattern] = [] + def urls_paths(self, prefix: str) -> t.Iterator[t.Union[URLPattern, URLResolver]]: + namespaced_patterns: t.List[URLPattern] = [] for path, path_view in self.path_operations.items(): path = path.replace("{", "<").replace("}", ">") @@ -571,7 +557,7 @@ def urls_paths(self, prefix: str) -> Iterator[Union[URLPattern, URLResolver]]: route = route.lstrip("/") for op in path_view.operations: - op = cast(Operation, op) + op = t.cast(Operation, op) view = path_view.get_view() pattern = django_path(route, view, name=op.url_name) @@ -630,24 +616,24 @@ def _add_operation_from_route_function(self, route_function: RouteFunction) -> N def add_api_operation( self, path: str, - methods: List[str], - view_func: Callable, + methods: t.List[str], + view_func: t.Callable, *, - auth: Any = NOT_SET, - throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, - response: Any = NOT_SET, - operation_id: Optional[str] = None, - summary: Optional[str] = None, - description: Optional[str] = None, - tags: Optional[List[str]] = None, - deprecated: Optional[bool] = None, + auth: t.Any = NOT_SET, + throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, + response: t.Any = NOT_SET, + operation_id: t.Optional[str] = None, + summary: t.Optional[str] = None, + description: t.Optional[str] = None, + tags: t.Optional[t.List[str]] = None, + deprecated: t.Optional[bool] = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - url_name: Optional[str] = None, + url_name: t.Optional[str] = None, include_in_schema: bool = True, - openapi_extra: Optional[Dict[str, Any]] = None, + openapi_extra: t.Optional[t.Dict[str, t.Any]] = None, ) -> Operation: auth = self.auth if auth == NOT_SET else auth @@ -681,39 +667,41 @@ def add_api_operation( return operation -@overload +@t.overload def api_controller( - prefix_or_class: Union[ControllerClassType, Type[T]], -) -> Union[Type[ControllerBase], Type[T]]: # pragma: no cover + prefix_or_class: t.Union[ControllerClassType, t.Type[T]], +) -> t.Union[t.Type[ControllerBase], t.Type[T]]: # pragma: no cover ... -@overload +@t.overload def api_controller( prefix_or_class: str = "", - auth: Any = NOT_SET, - throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, - tags: Union[Optional[List[str]], str] = None, - permissions: Optional[List[BasePermissionType]] = None, + auth: t.Any = NOT_SET, + throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, + tags: t.Union[t.Optional[t.List[str]], str] = None, + permissions: t.Optional[t.List[BasePermissionType]] = None, auto_import: bool = True, - urls_namespace: Optional[str] = None, + urls_namespace: t.Optional[str] = None, use_unique_op_id: bool = True, -) -> Callable[ - [Union[Type, Type[T]]], Union[Type[ControllerBase], Type[T]] +) -> t.Callable[ + [t.Union[t.Type, t.Type[T]]], t.Union[t.Type[ControllerBase], t.Type[T]] ]: # pragma: no cover ... def api_controller( - prefix_or_class: Union[str, ControllerClassType] = "", - auth: Any = NOT_SET, - throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, - tags: Union[Optional[List[str]], str] = None, - permissions: Optional[List[BasePermissionType]] = None, + prefix_or_class: t.Union[str, ControllerClassType] = "", + auth: t.Any = NOT_SET, + throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, + tags: t.Union[t.Optional[t.List[str]], str] = None, + permissions: t.Optional[t.List[BasePermissionType]] = None, auto_import: bool = True, - urls_namespace: Optional[str] = None, + urls_namespace: t.Optional[str] = None, use_unique_op_id: bool = True, -) -> Union[ControllerClassType, Callable[[ControllerClassType], ControllerClassType]]: +) -> t.Union[ + ControllerClassType, t.Callable[[ControllerClassType], ControllerClassType] +]: if isinstance(prefix_or_class, type): return APIController( prefix="", diff --git a/ninja_extra/controllers/model/builder.py b/ninja_extra/controllers/model/builder.py index 3f5dff22..b63efd81 100644 --- a/ninja_extra/controllers/model/builder.py +++ b/ninja_extra/controllers/model/builder.py @@ -2,7 +2,12 @@ from ninja.orm.fields import TYPES -from ninja_extra.constants import ROUTE_FUNCTION +from ninja_extra.constants import ROUTE_OBJECT +from ninja_extra.controllers.route.route_functions import ( + AsyncRouteFunction, + RouteFunction, +) +from ninja_extra.reflect import reflect from .endpoints import ( ModelAsyncEndpointFactory, @@ -13,6 +18,7 @@ if t.TYPE_CHECKING: from ninja_extra.controllers.base import APIController, ModelControllerBase + from ninja_extra.controllers.route import Route class ModelControllerBuilder: @@ -51,8 +57,18 @@ def __init__( ) def _add_to_controller(self, func: t.Callable) -> None: - route_function = getattr(func, ROUTE_FUNCTION) - route_function.api_controller = self._api_controller_instance + route_obj: "Route" = t.cast( + "Route", reflect.get_metadata_or_raise_exception(ROUTE_OBJECT, func) + ) + route_function: t.Union[RouteFunction, AsyncRouteFunction] + if route_obj.is_async: + route_function = AsyncRouteFunction( + route_obj, api_controller=self._api_controller_instance + ) + else: + route_function = RouteFunction( + route_obj, api_controller=self._api_controller_instance + ) self._api_controller_instance.add_controller_route_function(route_function) def _register_create_endpoint(self) -> None: diff --git a/ninja_extra/controllers/model/endpoints.py b/ninja_extra/controllers/model/endpoints.py index 96dacb2c..ce81c124 100644 --- a/ninja_extra/controllers/model/endpoints.py +++ b/ninja_extra/controllers/model/endpoints.py @@ -17,6 +17,7 @@ PathResolverOperation, ) from ninja_extra.controllers.route import route +from ninja_extra.controllers.utils import get_api_controller from ninja_extra.exceptions import NotFound from ninja_extra.pagination import ( PageNumberPaginationExtra, @@ -123,7 +124,8 @@ def create( """ def _setup(model_controller_type: t.Type["ModelControllerBase"]) -> t.Callable: - api_controller = model_controller_type.get_api_controller() + api_controller = get_api_controller(model_controller_type) + assert api_controller is not None, "API controller is required" working_path = cls._clean_path(path) @@ -191,7 +193,8 @@ def update( """ def _setup(model_controller_type: t.Type["ModelControllerBase"]) -> t.Callable: - api_controller = model_controller_type.get_api_controller() + api_controller = get_api_controller(model_controller_type) + assert api_controller is not None, "API controller is required" working_path = cls._clean_path(path) update_item = _path_resolver( path, @@ -261,7 +264,8 @@ def patch( """ def _setup(model_controller_type: t.Type["ModelControllerBase"]) -> t.Callable: - api_controller = model_controller_type.get_api_controller() + api_controller = get_api_controller(model_controller_type) + assert api_controller is not None, "API controller is required" working_path = cls._clean_path(path) patch_item = _path_resolver( path, @@ -329,7 +333,8 @@ def find_one( """ def _setup(model_controller_type: t.Type["ModelControllerBase"]) -> t.Callable: - api_controller = model_controller_type.get_api_controller() + api_controller = get_api_controller(model_controller_type) + assert api_controller is not None, "API controller is required" working_path = cls._clean_path(path) get_item = _path_resolver( path, @@ -401,7 +406,8 @@ def list( """ def _setup(model_controller_type: t.Type["ModelControllerBase"]) -> t.Callable: - api_controller = model_controller_type.get_api_controller() + api_controller = get_api_controller(model_controller_type) + assert api_controller is not None, "API controller is required" working_path = cls._clean_path(path) list_items = _path_resolver( path, @@ -493,7 +499,8 @@ def delete( """ def _setup(model_controller_type: t.Type["ModelControllerBase"]) -> t.Callable: - api_controller = model_controller_type.get_api_controller() + api_controller = get_api_controller(model_controller_type) + assert api_controller is not None, "API controller is required" working_path = cls._clean_path(path) delete_item = _path_resolver( path, diff --git a/ninja_extra/controllers/registry.py b/ninja_extra/controllers/registry.py index b1316137..f22bd327 100644 --- a/ninja_extra/controllers/registry.py +++ b/ninja_extra/controllers/registry.py @@ -1,37 +1,56 @@ -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING, Dict, Optional, Type, cast + +from ninja_extra.constants import API_CONTROLLER_INSTANCE +from ninja_extra.reflect import reflect if TYPE_CHECKING: # pragma: no cover - from ninja_extra.controllers.base import ControllerBase # pragma: no cover + from ninja_extra.controllers.base import ( + APIController, + ControllerBase, + ) # pragma: no cover -class ControllerBorg: - _shared_state_: Dict[str, Dict[str, Type["ControllerBase"]]] = {"controllers": {}} +class ControllerRegistry: + KEY = "CONTROLLER_REGISTRY" def __init__(self) -> None: - self.__dict__ = self._shared_state_ + reflect.define_metadata(self.KEY, {}, self.__class__) def add_controller(self, controller: Type["ControllerBase"]) -> None: - if ( - hasattr(controller, "get_api_controller") - and controller.get_api_controller().auto_import - ): - self._shared_state_["controllers"].update({str(controller): controller}) + api_controller_raw = reflect.get_metadata(API_CONTROLLER_INSTANCE, controller) + if not api_controller_raw: + return + api_controller: "APIController" = cast("APIController", api_controller_raw) + if not api_controller.auto_import: + return + reflect.define_metadata(self.KEY, {str(controller): controller}, self.__class__) def remove_controller( self, controller: Type["ControllerBase"] ) -> Optional[Type["ControllerBase"]]: - if str(controller) in self._shared_state_["controllers"]: - return self._shared_state_["controllers"].pop(str(controller)) + controllers = reflect.get_metadata(self.KEY, self.__class__) + + if controllers and str(controller) in controllers: + removed_controller: Type["ControllerBase"] = cast( + Type["ControllerBase"], controllers[str(controller)] + ) + del controllers[str(controller)] + + reflect.delete_metadata(self.KEY, self.__class__) + reflect.define_metadata(self.KEY, controllers, self.__class__) + + return removed_controller return None def clear_controller(self) -> None: - self._shared_state_["controllers"] = {} + reflect.delete_metadata(self.KEY, self.__class__) + reflect.define_metadata(self.KEY, {}, self.__class__) - @classmethod - def get_controllers(cls) -> Dict[str, Type["ControllerBase"]]: - return cls._shared_state_.get("controllers", {}) + def get_controllers(self) -> Dict[str, Type["ControllerBase"]]: + controllers = reflect.get_metadata(self.KEY, self.__class__) + return ( + cast(Dict[str, Type["ControllerBase"]], controllers) if controllers else {} + ) -class ControllerRegistry(ControllerBorg): - def __init__(self) -> None: - ControllerBorg.__init__(self) +controller_registry = ControllerRegistry() diff --git a/ninja_extra/controllers/route/__init__.py b/ninja_extra/controllers/route/__init__.py index a582176e..fbe8aa70 100644 --- a/ninja_extra/controllers/route/__init__.py +++ b/ninja_extra/controllers/route/__init__.py @@ -8,17 +8,17 @@ from ninja_extra.constants import ( DELETE, GET, + OPERATION_ENDPOINT_KEY, PATCH, POST, PUT, - ROUTE_FUNCTION, ROUTE_METHODS, + ROUTE_OBJECT, ) from ninja_extra.permissions import BasePermission +from ninja_extra.reflect import reflect from ninja_extra.schemas import RouteParameter -from .route_functions import AsyncRouteFunction, RouteFunction - class RouteInvalidParameterException(Exception): pass @@ -170,11 +170,9 @@ def _create_route_function( openapi_extra=openapi_extra, throttle=throttle, ) - route_function_class = RouteFunction - if route_obj.is_async: - route_function_class = AsyncRouteFunction - setattr(view_func, ROUTE_FUNCTION, route_function_class(route=route_obj)) + reflect.define_metadata(ROUTE_OBJECT, route_obj, view_func) + setattr(view_func, OPERATION_ENDPOINT_KEY, True) return view_func @classmethod diff --git a/ninja_extra/controllers/route/route_functions.py b/ninja_extra/controllers/route/route_functions.py index b115acf1..74fc12b6 100644 --- a/ninja_extra/controllers/route/route_functions.py +++ b/ninja_extra/controllers/route/route_functions.py @@ -5,12 +5,15 @@ from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, cast from django.http import HttpRequest, HttpResponse +from typing_extensions import deprecated +from ninja_extra.constants import ROUTE_OBJECT_FUNCTION from ninja_extra.context import ( RouteContext, get_route_execution_context, ) from ninja_extra.dependency_resolver import get_injector, service_resolver +from ninja_extra.reflect import reflect if TYPE_CHECKING: # pragma: no cover from ninja_extra.controllers.base import APIController, ControllerBase @@ -27,15 +30,15 @@ def __init__( class RouteFunction(object): - def __init__( - self, route: "Route", api_controller: Optional["APIController"] = None - ): + def __init__(self, route: "Route", api_controller: "APIController"): self.route = route self.operation: Optional["Operation"] = None self.has_request_param = False - self.api_controller = api_controller + self._api_controller = api_controller self.as_view = wraps(route.view_func)(self.get_view_function()) self._resolve_api_func_signature_(self.as_view) + # Store route function metadata + reflect.define_metadata(ROUTE_OBJECT_FUNCTION, self, route) def __call__( self, @@ -66,7 +69,14 @@ def _get_required_api_func_signature(self) -> Tuple: self.has_request_param = True return sig_inspect, sig_parameter - def get_api_controller(self) -> "APIController": + @property + def api_controller(self) -> "APIController": + return self._api_controller + + @deprecated( + "get_api_controller() is deprecated, use api_controller property instead." + ) + def get_api_controller(self) -> "APIController": # pragma: no cover assert self.api_controller, "APIController is required" return self.api_controller @@ -125,12 +135,11 @@ def _get_controller_instance(self) -> "ControllerBase": from ninja_extra.controllers.base import ModelControllerBase injector = get_injector() - _api_controller = self.get_api_controller() additional_kwargs = {} - if issubclass(_api_controller.controller_class, ModelControllerBase): + if issubclass(self.api_controller.controller_class, ModelControllerBase): controller_klass = cast( - ModelControllerBase, _api_controller.controller_class + ModelControllerBase, self.api_controller.controller_class ) # make sure model_config is not None if controller_klass.model_config is not None: @@ -141,7 +150,7 @@ def _get_controller_instance(self) -> "ControllerBase": additional_kwargs.update({"service": service}) controller_instance = injector.create_object( - _api_controller.controller_class, additional_kwargs=additional_kwargs + self.api_controller.controller_class, additional_kwargs=additional_kwargs ) return controller_instance @@ -155,11 +164,10 @@ def get_route_execution_context( DeprecationWarning, stacklevel=2, ) - _api_controller = self.get_api_controller() init_kwargs = { "permission_classes": self.route.permissions - or _api_controller.permission_classes, + or self.api_controller.permission_classes, "request": request, "kwargs": kwargs, "args": args, @@ -241,11 +249,10 @@ async def __call__( *args: Any, **kwargs: Any, ) -> Any: - _api_controller = self.get_api_controller() context = get_route_execution_context( request, temporal_response, - self.route.permissions or _api_controller.permission_classes, # type:ignore[arg-type] + self.route.permissions or self.api_controller.permission_classes, # type:ignore[arg-type] *args, **kwargs, ) diff --git a/ninja_extra/controllers/utils.py b/ninja_extra/controllers/utils.py new file mode 100644 index 00000000..8fa3e4f1 --- /dev/null +++ b/ninja_extra/controllers/utils.py @@ -0,0 +1,15 @@ +import typing as t + +from ninja_extra.constants import API_CONTROLLER_INSTANCE +from ninja_extra.reflect import reflect + +if t.TYPE_CHECKING: # pragma: no cover + from .base import APIController, ControllerBase + + +def get_api_controller( + controller_class: t.Type["ControllerBase"], +) -> t.Optional["APIController"]: + return t.cast( + "APIController", reflect.get_metadata(API_CONTROLLER_INSTANCE, controller_class) + ) diff --git a/ninja_extra/helper.py b/ninja_extra/helper.py index f4b4de90..affd0d9a 100644 --- a/ninja_extra/helper.py +++ b/ninja_extra/helper.py @@ -1,10 +1,8 @@ import inspect import typing as t -from ninja_extra.constants import ROUTE_FUNCTION - if t.TYPE_CHECKING: # pragma: no cover - from ninja_extra.controllers import RouteFunction + pass def get_function_name(func_class: t.Any) -> str: @@ -13,8 +11,9 @@ def get_function_name(func_class: t.Any) -> str: return str(func_class.__class__.__name__) -@t.no_type_check -def get_route_function(func: t.Callable) -> t.Optional["RouteFunction"]: - if hasattr(func, ROUTE_FUNCTION): - return func.__dict__[ROUTE_FUNCTION] - return None # pragma: no cover +# TODO: Add deprecation warning +# @t.no_type_check +# def get_route_function(func: t.Callable) -> t.Optional["RouteFunction"]: +# if hasattr(func, ROUTE_FUNCTION): +# return func.__dict__[ROUTE_FUNCTION] +# return None # pragma: no cover diff --git a/ninja_extra/main.py b/ninja_extra/main.py index 5d666d5b..7c7e9d38 100644 --- a/ninja_extra/main.py +++ b/ninja_extra/main.py @@ -1,17 +1,6 @@ +import typing as t import warnings from importlib import import_module -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, -) from django.core.exceptions import ImproperlyConfigured from django.http import HttpRequest, HttpResponse @@ -27,13 +16,16 @@ from ninja_extra import exceptions, router from ninja_extra.compatible import NOT_SET_TYPE +from ninja_extra.constants import API_CONTROLLER_INSTANCE from ninja_extra.controllers.base import APIController, ControllerBase -from ninja_extra.controllers.registry import ControllerRegistry +from ninja_extra.controllers.registry import controller_registry __all__ = [ "NinjaExtraAPI", ] +from ninja_extra.reflect import reflect + class NinjaExtraAPI(NinjaAPI): def __init__( @@ -42,20 +34,22 @@ def __init__( title: str = "NinjaExtraAPI", version: str = "1.0.0", description: str = "", - openapi_url: Optional[str] = "/openapi.json", + openapi_url: t.Optional[str] = "/openapi.json", docs: DocsBase = Swagger(), - docs_url: Optional[str] = "/docs", - docs_decorator: Optional[Callable[[TCallable], TCallable]] = None, - servers: Optional[List[DictStrAny]] = None, - urls_namespace: Optional[str] = None, + docs_url: t.Optional[str] = "/docs", + docs_decorator: t.Optional[t.Callable[[TCallable], TCallable]] = None, + servers: t.Optional[t.List[DictStrAny]] = None, + urls_namespace: t.Optional[str] = None, csrf: bool = False, - auth: Optional[Union[Sequence[Callable], Callable, NOT_SET_TYPE]] = NOT_SET, - throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, - renderer: Optional[BaseRenderer] = None, - parser: Optional[Parser] = None, - openapi_extra: Optional[Dict[str, Any]] = None, + auth: t.Optional[ + t.Union[t.Sequence[t.Callable], t.Callable, NOT_SET_TYPE] + ] = NOT_SET, + throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, + renderer: t.Optional[BaseRenderer] = None, + parser: t.Optional[Parser] = None, + openapi_extra: t.Optional[t.Dict[str, t.Any]] = None, app_name: str = "ninja", - **kwargs: Any, + **kwargs: t.Any, ) -> None: # add a warning if there csrf is True if csrf: @@ -88,14 +82,14 @@ def __init__( ) self.app_name = app_name self.exception_handler(exceptions.APIException)(self.api_exception_handler) - self._routers: List[Tuple[str, router.Router]] = [] # type: ignore + self._routers: t.List[t.Tuple[str, router.Router]] = [] # type: ignore self.default_router = router.Router() self.add_router("", self.default_router) def api_exception_handler( self, request: HttpRequest, exc: exceptions.APIException ) -> HttpResponse: - headers: Dict = {} + headers: t.Dict = {} if isinstance(exc, exceptions.Throttled): headers["Retry-After"] = "%d" % float(exc.wait or 0.0) @@ -111,7 +105,7 @@ def api_exception_handler( return response @property - def urls(self) -> Tuple[List[Union[URLResolver, URLPattern]], str, str]: + def urls(self) -> t.Tuple[t.List[t.Union[URLResolver, URLPattern]], str, str]: _url_tuple = super().urls return ( _url_tuple[0], @@ -120,23 +114,27 @@ def urls(self) -> Tuple[List[Union[URLResolver, URLPattern]], str, str]: ) def register_controllers( - self, *controllers: Union[Type[ControllerBase], Type, str] + self, *controllers: t.Union[t.Type[ControllerBase], t.Type, str] ) -> None: for controller in controllers: if isinstance(controller, str): - controller = cast( - Union[Type[ControllerBase], Type], import_string(controller) + controller = t.cast( + t.Union[t.Type[ControllerBase], t.Type], import_string(controller) ) if not issubclass(controller, ControllerBase): raise ImproperlyConfigured( f"{controller.__class__.__name__} class is not a controller" ) - api_controller: APIController = controller.get_api_controller() - if not api_controller.registered: + api_controller = t.cast( + APIController, + reflect.get_metadata_or_raise_exception( + API_CONTROLLER_INSTANCE, controller + ), + ) + if not api_controller.is_registered(self): self._routers.extend(api_controller.build_routers()) # type: ignore api_controller.set_api_instance(self) - api_controller.registered = True def auto_discover_controllers(self) -> None: from django.apps import apps @@ -154,7 +152,7 @@ def auto_discover_controllers(self) -> None: mod_path = "%s.%s" % (app_module.name, module) import_module(mod_path) self.register_controllers( - *ControllerRegistry.get_controllers().values() + *controller_registry.get_controllers().values() ) except ImportError as ex: # pragma: no cover raise ex diff --git a/ninja_extra/ordering/operation.py b/ninja_extra/ordering/operation.py index 8ee93bc3..f9181dbb 100644 --- a/ninja_extra/ordering/operation.py +++ b/ninja_extra/ordering/operation.py @@ -5,7 +5,9 @@ from asgiref.sync import sync_to_async from django.http import HttpRequest +from ninja_extra.constants import ORDERATOR_OBJECT from ninja_extra.interfaces.ordering import OrderingBase +from ninja_extra.reflect import reflect from ninja_extra.shortcuts import add_ninja_contribute_args logger = logging.getLogger() @@ -36,7 +38,7 @@ def __init__( self.orderator.InputSource, ), ) - orderator_view.orderator_operation = self # type:ignore[attr-defined] + reflect.define_metadata(ORDERATOR_OBJECT, self, orderator_view) @property def view_func_has_kwargs(self) -> bool: # pragma: no cover diff --git a/ninja_extra/pagination/operations.py b/ninja_extra/pagination/operations.py index 0df40ce0..0ce80c50 100644 --- a/ninja_extra/pagination/operations.py +++ b/ninja_extra/pagination/operations.py @@ -14,7 +14,9 @@ from ninja import FilterSchema, Query from ninja.pagination import AsyncPaginationBase, PaginationBase +from ninja_extra.constants import PAGINATOR_OBJECT from ninja_extra.context import RouteContext +from ninja_extra.reflect import reflect from ninja_extra.shortcuts import add_ninja_contribute_args if TYPE_CHECKING: # pragma: no cover @@ -58,7 +60,7 @@ def __init__( Query(...), ), ) - paginator_view.paginator_operation = self # type:ignore[attr-defined] + reflect.define_metadata(PAGINATOR_OBJECT, self, paginator_view) @property def view_func_has_kwargs(self) -> bool: # pragma: no cover @@ -147,12 +149,15 @@ async def as_view( params = dict(kw) params["request"] = request is_supported_async_orm = django.VERSION >= (4, 2) - paginate_queryset = ( - self.paginator.apaginate_queryset - if isinstance(self.paginator, AsyncPaginationBase) + if ( + isinstance(self.paginator, AsyncPaginationBase) and is_supported_async_orm - else cast(Callable, sync_to_async(self.paginator.paginate_queryset)) - ) + ): + paginate_queryset = self.paginator.apaginate_queryset + else: + paginate_queryset = cast( + Callable, sync_to_async(self.paginator.paginate_queryset) + ) return await paginate_queryset(items, **params) return as_view diff --git a/ninja_extra/reflect/__init__.py b/ninja_extra/reflect/__init__.py new file mode 100644 index 00000000..4060e2c1 --- /dev/null +++ b/ninja_extra/reflect/__init__.py @@ -0,0 +1,9 @@ +""" +Ellar Reflect: A module for managing metadata on callables and types. +Provides tools to attach, retrieve, and manage metadata for dependency injection and other framework features. +""" + +from ._reflect import reflect +from .utils import ensure_target, fail_silently, transfer_metadata + +__all__ = ["reflect", "ensure_target", "transfer_metadata", "fail_silently"] diff --git a/ninja_extra/reflect/_reflect.py b/ninja_extra/reflect/_reflect.py new file mode 100644 index 00000000..4627f5dd --- /dev/null +++ b/ninja_extra/reflect/_reflect.py @@ -0,0 +1,390 @@ +import logging +import typing as t +import weakref +from contextlib import asynccontextmanager, contextmanager +from weakref import WeakKeyDictionary, WeakValueDictionary + +from .utils import ensure_target, fail_silently, get_original_target + +logger = logging.getLogger("ellar") + + +def _try_hash(item: t.Any) -> bool: + """ + Try to hash an item. + + :param item: The item to try and hash. + :return: True if the item is hashable, False otherwise. + """ + try: + hash(item), weakref.ref(item) + return True + except TypeError: + return False + + +class _Hashable: + """ + A wrapper class to make unhashable items hashable by using their ID and string representation. + """ + + def __init__(self, item_id: int, item_repr: str) -> None: + self.item_id = item_id + self.item_repr = item_repr + # self._item_repr = item_repr + + def __hash__(self) -> int: + # Combine the hash values of the attributes + attrs = self.item_id, self.item_repr + return hash(attrs) + + def __eq__(self, other: t.Any) -> bool: + # Check if another object is equal based on attributes + if isinstance(other, _Hashable): + return self.item_id == other.item_id + return False + + def __repr__(self) -> str: + return self.item_repr + + @classmethod + def force_hash(cls, item: t.Any) -> t.Union[t.Any, "_Hashable"]: + """ + Force an item to be hashable. If it's already hashable, return it. + If not, return a _Hashable wrapper or retrieve an existing one. + + :param item: The item to hash. + :return: The item or its _Hashable wrapper. + """ + if not _try_hash(item): + hashable = fail_silently( + lambda: reflect._un_hashable[hash((id(item), repr(item)))] + ) + if hashable: + return hashable + + new_target = cls(item_id=id(item), item_repr=repr(item)) + return reflect.add_un_hashable_type(new_target) + return item + + +def _get_actual_target( + target: t.Any, +) -> t.Any: + """ + Get the actual target for metadata operations. + Resolves proxies and ensures the target is hashable. + + :param target: The target to resolve. + :return: The resolved, hashable target. + """ + target = get_original_target(target) + return _Hashable.force_hash(ensure_target(target)) + + +class _Reflect: + """ + Metadata manager class for storage and retrieval of metadata associated with types and callables. + Use `reflect` instance for all operations. + """ + + __slots__ = ("_meta_data",) + + _un_hashable: t.Dict[int, _Hashable] = {} + _data_type_update_callbacks: t.MutableMapping[t.Type, t.Callable] = ( + WeakValueDictionary() + ) + + def __init__(self) -> None: + self._meta_data: t.MutableMapping[t.Union[t.Type, t.Callable], t.Dict] = ( + WeakKeyDictionary() + ) + + def add_type_update_callback(self, type_: t.Type, func: t.Callable) -> None: + """ + Register a callback to handle updates for a specific metadata type. + + :param type_: The type of the metadata value. + :param func: The call back function to handle the update. + """ + self._data_type_update_callbacks[type_] = func + + def add_un_hashable_type(self, value: _Hashable) -> _Hashable: + """ + Store an unhashable item wrapper. + + :param value: The _Hashable wrapper. + :return: The stored _Hashable wrapper. + """ + self._un_hashable[hash(value)] = value + return value + + def _default_update_callback( + self, existing_value: t.Any, new_value: t.Any + ) -> t.Any: + return new_value + + def define_metadata( + self, + metadata_key: str, + metadata_value: t.Any, + target: t.Any, + ) -> t.Any: + """ + Define metadata for a target. + + :param metadata_key: The key for the metadata. + :param metadata_value: The value of the metadata. + :param target: The target object to associate the metadata with. + :return: The value returned by type update callback or new value. + """ + if target is None: + raise Exception("`target` is not a valid type") + # if ( + # not isinstance(target, type) + # and not callable(target) + # and not ismethod(target) + # or target is None + # ): + # raise Exception("`target` is not a valid type") + + target_metadata = self._get_or_create_metadata(target, create=True) + if target_metadata is not None: + existing = target_metadata.get(metadata_key) + if existing is not None: + update_callback: t.Callable[[t.Any, t.Any], t.Any] = ( + self._data_type_update_callbacks.get( + type(existing), self._default_update_callback + ) + ) + metadata_value = update_callback(existing, metadata_value) + target_metadata[metadata_key] = metadata_value + + def metadata(self, metadata_key: str, metadata_value: t.Any) -> t.Any: + """ + Decorator to define metadata on a class or function. + + :param metadata_key: The key for the metadata. + :param metadata_value: The value of the metadata. + :return: A decorator function. + """ + + def _wrapper(target: t.Any) -> t.Any: + self.define_metadata(metadata_key, metadata_value, target) + return target + + return _wrapper + + def has_metadata(self, metadata_key: str, target: t.Any) -> bool: + """ + Check if metadata key exists for a target. + + :param metadata_key: The key to check. + :param target: The target object. + :return: True if metadata key exists, False otherwise. + """ + _target_actual = _get_actual_target(target) + target_metadata = self._meta_data.get(_target_actual) or {} + + return metadata_key in target_metadata + + def get_metadata(self, metadata_key: str, target: t.Any) -> t.Optional[t.Any]: + """ + Retrieve metadata value for a target. + + :param metadata_key: The key to retrieve. + :param target: The target object. + :return: The metadata value or None if not found. + """ + _target_actual = _get_actual_target(target) + target_metadata = self._meta_data.get(_target_actual) or {} + + value = target_metadata.get(metadata_key) + if isinstance(value, (list, set, tuple, dict)): + # return immutable value + return type(value)(value) + return value + + def get_metadata_search_safe(self, metadata_key: str, target: t.Any) -> t.Any: + """ + Retrieve metadata value safely. Raises KeyError if key is not found in the target's metadata. + This behaves like `dict[key]`. + + :param metadata_key: The key to retrieve. + :param target: The target object. + :return: The metadata value. + """ + _target_actual = _get_actual_target(target) + meta = self._meta_data[_target_actual] + + value = meta[metadata_key] + if isinstance(value, (list, set, tuple, dict)): + # return immutable value + return type(value)(value) + return value + + def get_metadata_or_raise_exception( + self, metadata_key: str, target: t.Any + ) -> t.Any: + """ + Retrieve metadata or raise an Exception if not found. + + :param metadata_key: The key to retrieve. + :param target: The target object. + :return: The metadata value. + :raises Exception: If metadata key is not found. + """ + value = self.get_metadata(metadata_key=metadata_key, target=target) + if value is not None: + return value + raise Exception("MetaData Key not Found") + + def get_metadata_keys(self, target: t.Any) -> t.KeysView[t.Any]: + """ + Get all metadata keys for a target. + + :param target: The target object. + :return: A view of the metadata keys. + """ + _target_actual = _get_actual_target(target) + target_metadata = self._meta_data.get(_target_actual) or {} + + return target_metadata.keys() + + def get_all_metadata(self, target: t.Any) -> t.Dict: + """ + Get all metadata for a target as a dictionary. + + :param target: The target object. + :return: A dictionary containing all metadata. + """ + _target_actual = _get_actual_target(target) + target_metadata = self._meta_data.get(_target_actual) or {} + return type(target_metadata)(target_metadata) + + def delete_all_metadata(self, target: t.Any) -> None: + """ + Delete all metadata for a target. + + :param target: The target object. + """ + _target = _get_actual_target(target) + if _target in self._meta_data: + self._meta_data.pop(_target) + + def delete_metadata(self, metadata_key: str, target: t.Any) -> t.Any: + """ + Delete a specific metadata key for a target. + + :param metadata_key: The key to delete. + :param target: The target object. + :return: The deleted value or None. + """ + _target_actual = _get_actual_target(target) + target_metadata = self._meta_data.get(_target_actual) or {} + + if target_metadata and metadata_key in target_metadata: + value = target_metadata.pop(metadata_key) + if isinstance(value, (list, set, tuple, dict)): + # return immutable value + return type(value)(value) + return value + + def _get_or_create_metadata( + self, target: t.Any, create: bool = False + ) -> t.Optional[t.Dict]: + _target = _get_actual_target(target) + if _target in self._meta_data: + return self._meta_data[_target] + + if create: + self._meta_data[_target] = {} + return self._meta_data[_target] + return None + + def _clone_meta_data( + self, + ) -> t.MutableMapping[t.Union[t.Type, t.Callable], t.Dict]: + _meta_data: t.MutableMapping[t.Union[t.Type, t.Callable], t.Dict] = ( + WeakKeyDictionary() + ) + for k, v in self._meta_data.items(): + _meta_data[k] = dict(v) + return _meta_data + + @asynccontextmanager + async def async_context(self) -> t.AsyncGenerator[None, None]: + """ + Async context manager that isolates metadata changes within the context. + Metadata changes made inside the context are discarded after exit. + """ + cached_meta_data = self._clone_meta_data() + yield + reflect._meta_data.clear() + reflect._meta_data = WeakKeyDictionary(dict=cached_meta_data) + + @contextmanager + def context(self) -> t.Generator: + """ + Context manager that isolates metadata changes within the context. + Metadata changes made inside the context are discarded after exit. + """ + cached_meta_data = self._clone_meta_data() + yield + reflect._meta_data.clear() + reflect._meta_data = WeakKeyDictionary(dict=cached_meta_data) + + +def _list_update(existing_value: t.Any, new_value: t.Any) -> t.Any: + """ + Update callback for list/tuple types. Concatenates the new value to the existing value. + + :param existing_value: The existing list or tuple. + :param new_value: The new list or tuple. + :return: The concatenated list or tuple. + """ + if isinstance(existing_value, (list, tuple)) and isinstance( + new_value, (list, tuple) + ): + return existing_value + type(existing_value)(new_value) # type: ignore + return new_value + + +def _set_update(existing_value: t.Any, new_value: t.Any) -> t.Any: + """ + Update callback for set types. Unions the new value with the existing value. + + :param existing_value: The existing set. + :param new_value: The new set. + :return: The union of the sets. + """ + if isinstance(existing_value, set) and isinstance(new_value, set): + existing_combined = list(existing_value) + list(new_value) + return type(existing_value)(existing_combined) + return new_value + + +def _dict_update(existing_value: t.Any, new_value: t.Any) -> t.Any: + """ + Update callback for dict types. Updates the existing dictionary with new values. + + :param existing_value: The existing dictionary. + :param new_value: The new dictionary. + :return: The updated dictionary. + """ + if isinstance( + existing_value, (dict, WeakKeyDictionary, WeakValueDictionary) + ) and isinstance(new_value, (dict, WeakKeyDictionary, WeakValueDictionary)): + existing_value.update(new_value) + return type(existing_value)(existing_value) + return new_value + + +reflect = _Reflect() + +reflect.add_type_update_callback(tuple, _list_update) +reflect.add_type_update_callback(list, _list_update) +reflect.add_type_update_callback(set, _set_update) +reflect.add_type_update_callback(dict, _dict_update) +reflect.add_type_update_callback(WeakKeyDictionary, _dict_update) +reflect.add_type_update_callback(WeakValueDictionary, _dict_update) diff --git a/ninja_extra/reflect/utils.py b/ninja_extra/reflect/utils.py new file mode 100644 index 00000000..e655c2f4 --- /dev/null +++ b/ninja_extra/reflect/utils.py @@ -0,0 +1,113 @@ +import functools +import inspect +import logging +import typing as t + +logger = logging.getLogger("ellar") + + +def ensure_target(target: t.Union[t.Type, t.Callable]) -> t.Union[t.Type, t.Callable]: + """ + Ensure the target is a class or a function, unwrapping methods to their underlying functions. + + :param target: The target object (class, function, or method). + :return: The class or function. + """ + res = target + if inspect.ismethod(res): + res = res.__func__ + return res + + +def is_decorated_with_partial(func_or_class: t.Any) -> bool: + """ + Check if the object is decorated with `functools.partial`. + + :param func_or_class: The object to check. + :return: True if decorated with partial, False otherwise. + """ + return isinstance(func_or_class, functools.partial) + + +def is_decorated_with_wraps(func_or_class: t.Any) -> bool: + """ + Check if the object is decorated with `functools.wraps`. + + :param func_or_class: The object to check. + :return: True if decorated with wraps, False otherwise. + """ + return hasattr(func_or_class, "__wrapped__") + + +def get_original_target(func_or_class: t.Any) -> t.Any: + """ + Unwrap the object to find the original target, getting past partials and wraps. + + :param func_or_class: The object to unwrap. + :return: The original underlying object. + """ + while True: + if is_decorated_with_partial(func_or_class): + func_or_class = func_or_class.func + elif is_decorated_with_wraps(func_or_class): + func_or_class = func_or_class.__wrapped__ + else: + return func_or_class + + +def transfer_metadata( + old_target: t.Any, new_target: t.Any, clean_up: bool = False +) -> None: + """ + Transfer metadata from one target to another. + + :param old_target: The source target. + :param new_target: The destination target. + :param clean_up: If True, delete metadata from the old target after transfer. + """ + from ._reflect import reflect + + meta = reflect.get_all_metadata(old_target) + for k, v in meta.items(): + reflect.define_metadata(k, v, new_target) + + if clean_up: + reflect.delete_all_metadata(old_target) + + +@t.no_type_check +def fail_silently(func: t.Callable, *args: t.Any, **kwargs: t.Any) -> t.Optional[t.Any]: + """ + Execute a function and return None if an exception occurs, logging the error blindly. + + :param func: The function to execute. + :param args: Positional arguments for the function. + :param kwargs: Keyword arguments for the function. + :return: The result of the function or None if an exception occurred. + """ + try: + return func(*args, **kwargs) + except Exception as ex: # pragma: no cover + logger.debug( + f"Calling {func} with args: {args} kw: {kwargs} failed\nException: {ex}" + ) + return None + + +class AnnotationToValue(type): + keys: t.List[str] + + @t.no_type_check + def __new__(mcls, name, bases, namespace): + cls = super().__new__(mcls, name, bases, namespace) + annotations = {} + for base in reversed(bases): # pragma: no cover + annotations.update(getattr(base, "__annotations__", {})) + annotations.update(namespace.get("__annotations__", {})) + cls.keys = [] + for k, v in annotations.items(): + if type(v) is type(str): + value = str(k).lower() + setattr(cls, k, value) + cls.keys.append(value) + return cls diff --git a/ninja_extra/searching/operations.py b/ninja_extra/searching/operations.py index 0f140d27..56cf833b 100644 --- a/ninja_extra/searching/operations.py +++ b/ninja_extra/searching/operations.py @@ -3,7 +3,9 @@ from asgiref.sync import sync_to_async +from ninja_extra.constants import SEARCH_OPERATOR_OBJECT from ninja_extra.interfaces.searching import SearchingBase +from ninja_extra.reflect import reflect from ninja_extra.shortcuts import add_ninja_contribute_args if t.TYPE_CHECKING: # pragma: no cover @@ -32,7 +34,7 @@ def __init__( self.searcherator.InputSource, ), ) - searcherator_view.searcherator_operation = self # type:ignore[attr-defined] + reflect.define_metadata(SEARCH_OPERATOR_OBJECT, self, searcherator_view) @property def view_func_has_kwargs(self) -> bool: # pragma: no cover diff --git a/ninja_extra/testing/client.py b/ninja_extra/testing/client.py index 2b313e02..bcce2e2f 100644 --- a/ninja_extra/testing/client.py +++ b/ninja_extra/testing/client.py @@ -1,27 +1,50 @@ from json import dumps as json_dumps -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, cast from unittest.mock import Mock from urllib.parse import urlencode +from django.urls import Resolver404 from ninja import NinjaAPI, Router from ninja.responses import NinjaJSONEncoder from ninja.testing.client import NinjaClientBase, NinjaResponse from ninja_extra import ControllerBase, NinjaExtraAPI +from ninja_extra.constants import CONTROLLER_WATERMARK +from ninja_extra.controllers.utils import get_api_controller +from ninja_extra.reflect import reflect class NinjaExtraClientBase(NinjaClientBase): def __init__( self, router_or_app: Union[NinjaAPI, Router, Type[ControllerBase]], **kw: Any ) -> None: - if hasattr(router_or_app, "get_api_controller"): + if reflect.has_metadata(CONTROLLER_WATERMARK, cast(Any, router_or_app)): api = NinjaExtraAPI(**kw) - controller_ninja_api_controller = router_or_app.get_api_controller() + controller_type = cast(Type[ControllerBase], router_or_app) + controller_ninja_api_controller = get_api_controller(controller_type) assert controller_ninja_api_controller + controller_ninja_api_controller.set_api_instance(api) self._urls_cache = list(controller_ninja_api_controller.urls_paths("")) + router_or_app = api - super(NinjaExtraClientBase, self).__init__(router_or_app) + super(NinjaExtraClientBase, self).__init__( + cast(Union[NinjaAPI, Router], router_or_app) + ) + + def _resolve( + self, method: str, path: str, data: Dict, request_params: Any + ) -> Tuple[Callable, Mock, Dict]: + url_path = path.split("?")[0].lstrip("/") + for url in self.urls: + try: + match = url.resolve(url_path) + except Resolver404: + continue + if match: + request = self._build_request(method, path, data, request_params) + return match.func, request, match.kwargs + raise Exception(f'Cannot resolve "{path}"') def request( self, diff --git a/tests/conftest.py b/tests/conftest.py index 506abcfd..8fab4d87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,9 @@ import os +import typing as t +from uuid import uuid4 import django +import pytest def pytest_configure(config): @@ -61,3 +64,31 @@ def pytest_configure(config): ) django.setup() + + +@pytest.fixture +def reflect_context(): + from ninja_extra.reflect import reflect + + with reflect.context(): + yield reflect + + +@pytest.fixture +def random_type(): + return type(f"Random{uuid4().hex[:6]}", (), {}) + + +@pytest.fixture +def get_route_function(): + from ninja_extra.controllers.route.route_functions import RouteFunction + from ninja_extra.reflect import reflect + + def _wrap(func: t.Callable) -> RouteFunction: + route_object = reflect.get_metadata_or_raise_exception("ROUTE_OBJECT", func) + route_function = reflect.get_metadata_or_raise_exception( + "ROUTE_OBJECT_FUNCTION", route_object + ) + return route_function + + return _wrap diff --git a/tests/test_api_instance.py b/tests/test_api_instance.py index 092900d4..c35c93f0 100644 --- a/tests/test_api_instance.py +++ b/tests/test_api_instance.py @@ -5,7 +5,8 @@ from ninja.testing import TestClient from ninja_extra import NinjaExtraAPI, api_controller, http_get -from ninja_extra.controllers.registry import ControllerRegistry +from ninja_extra.controllers.registry import controller_registry +from ninja_extra.controllers.utils import get_api_controller @api_controller @@ -38,7 +39,7 @@ def test_api_instance(): def test_api_auto_discover_controller(): ninja_extra_api = NinjaExtraAPI() - assert str(SomeAPIController) in ControllerRegistry.get_controllers() + assert str(SomeAPIController) in controller_registry.get_controllers() with mock.patch.object( ninja_extra_api, "register_controllers" @@ -48,17 +49,17 @@ def test_api_auto_discover_controller(): assert ( "" - in ControllerRegistry.get_controllers() + in controller_registry.get_controllers() ) @api_controller class SomeAPI2Controller: auto_import = False - assert str(SomeAPI2Controller) not in ControllerRegistry.get_controllers() + assert str(SomeAPI2Controller) not in controller_registry.get_controllers() -def test_api_register_controller_works(): +def test_api_register_controller_works(reflect_context): @api_controller("/another") class AnotherAPIController: @http_get("/example") @@ -67,10 +68,10 @@ def example(self): ninja_extra_api = NinjaExtraAPI() assert len(ninja_extra_api._routers) == 1 - assert not AnotherAPIController.get_api_controller().registered + assert not get_api_controller(AnotherAPIController).is_registered(ninja_extra_api) ninja_extra_api.register_controllers(AnotherAPIController) - assert AnotherAPIController.get_api_controller().registered + assert get_api_controller(AnotherAPIController).is_registered(ninja_extra_api) assert len(ninja_extra_api._routers) == 2 assert "/another" in dict(ninja_extra_api._routers) diff --git a/tests/test_controller.py b/tests/test_controller.py index 2904de83..4f121759 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -12,13 +12,13 @@ http_get, testing, ) -from ninja_extra.controllers import ControllerBase, RouteContext, RouteFunction +from ninja_extra.context import RouteContext +from ninja_extra.controllers import ControllerBase, RouteFunction from ninja_extra.controllers.base import ( APIController, - MissingAPIControllerDecoratorException, get_route_functions, ) -from ninja_extra.helper import get_route_function +from ninja_extra.controllers.utils import get_api_controller from ninja_extra.permissions.common import AllowAny from .utils import AsyncFakeAuth, FakeAuth @@ -87,7 +87,7 @@ def test_api_controller_as_decorator(self): controller_type = api_controller("prefix", tags="new_tag", auth=FakeAuth())( type("Any", (), {}) ) - api_controller_instance = controller_type.get_api_controller() + api_controller_instance = get_api_controller(controller_type) assert not api_controller_instance.has_auth_async assert not api_controller_instance._prefix_has_route_param @@ -95,29 +95,28 @@ def test_api_controller_as_decorator(self): assert api_controller_instance.tags == ["new_tag"] assert api_controller_instance.permission_classes == [AllowAny] - controller_type = api_controller()(controller_type) - api_controller_instance = controller_type.get_api_controller() + controller_type = api_controller()(type("Any2", (), {})) + api_controller_instance = get_api_controller(controller_type) assert api_controller_instance.prefix == "" - assert api_controller_instance.tags == ["any"] + assert api_controller_instance.tags == ["any2"] assert "ninja_extra.controllers.base" in SomeController.__module__ assert "tests.test_controller" in Some2Controller.__module__ - assert Some2Controller.get_api_controller() + assert get_api_controller(Some2Controller) def test_controller_get_api_controller_raise_exception(self): class BController(ControllerBase): pass - with pytest.raises(MissingAPIControllerDecoratorException): - BController.get_api_controller() + assert get_api_controller(BController) is None - def test_api_controller_prefix_with_parameter(self): + def test_api_controller_prefix_with_parameter(self, reflect_context): @api_controller("/{int:organisation_id}") class UsersController: @http_get("") def example_with_id_response(self, organisation_id: int): return {"organisation_id": organisation_id} - _api_controller: APIController = UsersController.get_api_controller() + _api_controller: APIController = get_api_controller(UsersController) assert _api_controller._prefix_has_route_param client = testing.TestClient(UsersController) @@ -128,27 +127,26 @@ def example_with_id_response(self, organisation_id: int): def test_controller_should_have_preset_properties(self): api = NinjaExtraAPI() - _api_controller = SomeController.get_api_controller() + _api_controller = get_api_controller(SomeController) assert _api_controller.tags == ["some"] assert _api_controller._path_operations == {} assert _api_controller.permission_classes == [AllowAny] - assert SomeController.api is None - assert _api_controller.registered is False + assert _api_controller.is_registered(api) is False assert ControllerBase in SomeController.__bases__ api.register_controllers(SomeController) - assert _api_controller.registered + assert _api_controller.is_registered(api) def test_controller_should_wrap_with_inject(self): assert not hasattr(SomeController.__init__, "__bindings__") assert hasattr(SomeControllerWithInject.__init__, "__bindings__") def test_controller_should_have_path_operation_list(self): - _api_controller = SomeControllerWithRoute.get_api_controller() + _api_controller = get_api_controller(SomeControllerWithRoute) assert len(_api_controller._path_operations) == 5 - route_function: RouteFunction = get_route_function( - SomeControllerWithRoute().example + route_function: RouteFunction = ( + _api_controller._controller_class_route_functions.get("example") ) path_view = _api_controller._path_operations.get(str(route_function)) assert path_view, "route doesn't exist in controller" @@ -159,15 +157,16 @@ def test_controller_should_have_path_operation_list(self): assert operation.operation_id == route_function.route.route_params.operation_id def test_controller_should_append_unique_op_id_to_operation_id(self): - _api_controller = SomeControllerWithSingleRoute.get_api_controller() + _api_controller = get_api_controller(SomeControllerWithSingleRoute) controller_name = ( str(_api_controller.controller_class.__name__) .lower() .replace("controller", "") ) - route_view_func_name: RouteFunction = get_route_function( - SomeControllerWithRoute().example - ).route.view_func.__name__ + route_function: RouteFunction = ( + _api_controller._controller_class_route_functions.get("example") + ) + route_view_func_name: RouteFunction = route_function.route.view_func.__name__ operation_id = ( _api_controller._path_operations.get("/example").operations[0].operation_id @@ -179,15 +178,16 @@ def test_controller_should_append_unique_op_id_to_operation_id(self): assert len(op_id_postfix) == 8 def test_controller_should_not_add_unique_suffix_following_params(self): - _api_controller = SomeControllerWithoutUniqueSuffix.get_api_controller() + _api_controller = get_api_controller(SomeControllerWithoutUniqueSuffix) controller_name = ( str(_api_controller.controller_class.__name__) .lower() .replace("controller", "") ) - route_view_func_name: RouteFunction = get_route_function( - SomeControllerWithRoute().example - ).route.view_func.__name__ + route_function: RouteFunction = ( + _api_controller._controller_class_route_functions.get("example") + ) + route_view_func_name: RouteFunction = route_function.route.view_func.__name__ operation_id = ( _api_controller._path_operations.get("/example").operations[0].operation_id @@ -196,19 +196,21 @@ def test_controller_should_not_add_unique_suffix_following_params(self): assert operation_id == f"{controller_name}_{route_view_func_name}" def test_get_route_function_should_return_instance_route_definitions(self): - for route_definition in get_route_functions(SomeControllerWithRoute): + for route_definition in get_route_functions(SomeControllerWithRoute, Mock()): assert isinstance(route_definition, RouteFunction) - def test_compute_api_route_function_works(self): + def test_compute_api_route_function_works(self, reflect_context): @api_controller() class AnyClassTypeWithRoute: @http_get("/example") def example(self): pass - api_controller_instance = AnyClassTypeWithRoute.get_api_controller() + api_controller_instance = get_api_controller(AnyClassTypeWithRoute) assert len(api_controller_instance.path_operations) == 1 - route_function = get_route_function(AnyClassTypeWithRoute().example) + route_function: RouteFunction = ( + api_controller_instance._controller_class_route_functions.get("example") + ) path_view = api_controller_instance.path_operations.get(str(route_function)) assert path_view @@ -322,12 +324,11 @@ async def test_controller_base_aget_object_or_none_works(self): def test_controller_registration_through_string(): - assert DisableAutoImportController.get_api_controller().registered is False - api = NinjaExtraAPI() + assert get_api_controller(DisableAutoImportController).is_registered(api) is False api.register_controllers("tests.test_controller.DisableAutoImportController") - assert DisableAutoImportController.get_api_controller().registered + assert get_api_controller(DisableAutoImportController).is_registered(api) @pytest.mark.skipif(django.VERSION < (3, 1), reason="requires django 3.1 or higher") @@ -352,10 +353,12 @@ class AsyncRouteInControllerWithAsyncAuth: async def example(self): pass - example_route_function = get_route_function( - AsyncRouteInControllerWithAsyncAuth().example + api_controller_instance = get_api_controller(AsyncRouteInControllerWithAsyncAuth) + example_route_function = ( + api_controller_instance._controller_class_route_functions.get("example") ) - assert AsyncRouteInControllerWithAsyncAuth.get_api_controller().has_auth_async + + assert api_controller_instance.has_auth_async assert isinstance( example_route_function.operation.auth_callbacks[0], AsyncFakeAuth, diff --git a/tests/test_controller_registry.py b/tests/test_controller_registry.py index c01b26fc..db466d50 100644 --- a/tests/test_controller_registry.py +++ b/tests/test_controller_registry.py @@ -12,35 +12,51 @@ class AutoImportTrueControllerSample(ControllerBase): auto_import = True -def test_can_not_add_controller_for_auto_false(): +def test_can_not_add_controller_for_auto_false(reflect_context): registry = ControllerRegistry() registry.clear_controller() + registry.add_controller(AutoImportFalseControllerSample) - assert str(AutoImportFalseControllerSample) not in registry.controllers + controllers = registry.get_controllers() + + assert str(AutoImportFalseControllerSample) not in controllers -def test_can_add_controller_for_auto_true(): +def test_can_add_controller_for_auto_true(reflect_context): registry = ControllerRegistry() registry.clear_controller() + registry.add_controller(AutoImportTrueControllerSample) - assert str(AutoImportTrueControllerSample) in registry.controllers + controllers = registry.get_controllers() + assert str(AutoImportTrueControllerSample) in controllers -def test_remove_controller_works(): + +def test_remove_controller_works(reflect_context): registry = ControllerRegistry() registry.clear_controller() registry.add_controller(AutoImportTrueControllerSample) - assert str(AutoImportTrueControllerSample) in registry.controllers + controllers = registry.get_controllers() + + assert str(AutoImportTrueControllerSample) in controllers result = registry.remove_controller(AutoImportTrueControllerSample) + assert result - assert str(AutoImportTrueControllerSample) not in registry.controllers + controllers = registry.get_controllers() + + assert str(AutoImportTrueControllerSample) not in controllers assert registry.remove_controller(AutoImportTrueControllerSample) is None -def test_clear_registry_works(): +def test_clear_registry_works(reflect_context): registry = ControllerRegistry() registry.add_controller(AutoImportTrueControllerSample) - assert str(AutoImportTrueControllerSample) in registry.controllers + + controllers = registry.get_controllers() + assert str(AutoImportTrueControllerSample) in controllers + registry.clear_controller() - assert len(registry.controllers) == 0 + controllers = registry.get_controllers() + + assert len(controllers) == 0 diff --git a/tests/test_deprecation_warnings.py b/tests/test_deprecation_warnings.py index adb16512..5a18560a 100644 --- a/tests/test_deprecation_warnings.py +++ b/tests/test_deprecation_warnings.py @@ -1,17 +1,5 @@ import warnings -import pytest - - -def test_route_context_deprecation(): - with pytest.warns( - DeprecationWarning, - match="RouteContext is deprecated and will be removed in a future version.", - ): - from ninja_extra.controllers import RouteContext - - assert RouteContext is not None # Verify we can still access it - def test_no_warning_for_other_imports(): with warnings.catch_warnings(): diff --git a/tests/test_model_controller/test_model_async_controller_operation.py b/tests/test_model_controller/test_model_async_controller_operation.py index 6ee7068c..51039c91 100644 --- a/tests/test_model_controller/test_model_async_controller_operation.py +++ b/tests/test_model_controller/test_model_async_controller_operation.py @@ -2,6 +2,7 @@ from asgiref.sync import sync_to_async from ninja_extra.controllers.base import APIController +from ninja_extra.controllers.utils import get_api_controller from ninja_extra.testing import TestAsyncClient from ..models import Event @@ -288,8 +289,8 @@ async def test_api_controller_prefix_with_parameter(): "end_date": "2020-01-02", "title": "test-prefix", } - _api_controller: APIController = ( - AsyncEventModelControllerWithPrefix.get_api_controller() + _api_controller: APIController = get_api_controller( + AsyncEventModelControllerWithPrefix ) assert _api_controller._prefix_has_route_param diff --git a/tests/test_operation.py b/tests/test_operation.py index 78943329..003dbe33 100644 --- a/tests/test_operation.py +++ b/tests/test_operation.py @@ -5,9 +5,11 @@ from ninja import Body, Schema from ninja_extra import api_controller, http_delete, http_get, http_post, route, status +from ninja_extra.constants import ROUTE_OBJECT from ninja_extra.controllers import AsyncRouteFunction, RouteFunction -from ninja_extra.helper import get_route_function +from ninja_extra.controllers.utils import get_api_controller from ninja_extra.operation import AsyncOperation, Operation +from ninja_extra.reflect import reflect from ninja_extra.testing import TestAsyncClient, TestClient from .utils import AsyncFakeAuth, FakeAuth, mock_log_call @@ -46,12 +48,12 @@ def test_route_operation_execution_should_log_execution(self): @pytest.mark.skipif(django.VERSION < (3, 1), reason="requires django 3.1 or higher") -def test_operation_auth_configs(): +def test_operation_auth_configs(reflect_context): @api_controller("prefix", tags="any_Tag") class AController: pass - api_controller_instance = AController.get_api_controller() + api_controller_instance = get_api_controller(AController) async def async_endpoint(self, request): pass @@ -63,7 +65,10 @@ def sync_endpoint(self, request): async_auth_http_get = route.get("/example/async", auth=[AsyncFakeAuth()]) sync_auth_http_get(async_endpoint) - async_route_function = get_route_function(async_endpoint) + route_obj = reflect.get_metadata_or_raise_exception(ROUTE_OBJECT, async_endpoint) + async_route_function = AsyncRouteFunction( + route_obj, api_controller=api_controller_instance + ) assert isinstance(async_route_function, AsyncRouteFunction) api_controller_instance._add_operation_from_route_function(async_route_function) @@ -72,14 +77,24 @@ def sync_endpoint(self, request): assert isinstance(async_route_function.operation, AsyncOperation) sync_auth_http_get(sync_endpoint) - sync_route_function = get_route_function(sync_endpoint) + sync_route_obj = reflect.get_metadata_or_raise_exception( + ROUTE_OBJECT, sync_endpoint + ) + sync_route_function = RouteFunction( + sync_route_obj, api_controller=api_controller_instance + ) api_controller_instance._add_operation_from_route_function(sync_route_function) assert isinstance(sync_route_function.operation, Operation) assert isinstance(sync_route_function, RouteFunction) with pytest.raises(Exception) as ex: new_sync_endpoint = async_auth_http_get(sync_endpoint) - new_sync_route_function = get_route_function(new_sync_endpoint) + new_sync_route_obj = reflect.get_metadata_or_raise_exception( + ROUTE_OBJECT, new_sync_endpoint + ) + new_sync_route_function = RouteFunction( + new_sync_route_obj, api_controller=api_controller_instance + ) api_controller_instance._add_operation_from_route_function( new_sync_route_function ) @@ -118,7 +133,7 @@ async def test_async_route_operation_execution_should_log_execution(self): await client.get("/example_exception") -def test_controller_operation_order(): +def test_controller_operation_order(reflect_context): class InputSchema(Schema): name: str age: int diff --git a/tests/test_ordering.py b/tests/test_ordering.py index 597e301e..faa94f7b 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -1,4 +1,3 @@ -import inspect import operator from typing import List @@ -7,13 +6,15 @@ from ninja import Schema from ninja_extra import NinjaExtraAPI, api_controller, route -from ninja_extra.constants import ROUTE_FUNCTION +from ninja_extra.constants import ORDERATOR_OBJECT +from ninja_extra.controllers.utils import get_api_controller from ninja_extra.ordering import ( OrderatorOperation, Ordering, OrderingBase, ordering, ) +from ninja_extra.reflect import reflect from ninja_extra.testing import TestAsyncClient, TestClient from .models import Category @@ -106,16 +107,17 @@ def items_10(self): @pytest.mark.django_db class TestOrdering: def test_orderator_operation_used(self): - some_api_route_functions = dict( - inspect.getmembers( - SomeAPIController, lambda member: hasattr(member, ROUTE_FUNCTION) - ) - ) + api_controller_instance = get_api_controller(SomeAPIController) has_kwargs = ("items_3", "items_4") found_route_functions = False - for name, route_function in some_api_route_functions.items(): - assert hasattr(route_function, "orderator_operation") - orderator_operation = route_function.orderator_operation + for ( + name, + route_function, + ) in api_controller_instance._controller_class_route_functions.items(): + assert reflect.has_metadata(ORDERATOR_OBJECT, route_function.as_view) + orderator_operation = reflect.get_metadata( + ORDERATOR_OBJECT, route_function.as_view + ) assert isinstance(orderator_operation, OrderatorOperation) if name in has_kwargs: assert orderator_operation.view_func_has_kwargs @@ -303,17 +305,17 @@ async def items_10(self): client = TestAsyncClient(AsyncSomeAPIController) async def test_orderator_operation_used(self): - some_api_route_functions = dict( - inspect.getmembers( - self.AsyncSomeAPIController, - lambda member: hasattr(member, ROUTE_FUNCTION), - ) - ) + api_controller_instance = get_api_controller(self.AsyncSomeAPIController) has_kwargs = ("items_3", "items_4") found_route_functions = False - for name, route_function in some_api_route_functions.items(): - assert hasattr(route_function, "orderator_operation") - orderator_operation = route_function.orderator_operation + for ( + name, + route_function, + ) in api_controller_instance._controller_class_route_functions.items(): + assert reflect.has_metadata(ORDERATOR_OBJECT, route_function.as_view) + orderator_operation = reflect.get_metadata( + ORDERATOR_OBJECT, route_function.as_view + ) assert isinstance(orderator_operation, OrderatorOperation) if name in has_kwargs: assert orderator_operation.view_func_has_kwargs diff --git a/tests/test_pagination.py b/tests/test_pagination.py index a0417545..7b471edc 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -6,6 +6,7 @@ from ninja import FilterSchema, NinjaAPI, Schema from ninja_extra import NinjaExtraAPI, api_controller, route +from ninja_extra.constants import PAGINATOR_OBJECT from ninja_extra.controllers import RouteFunction from ninja_extra.pagination import ( AsyncPaginatorOperation, @@ -336,7 +337,7 @@ def test_case7(self): assert response.status_code == 404 assert response.json() == {"message": "Not Found"} - def test_filter_schema_integration(self): + def test_filter_schema_integration(self, reflect_context): """Test that filter_schema is properly integrated with paginate decorator""" # Test the paginate decorator directly with filter_schema @@ -347,8 +348,8 @@ def test_view(request): return ITEMS # Check that the decorated function has paginator_operation attribute - assert hasattr(test_view, "paginator_operation") - paginator_operation = test_view.paginator_operation + assert reflect_context.has_metadata(PAGINATOR_OBJECT, test_view) + paginator_operation = reflect_context.get_metadata(PAGINATOR_OBJECT, test_view) # Verify filter_schema is stored correctly assert isinstance(paginator_operation, PaginatorOperation) diff --git a/tests/test_reflect.py b/tests/test_reflect.py new file mode 100644 index 00000000..24461b63 --- /dev/null +++ b/tests/test_reflect.py @@ -0,0 +1,186 @@ +import pytest + +from ninja_extra.reflect import reflect + + +def test_define_metadata_creates_attribute_dict(random_type): + key = "FrameworkName" + reflect.define_metadata(key, "Ellar", random_type) + # assert hasattr(random_type, REFLECT_TYPE) + assert reflect.get_metadata(key, random_type) == "Ellar" + assert list(reflect.get_metadata_keys(random_type)) == [key] + + +@pytest.mark.parametrize("immutable_type", ["FrameworkName", 23, (33, 45), {34, 45}]) +def test_reflect_works_with_immutable_types(immutable_type, reflect_context): + key = "FrameworkName" + reflect.define_metadata(key, "Ellar", immutable_type) + # assert hasattr(random_type, REFLECT_TYPE) + assert reflect.get_metadata(key, immutable_type) == "Ellar" + assert list(reflect.get_metadata_keys(immutable_type)) == [key] + + +def test_define_metadata_without_default(random_type): + key = "FrameworkName" + + reflect.define_metadata(key, "Ellar", random_type) + assert reflect.get_metadata(key, random_type) == "Ellar" + reflect.define_metadata(key, "Starlette", random_type) + assert reflect.get_metadata(key, random_type) == "Starlette" + + +def test_define_metadata_with_existing_tuple(random_type): + reflect.define_metadata("B", ("EllarB",), random_type) + assert reflect.get_metadata("B", random_type) == ("EllarB",) + + reflect.define_metadata("B", ("AnotherEllar",), random_type) + reflect.define_metadata("B", ("AnotherEllarC",), random_type) + assert reflect.get_metadata("B", random_type) == ( + "EllarB", + "AnotherEllar", + "AnotherEllarC", + ) + + +def test_get_all_metadata(random_type): + reflect.define_metadata("B", ("EllarB",), random_type) + assert reflect.get_metadata("B", random_type) == ("EllarB",) + + reflect.define_metadata("B", ("AnotherEllar",), random_type) + data = reflect.get_all_metadata(random_type) + assert data == {"B": ("EllarB", "AnotherEllar")} + + +def test_delete_all_metadata(random_type): + reflect.define_metadata("D", ("EllarD",), random_type) + + reflect.define_metadata("B", ("AnotherEllar",), random_type) + data = reflect.get_all_metadata(random_type) + assert data == {"B": ("AnotherEllar",), "D": ("EllarD",)} + + reflect.delete_all_metadata(random_type) + assert reflect.get_metadata("D", random_type) is None + + +def test_define_metadata_with_existing_list(random_type): + reflect.define_metadata("B", ["Ellar"], random_type) + assert reflect.get_metadata("B", random_type) == ["Ellar"] + + reflect.define_metadata("B", ["AnotherEllar"], random_type) + reflect.define_metadata("B", ["AnotherEllarD"], random_type) + assert reflect.get_metadata("B", random_type) == [ + "Ellar", + "AnotherEllar", + "AnotherEllarD", + ] + + +def test_define_metadata_with_existing_dict(random_type): + reflect.define_metadata("C", {"C": "EllarC"}, random_type) + assert reflect.get_metadata("C", random_type) == {"C": "EllarC"} + + reflect.define_metadata("C", {"D": "EllarD"}, random_type) + assert reflect.get_metadata("C", random_type) == { + "D": "EllarD", + "C": "EllarC", + } + + +def test_define_metadata_with_existing_set(random_type): + reflect.define_metadata("A", {"EllarA"}, random_type) + reflect.define_metadata("A", {"AnotherEllar"}, random_type) + assert reflect.get_metadata("A", random_type) == {"AnotherEllar", "EllarA"} + + +def test_reflect_meta_decorator(): + @reflect.metadata("defined_key", "chioma") + @reflect.metadata("defined_key_b", "jessy") + def function_a(): + """ignore""" + + assert reflect.get_metadata("defined_key", function_a) == "chioma" + assert reflect.get_metadata("defined_key_b", function_a) == "jessy" + assert list(reflect.get_metadata_keys(function_a)) == [ + "defined_key_b", + "defined_key", + ] + + +def test_reflect_has_metadata_works(): + @reflect.metadata("defined_key", "jessy") + def function_new(): + """ignore""" + + assert reflect.has_metadata("defined_key", function_new) + assert reflect.has_metadata("defined_key_b", function_new) is False + + +def test_reflect_get_metadata_or_raise_exception(): + @reflect.metadata("defined_key_b", "jessy") + def function_new(): + """ignore""" + + assert ( + reflect.get_metadata_or_raise_exception("defined_key_b", function_new) + == "jessy" + ) + + with pytest.raises(Exception, match="MetaData Key not Found"): + reflect.get_metadata_or_raise_exception("defined_key", function_new) + + +def test_delete_metadata_works(): + @reflect.metadata("defined_key_b", "jessy") + def function_new(): + """ignore""" + + reflect.delete_metadata("defined_key_b", function_new) + assert reflect.has_metadata("defined_key_b", function_new) is False + + +def test_reflect_context_works(): + @reflect.metadata("defined_key_b", "jessy") + @reflect.metadata("defined_key", "clara") + def function_new(): + """ignore""" + + with reflect.context(): + reflect.define_metadata("defined_key_c", "Eadwin", function_new) + reflect.define_metadata("defined_key_d", "Dakolo", function_new) + + assert reflect.has_metadata("defined_key_b", function_new) + assert reflect.has_metadata("defined_key", function_new) + assert reflect.has_metadata("defined_key_c", function_new) + assert reflect.has_metadata("defined_key_d", function_new) + + assert reflect.has_metadata("defined_key_c", function_new) is False + assert reflect.has_metadata("defined_key_d", function_new) is False + + +@pytest.mark.asyncio +async def test_reflect_async_context_works(): + @reflect.metadata("defined_key_b", "jessy") + @reflect.metadata("defined_key", "clara") + def function_new(): + """ignore""" + + async with reflect.async_context(): + reflect.define_metadata("defined_key_c", "Eadwin", function_new) + reflect.define_metadata("defined_key_d", "Dakolo", function_new) + + assert reflect.has_metadata("defined_key_b", function_new) + assert reflect.has_metadata("defined_key", function_new) + assert reflect.has_metadata("defined_key_c", function_new) + assert reflect.has_metadata("defined_key_d", function_new) + + assert reflect.has_metadata("defined_key_c", function_new) is False + assert reflect.has_metadata("defined_key_d", function_new) is False + + +def test_define_metadata_raise_exception(): + with pytest.raises(Exception, match="`target` is not a valid type"): + reflect.define_metadata("defined_key_c", "Eadwin", None) + + +def test_define_metadata_overrides_existing_collection_of_different_type(): + pass diff --git a/tests/test_route.py b/tests/test_route.py index 8022aab4..e7f2c722 100644 --- a/tests/test_route.py +++ b/tests/test_route.py @@ -7,6 +7,7 @@ from ninja.constants import NOT_SET from ninja_extra import api_controller, permissions, route +from ninja_extra.constants import ROUTE_OBJECT from ninja_extra.context import ( RouteContext, get_route_execution_context, @@ -16,10 +17,11 @@ RouteFunction, RouteInvalidParameterException, ) -from ninja_extra.controllers.base import get_all_controller_route_function +from ninja_extra.controllers.base import get_route_functions +from ninja_extra.controllers.utils import get_api_controller from ninja_extra.exceptions import PermissionDenied -from ninja_extra.helper import get_route_function from ninja_extra.permissions import AllowAny +from ninja_extra.reflect import reflect from .schemas import UserSchema from .utils import FakeAuth @@ -90,13 +92,13 @@ class TestControllerRoute: def test_api_controller_builds_accurate_operations_list( self, path, operation_count ): - api_controller_instance = SomeTestController.get_api_controller() + api_controller_instance = get_api_controller(SomeTestController) path_view = api_controller_instance.path_operations.get(path) assert len(path_view.operations) == operation_count def test_controller_route_should_have_an_operation(self): - for route_func in get_all_controller_route_function(SomeTestController): - path_view = SomeTestController.get_api_controller().path_operations.get( + for route_func in get_route_functions(SomeTestController, None): + path_view = get_api_controller(SomeTestController).path_operations.get( str(route_func) ) operations = list( @@ -111,7 +113,7 @@ def test_controller_route_should_have_an_operation(self): assert operations[0].operation_id == "example_post_operation_id" assert route_func.route.route_params.methods == operations[0].methods - def test_controller_route_should_right_view_func_type(self): + def test_controller_route_should_right_view_func_type(self, get_route_function): controller = SomeTestController() route_function = get_route_function(controller.example) assert isinstance(route_function, RouteFunction) @@ -119,7 +121,9 @@ def test_controller_route_should_right_view_func_type(self): assert hasattr(route_function.as_view, "get_route_function") assert route_function.as_view.get_route_function() == route_function - def test_controller_route_should_use_userschema_as_response(self): + def test_controller_route_should_use_userschema_as_response( + self, get_route_function + ): controller = SomeTestController() route_function = get_route_function(controller.example) assert route_function.route.route_params.response == NOT_SET @@ -154,28 +158,36 @@ def example_list_create(self, ex_id: str): assert "Invalid response configuration" in str(ex) - def test_route_response_parameters_computed_correctly(self): + def test_route_response_parameters_computed_correctly( + self, get_route_function, reflect_context + ): unique_response = [{302: Schema}, (401, Schema)] non_unique_response = [ {201: Schema}, ] # Id status_code == 201 so it should be replaced by the dict response - @route.get("/example/list", response=unique_response) - def example_unique_response(self, ex_id: str): - pass + @api_controller + class ExampleController: + @route.get("/example/list", response=unique_response) + def example_unique_response(self, ex_id: str): + pass - @route.get("/example/list", response=non_unique_response) - def example_non_unique_response(self, ex_id: str): - pass + @route.get("/example/list", response=non_unique_response) + def example_non_unique_response(self, ex_id: str): + pass assert ( - len(get_route_function(example_unique_response).route.route_params.response) + len( + get_route_function( + ExampleController().example_unique_response + ).route.route_params.response + ) == 2 ) assert ( len( get_route_function( - example_non_unique_response + ExampleController().example_non_unique_response ).route.route_params.response ) == 1 @@ -216,7 +228,9 @@ def example_non_unique_response(self, ex_id: str): ), ], ) - def test_route_generates_required_route_definitions(self, func, methods, kwargs): + def test_route_generates_required_route_definitions( + self, func, methods, kwargs, get_route_function + ): def view_func(request): pass @@ -226,15 +240,15 @@ def view_func(request): if func == "generic" else route_method("/", **kwargs) )(view_func) - route_function = get_route_function(view_func) - assert route_function.route.route_params.methods == methods + route_object = reflect.get_metadata_or_raise_exception(ROUTE_OBJECT, view_func) + assert route_object.route_params.methods == methods for k, v in kwargs.items(): - assert getattr(route_function.route.route_params, k) == v + assert getattr(route_object.route_params, k) == v @pytest.mark.skipif(django.VERSION < (3, 1), reason="requires django 3.1 or higher") @pytest.mark.asyncio -async def test_async_route_function(): +async def test_async_route_function(reflect_context, get_route_function): @api_controller() class AsyncSomeTestController(SomeTestController): @route.get("/example_async") @@ -279,40 +293,39 @@ async def async_api_func(self): def test_get_required_api_func_signature_return_filtered_signature(self): route.get("")(self.api_func) - route_function = get_route_function(self.api_func) + route_object = reflect.get_metadata_or_raise_exception( + ROUTE_OBJECT, self.api_func + ) + route_function = RouteFunction(route_object, None) + assert not route_function.has_request_param sig_inspect, sig_parameter = route_function._get_required_api_func_signature() assert len(sig_parameter) == 0 route.get("")(self.api_func_with_has_request_param) - route_function = get_route_function(self.api_func_with_has_request_param) + route_object = reflect.get_metadata_or_raise_exception( + ROUTE_OBJECT, self.api_func_with_has_request_param + ) + route_function = RouteFunction(route_object, None) assert route_function.has_request_param sig_inspect, sig_parameter = route_function._get_required_api_func_signature() assert len(sig_parameter) == 0 route.get("")(self.api_func_with_param) - route_function = get_route_function(self.api_func_with_param) + route_object = reflect.get_metadata_or_raise_exception( + ROUTE_OBJECT, self.api_func_with_param + ) + route_function = RouteFunction(route_object, None) sig_inspect, sig_parameter = route_function._get_required_api_func_signature() assert len(sig_parameter) == 1 assert str(sig_parameter[0]).replace(" ", "") == "example_id:str" - def test_from_route_returns_route_function_instance(self): - route.get("")(self.api_func) - route_function = get_route_function(self.api_func) - assert isinstance(route_function, RouteFunction) - - route.get("")(self.async_api_func) - route_function = get_route_function(self.async_api_func) - assert isinstance(route_function, AsyncRouteFunction) - def test_get_route_execution_context(self): route.get("")(self.api_func) - route_function = get_route_function(self.api_func) - with pytest.raises(AssertionError): - route_function.get_route_execution_context( - anonymous_request, "arg1", "arg2", extra="extra" - ) - route_function.api_controller = Mock() + route_object = reflect.get_metadata_or_raise_exception( + ROUTE_OBJECT, self.api_func + ) + route_function = RouteFunction(route_object, Mock()) route_function.api_controller.permission_classes = [AllowAny] route_context = route_function.get_route_execution_context( @@ -323,14 +336,21 @@ def test_get_route_execution_context(self): for key in expected_keywords: assert getattr(route_context, key) - def test_get_controller_instance_return_controller_instance(self): - route_function: RouteFunction = get_route_function(SomeTestController().example) + def test_get_controller_instance_return_controller_instance( + self, get_route_function + ): + api_controller_instance = get_api_controller(SomeTestController) + route_function: RouteFunction = ( + api_controller_instance._controller_class_route_functions.get("example") + ) controller_instance = route_function._get_controller_instance() assert isinstance(controller_instance, SomeTestController) assert isinstance(controller_instance, SomeTestController) assert controller_instance.context is None - def test_process_view_function_result_return_tuple_or_input(self): + def test_process_view_function_result_return_tuple_or_input( + self, get_route_function + ): route_function: RouteFunction = get_route_function(SomeTestController().example) mock_result = {"detail": "Some Message", "status_code": 302} response = route_function._process_view_function_result(mock_result) @@ -359,7 +379,9 @@ def get_real_user_request(cls): _request.user = user return _request - def test_permission_controller_example_allow_any_auth_is_none(self): + def test_permission_controller_example_allow_any_auth_is_none( + self, get_route_function + ): example_allow_any_route_function = get_route_function( self.controller.example_allow_any ) @@ -370,7 +392,9 @@ def test_permission_controller_example_allow_any_auth_is_none(self): assert response == {"message": "OK"} assert response == self.controller.example_allow_any() - def test_route_is_protected_by_global_controller_permission(self): + def test_route_is_protected_by_global_controller_permission( + self, get_route_function + ): example_route_function = get_route_function(self.controller.example) with pytest.raises(PermissionDenied) as pex: example_route_function(anonymous_request) @@ -378,20 +402,24 @@ def test_route_is_protected_by_global_controller_permission(self): pex.value.detail ) - def test_route_protected_by_global_controller_permission_works(self): + def test_route_protected_by_global_controller_permission_works( + self, get_route_function + ): example_route_function = get_route_function(self.controller.example) request = self.get_real_user_request() response = example_route_function(request) assert response == {"message": "OK"} - def test_route_is_protected_by_its_permissions_paramater(self): + def test_route_is_protected_by_its_permissions_paramater(self, get_route_function): example_allow_any_route_function = get_route_function( self.controller.example_allow_any ) response = example_allow_any_route_function(anonymous_request) assert response == {"message": "OK"} - def test_route_prep_controller_route_execution_context_works(self): + def test_route_prep_controller_route_execution_context_works( + self, get_route_function + ): route_function: RouteFunction = get_route_function(SomeTestController().example) context = get_route_execution_context(request=anonymous_request) with route_function._prep_controller_route_execution( @@ -402,7 +430,7 @@ def test_route_prep_controller_route_execution_context_works(self): assert ctx.controller_instance.context is None def test_route_prep_controller_route_execution_context_cleans_controller_after_route_execution( - self, + self, get_route_function ): route_function: RouteFunction = get_route_function(SomeTestController().example) context = get_route_execution_context(request=anonymous_request) diff --git a/tests/test_searching.py b/tests/test_searching.py index 4b4ac994..c222841f 100644 --- a/tests/test_searching.py +++ b/tests/test_searching.py @@ -1,4 +1,3 @@ -import inspect import operator from typing import List @@ -7,7 +6,9 @@ from ninja import Schema from ninja_extra import NinjaExtraAPI, api_controller, route -from ninja_extra.constants import ROUTE_FUNCTION +from ninja_extra.constants import SEARCH_OPERATOR_OBJECT +from ninja_extra.controllers.utils import get_api_controller +from ninja_extra.reflect import reflect from ninja_extra.searching import ( AsyncSearcheratorOperation, SearcheratorOperation, @@ -87,18 +88,21 @@ def items_6(self, **kwargs): @pytest.mark.django_db class TestSearch: def test_Search_operation_used(self): - some_api_route_functions = dict( - inspect.getmembers( - SomeAPIController, lambda member: hasattr(member, ROUTE_FUNCTION) - ) - ) has_kwargs = ("items_3", "items_4") - found_route_functions = False + api_controller_instance = get_api_controller(SomeAPIController) - for name, route_function in some_api_route_functions.items(): - assert hasattr(route_function, "searcherator_operation") - searcherator_operation = route_function.searcherator_operation + found_route_functions = False + for ( + name, + route_function, + ) in api_controller_instance._controller_class_route_functions.items(): + assert reflect.has_metadata(SEARCH_OPERATOR_OBJECT, route_function.as_view) + + searcherator_operation = reflect.get_metadata( + SEARCH_OPERATOR_OBJECT, route_function.as_view + ) assert isinstance(searcherator_operation, SearcheratorOperation) + if name in has_kwargs: assert searcherator_operation.view_func_has_kwargs found_route_functions = True @@ -269,18 +273,22 @@ async def items_8(self, **kwargs): client = TestAsyncClient(AsyncSomeAPIController) async def test_Search_operation_used(self): - some_api_route_functions = dict( - inspect.getmembers( - self.AsyncSomeAPIController, - lambda member: hasattr(member, ROUTE_FUNCTION), - ) - ) has_kwargs = ("items_3", "items_4") + + api_controller_instance = get_api_controller(self.AsyncSomeAPIController) + found_route_functions = False + for ( + name, + route_function, + ) in api_controller_instance._controller_class_route_functions.items(): + assert reflect.has_metadata( + SEARCH_OPERATOR_OBJECT, route_function.as_view + ) - for name, route_function in some_api_route_functions.items(): - assert hasattr(route_function, "searcherator_operation") - searcherator_operation = route_function.searcherator_operation + searcherator_operation = reflect.get_metadata( + SEARCH_OPERATOR_OBJECT, route_function.as_view + ) assert isinstance(searcherator_operation, AsyncSearcheratorOperation) if name in has_kwargs: assert searcherator_operation.view_func_has_kwargs diff --git a/tests/test_throthling/test_throttle_controller.py b/tests/test_throthling/test_throttle_controller.py index 580f3862..da6bfc18 100644 --- a/tests/test_throthling/test_throttle_controller.py +++ b/tests/test_throthling/test_throttle_controller.py @@ -4,6 +4,7 @@ from ninja.constants import NOT_SET from ninja_extra import ControllerBase, api_controller, http_get +from ninja_extra.controllers.utils import get_api_controller from ninja_extra.testing import TestClient from ninja_extra.throttling import ( AnonRateThrottle, @@ -48,7 +49,7 @@ def test_all_controller_func_has_throttling_decorator(self): cloned_controller = api_controller( "/throttled-controller", throttle=DynamicRateThrottle(rate="5/min") )(type("ThrottlingControllerSample", (ThrottlingControllerSample,), {})) - api_controller_instance = cloned_controller.get_api_controller() + api_controller_instance = get_api_controller(cloned_controller) for ( _, func, From 669333b09d53696d1dfeaf4d0bea18aa90891e38 Mon Sep 17 00:00:00 2001 From: Ezeudoh Tochukwu Date: Tue, 16 Dec 2025 22:24:46 +0100 Subject: [PATCH 2/3] Added test for sub-classable controller #319. Thanks to @anentropic --- tests/test_controller.py | 59 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/test_controller.py b/tests/test_controller.py index 4f121759..45b4bf10 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -24,6 +24,22 @@ from .utils import AsyncFakeAuth, FakeAuth +class ReportControllerBase(ControllerBase): + @http_get("") + def report(self): + return {"controller": type(self).__name__} + + +@api_controller("/alpha", urls_namespace="alpha") +class AlphaReportController(ReportControllerBase): + pass + + +@api_controller("/beta", urls_namespace="beta") +class BetaReportController(ReportControllerBase): + pass + + @api_controller class SomeController: pass @@ -384,3 +400,46 @@ def test_namespaced_controller_detail(client): def test_default_url_name(client): assert reverse("api-1.0.0:get_event", kwargs={"id": 5}) == "/api/events/5" + + +def test_controller_subclass_routes_remain_isolated(): + api = NinjaExtraAPI() + api.register_controllers(AlphaReportController, BetaReportController) + client = testing.TestClient(api) + + alpha_response = client.get("/alpha") + beta_response = client.get("/beta") + + assert alpha_response.status_code == 200 + assert beta_response.status_code == 200 + assert alpha_response.json() == {"controller": "AlphaReportController"} + assert beta_response.json() == {"controller": "BetaReportController"} + + +def test_controller_multi_level_inheritance_routes_isolated(): + """Test that route isolation works with multi-level inheritance.""" + + # Middle layer doesn't override the route + class MiddleReportController(ReportControllerBase): + pass + + @api_controller("/gamma") + class GammaReportController(MiddleReportController): + pass + + @api_controller("/delta") + class DeltaReportController(MiddleReportController): + pass + + api = NinjaExtraAPI() + api.register_controllers(GammaReportController) + api.register_controllers(DeltaReportController) + client = testing.TestClient(api) + + gamma_response = client.get("/gamma") + delta_response = client.get("/delta") + + assert gamma_response.status_code == 200 + assert delta_response.status_code == 200 + assert gamma_response.json() == {"controller": "GammaReportController"} + assert delta_response.json() == {"controller": "DeltaReportController"} From b1d54a1eee0087ade6010c72fe2b7cb530100419 Mon Sep 17 00:00:00 2001 From: Ezeudoh Tochukwu Date: Tue, 16 Dec 2025 22:26:38 +0100 Subject: [PATCH 3/3] Added test for controller reuse across multiple NinjaExtraAPI instances #293. Thanks to @jdiego --- tests/test_api_instance.py | 87 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tests/test_api_instance.py b/tests/test_api_instance.py index c35c93f0..4aeb1b37 100644 --- a/tests/test_api_instance.py +++ b/tests/test_api_instance.py @@ -1,3 +1,4 @@ +import typing as t from unittest import mock import pytest @@ -84,3 +85,89 @@ def example(self): res = client.get("/another/example") assert res.status_code == 200 assert res.content == b'"Create Response Works"' + + +def test_same_controller_two_apis_works(): + @api_controller("/ping") + class P: + @http_get("") + def ping(self): + return {"ok": True} + + a = NinjaExtraAPI(urls_namespace="a") + b = NinjaExtraAPI(urls_namespace="b") + + a.register_controllers(P) + b.register_controllers(P) # triggers clone path + + assert TestClient(a).get("/ping").json() == {"ok": True} + assert TestClient(b).get("/ping").json() == {"ok": True} + + +def test_openapi_schema_params_are_correct_on_two_apis(): + @api_controller("/") + class ItemsController: + @http_get("/items_1") + def items_1(self, ordering: t.Optional[str] = None): + return {"ok": True} + + # Two independent API instances + api_a = NinjaExtraAPI(title="A") + api_b = NinjaExtraAPI(title="B") + + api_a.register_controllers(ItemsController) + api_b.register_controllers(ItemsController) + + expected_params = [ + { + "in": "query", + "name": "ordering", + "required": False, + "schema": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Ordering", + }, + } + ] + + # Check API A schema + schema_a = api_a.get_openapi_schema() + op_a = schema_a["paths"]["/api/items_1"]["get"] + assert op_a["parameters"] == expected_params + + # Check API B schema + schema_b = api_b.get_openapi_schema() + op_b = schema_b["paths"]["/api/items_1"]["get"] + assert op_b["parameters"] == expected_params + + # (Optional) also confirm the route actually works on both APIs + ca = TestClient(api_a) + cb = TestClient(api_b) + assert ca.get("/items_1").status_code == 200 + assert cb.get("/items_1").status_code == 200 + + +def test_clone_is_cached_per_api_not_recreated(): + """Register the same original class twice on the same API -> reuse cached clone, no new routers.""" + + @api_controller("/x") + class X: + @http_get("") + def ok(self): + return {"ok": True} + + a = NinjaExtraAPI(urls_namespace="a") + b = NinjaExtraAPI(urls_namespace="b") + + # Mount on A (original) + a.register_controllers(X) + # Mount on B (clone) + b.register_controllers(X) + # Re-register same original on B (should reuse the cached clone; no new routers added) + before = len(b._routers) + b.register_controllers(X) + after = len(b._routers) + assert before == after + + # Optional: ensure path exists and works + assert TestClient(b).get("/x").json() == {"ok": True}