Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
# Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from
# ARRAY_API_TESTS_SKIP_DTYPES
all_dtypes = sampled_from(_sorted_dtypes)
int_dtypes = sampled_from(dh.all_int_dtypes)
all_int_dtypes = sampled_from(dh.all_int_dtypes)
int_dtypes = sampled_from(dh.int_dtypes) # signed ints
uint_dtypes = sampled_from(dh.uint_dtypes)
real_dtypes = sampled_from(dh.real_dtypes)
# Warning: The hypothesis "floating_dtypes" is what we call
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def test_repeat(x, kw, data):
size = math.prod(shape) if axis is None else shape[axis]
repeat_strat = st.integers(1, 10)
repeats = data.draw(repeat_strat
| hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat,
| hh.arrays(dtype=hh.all_int_dtypes, elements=repeat_strat,
shape=st.sampled_from([(1,), (size,)])),
label="repeats")
if isinstance(repeats, int):
Expand Down
9 changes: 4 additions & 5 deletions meta_tests/test_array_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from hypothesis import strategies as st

from array_api_tests import _array_module as xp
from array_api_tests.hypothesis_helpers import (int_dtypes, arrays,
two_mutually_broadcastable_shapes)
import array_api_tests.hypothesis_helpers as hh
from array_api_tests.shape_helpers import iter_indices, broadcast_shapes
from array_api_tests .array_helpers import exactly_equal, notequal, less

Expand All @@ -24,10 +23,10 @@ def test_notequal():
assert xp.all(xp.equal(notequal(a, b), res))


@given(two_mutually_broadcastable_shapes, int_dtypes, int_dtypes, st.data())
@given(hh.two_mutually_broadcastable_shapes, hh.all_int_dtypes, hh.all_int_dtypes, st.data())
def test_less(shapes, dtype1, dtype2, data):
x = data.draw(arrays(shape=shapes[0], dtype=dtype1))
y = data.draw(arrays(shape=shapes[1], dtype=dtype2))
x = data.draw(hh.arrays(shape=shapes[0], dtype=dtype1))
y = data.draw(hh.arrays(shape=shapes[1], dtype=dtype2))

res = less(x, y)

Expand Down