ENH: torch.asarray device propagation#296
Closed
crusaderky wants to merge 1 commit intodata-apis:mainfrom
Closed
Conversation
Contributor
There was a problem hiding this comment.
Pull Request Overview
This pull request implements a workaround for torch.asarray’s device propagation issue by adjusting imports and type annotations, along with minor device‐related tweaks across multiple array API compatibility modules.
- Update type hints for device parameters and return types (from str to Device)
- Replace the custom Device import with torch.device in torch/_typing.py
- Introduce a new asarray function in torch/_aliases.py to address dtype propagation
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| array_api_compat/torch/_typing.py | Updated import of Device from torch and modified the all order |
| array_api_compat/torch/_info.py | Adjusted docstrings to reflect Device type for default device methods |
| array_api_compat/torch/_aliases.py | Added a new asarray function and removed redundant torch.asarray calls |
| array_api_compat/numpy/_info.py | Changed device type annotations from str to Device |
| array_api_compat/dask/array/_info.py | Changed device type annotations from str to Device |
| array_api_compat/cupy/_info.py | Updated device type annotations and added Notes in docstring |
| array_api_compat/common/_aliases.py | Updated inline comments to include Dask in creation functions |
Comments suppressed due to low confidence (3)
array_api_compat/torch/_typing.py:3
- The change to import Device directly from torch (using torch.device) may affect the expected Device behavior compared to the previous custom definition. Please verify that this change preserves the intended device propagation semantics throughout the code.
from torch import device as Device, dtype as DType, Tensor as Array
array_api_compat/torch/_aliases.py:228
- In the new asarray function, only torch.Tensor inputs have their device set automatically. Consider verifying that non-tensor objects are handled correctly when device is None.
if device is None and isinstance(obj, torch.Tensor):
array_api_compat/torch/_aliases.py:308
- Removal of the explicit torch.asarray conversion in functions like prod, sum, any, and all might lead to unintended type issues if the input is not already a tensor. Confirm that callers always supply a tensor or that the new asarray function is applied upstream to avoid runtime errors.
- x = torch.asarray(x)
Contributor
Author
|
CI failures unrelated (#297) |
Member
|
close/reopen to rerun the CI |
Contributor
Author
|
I can't reopen a PR someone else closed. |
Member
|
Apologies for the glitch. No idea why it's not got reopened. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
asarray: device does not propagate from input to output afterset_default_devicepytorch/pytorch#150199prod,sum,anyandall, which had the same issue induced by array-api-compat