Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/mlx/_stub_patterns.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mlx.core.__prefix__:
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, ParamSpec, TypeVar
import sys
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
P = ParamSpec("P")
R = TypeVar("R")

mlx.core.__suffix__:
from typing import Union
Expand Down
4 changes: 2 additions & 2 deletions python/src/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,7 @@ void init_transforms(nb::module_& m) {
"argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{},
nb::sig(
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
"def value_and_grad(fun: Callable[P, R], argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable[P, Tuple[R, Any]]"),
R"pbdoc(
Returns a function which computes the value and gradient of ``fun``.
Expand Down Expand Up @@ -1472,7 +1472,7 @@ void init_transforms(nb::module_& m) {
"outputs"_a = nb::none(),
"shapeless"_a = false,
nb::sig(
"def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"),
"def compile(fun: Callable[P, R], inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable[P, R]"),
R"pbdoc(
Returns a compiled function which produces the same output as ``fun``.
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,15 @@ def run(self):
self.copy_tree(regular_dir, inplace_dir)

# Build type stubs.
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
extdir = ext_fullpath.parent.parent.resolve()
build_temp = Path(self.build_temp) / ext.name
env = os.environ.copy()
env["PYTHONPATH"] += f":{extdir}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this actually fixes the problem you saw, the exact path has already been passed in the stubgen target:

# Run stubgen -m mlx.core -i python -p _stub_patterns.txt -o python/mlx
RECURSIVE
MODULE
"mlx.core"
PYTHON_PATH
"${CMAKE_CURRENT_SOURCE_DIR}/.."

Can you share the exact command that you used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the command is pip install git+https://github.com/ml-explore/mlx.git -vv and pip install -vv git+https://github.com/XXXXRT666/mlx.git@type-fix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMAKE_CURRENT_SOURCE_DIR = ~/mlx/python/src in my environment, but mlx.core is generated in ~/mlx/build/lib.macosx-11.0-arm64-cpython-310/mlx/core.cpython-310-darwin.so

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the information! Can you check if #3009 fixes the error for you?

subprocess.run(
["cmake", "--install", build_temp, "--component", "core_stub"],
check=True,
env=env,
)


Expand Down
Loading