diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 26a348d1..0d603d68 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -180,19 +180,15 @@ def test_nonzero(x): for idx in sh.ndindex(x.shape): if x[idx] != 0: indices.append(idx) - if x.ndim == 0: - assert out_size == len( - indices - ), f"prod(out[0].shape)={out_size}, but should be {len(indices)}" - else: - for i in range(out_size): - idx = tuple(int(x[i]) for x in out) - f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" - f_element = f"x[{idx}]={x[idx]}" - assert idx in indices, f"{f_idx} results in {f_element}, a zero element" - assert ( - idx == indices[i] - ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" + + for i in range(out_size): + idx = tuple(int(x[i]) for x in out) + f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" + f_element = f"x[{idx}]={x[idx]}" + assert idx in indices, f"{f_idx} results in {f_element}, a zero element" + assert ( + idx == indices[i] + ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" except Exception as exc: ph.add_note(exc, repro_snippet) raise