From 24c9b6ab021ba71930d7412be5e50a5d70e0cdbd Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 16 Jan 2026 18:37:20 +0000 Subject: [PATCH 1/2] Add More Type Hints in Func Transforms --- python/mlx/_stub_patterns.txt | 4 +++- python/src/transforms.cpp | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/mlx/_stub_patterns.txt b/python/mlx/_stub_patterns.txt index 637d7de4fe..7ca4369805 100644 --- a/python/mlx/_stub_patterns.txt +++ b/python/mlx/_stub_patterns.txt @@ -1,10 +1,12 @@ mlx.core.__prefix__: - from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, ParamSpec, TypeVar import sys if sys.version_info >= (3, 10): from typing import TypeAlias else: from typing_extensions import TypeAlias + P = ParamSpec("P") + R = TypeVar("R") mlx.core.__suffix__: from typing import Union diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 67863038eb..806b416180 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1333,7 +1333,7 @@ void init_transforms(nb::module_& m) { "argnums"_a = nb::none(), "argnames"_a = std::vector{}, nb::sig( - "def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"), + "def value_and_grad(fun: Callable[P, R], argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable[P, Tuple[R, Any]]"), R"pbdoc( Returns a function which computes the value and gradient of ``fun``. @@ -1472,7 +1472,7 @@ void init_transforms(nb::module_& m) { "outputs"_a = nb::none(), "shapeless"_a = false, nb::sig( - "def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"), + "def compile(fun: Callable[P, R], inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable[P, R]"), R"pbdoc( Returns a compiled function which produces the same output as ``fun``. From 887a6ca1e16332e54c3cc961e3fff97dead7ac8e Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 16 Jan 2026 21:42:27 +0000 Subject: [PATCH 2/2] Fix Bugs in building --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 00c7514408..e17ee714f7 100644 --- a/setup.py +++ b/setup.py @@ -169,10 +169,15 @@ def run(self): self.copy_tree(regular_dir, inplace_dir) # Build type stubs. + ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call] + extdir = ext_fullpath.parent.parent.resolve() build_temp = Path(self.build_temp) / ext.name + env = os.environ.copy() + env["PYTHONPATH"] += f":{extdir}" subprocess.run( ["cmake", "--install", build_temp, "--component", "core_stub"], check=True, + env=env, )