FIX: Wrap torch.argsort to set stable=True by default#356
FIX: Wrap torch.argsort to set stable=True by default#356ev-br merged 4 commits intodata-apis:mainfrom
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR wraps torch.argsort to set stable=True by default, aligning it with the array API specification and matching the behavior of the existing sort wrapper.
- Adds a new
argsortfunction wrapper that defaultsstableparameter toTrue
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Remove the empty line with trailing whitespace inside the function body. This line serves no purpose and should be deleted. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
In pytorch, both Which looks correct and wanted indeed. |
|
Need to also add $ git diff
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 715182a..fc1688a 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -851,7 +851,8 @@ __all__ = ['asarray', 'result_type', 'can_cast',
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
- 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
+ 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod',
+ 'argsort', 'sort', 'prod', 'sum',
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',To verify: run data-apis/array-api-tests#390 with |
|
Also cross-ref data-apis/array-api-tests#390 (comment). |
Thanks for the tip 👍 |
betatim
left a comment
There was a problem hiding this comment.
Let's merge this.
Not sure if we need to wait for the -tests PR to be merged first.
|
Okay, I'll send a quick follow-up PR with $ git diff
diff --git a/tests/test_torch.py b/tests/test_torch.py
index 7adb4ab..a367c7b 100644
--- a/tests/test_torch.py
+++ b/tests/test_torch.py
@@ -117,3 +117,14 @@ def test_meshgrid():
assert Y.shape == Y_xy.shape
assert xp.all(Y == Y_xy)
+
+
+def test_argsort_stable():
+ """Verify that argsort defaults to a stable sort."""
+ # Bare pytorch defaults to an unstable sort, and the array_api_compat wrapper
+ # enforces the stable=True default.
+ # cf https://github.com/data-apis/array-api-compat/pull/356 and
+ # https://github.com/data-apis/array-api-tests/pull/390#issuecomment-3452868329
+
+ t = xp.zeros(50) # should be >16
+ assert xp.all(xp.argsort(t) == xp.arange(50)) |
cross-ref data-apis#356 which wrapped torch.argsort to fix the default, and data-apis/array-api-tests#390 which made a matching change in the array-api-test suite.
cross-ref data-apis#356 which wrapped torch.argsort to fix the default, and data-apis/array-api-tests#390 which made a matching change in the array-api-test suite.
|
A follow-up in #358 |
Everything is in the title.
Fixes #354