Skip to content

Commit 4b7dc2b

Browse files
committed
Up edges creation
1 parent f196765 commit 4b7dc2b

File tree

3 files changed

+242
-105
lines changed

3 files changed

+242
-105
lines changed

src/xarray_histogram/core.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from collections import abc
1313
from copy import copy
1414
from functools import partial, reduce
15-
from typing import cast
1615

1716
import boost_histogram as bh
1817
import numpy as np
@@ -44,9 +43,6 @@
4443
RangeType = tuple[float | None, float | None]
4544

4645

47-
# TODO add flow
48-
49-
5046
def histogram(
5147
x: xr.DataArray,
5248
/,
@@ -535,30 +531,51 @@ def get_coord(name: str, ax: bh.axis.Axis, dtype: np.dtype, flow: bool) -> xr.Da
535531
attrs = dict(bin_type=type(ax).__name__, underflow=underflow, overflow=overflow)
536532

537533
if isinstance(ax, bh.axis.Integer):
538-
lefts = ax.edges[:-1].astype(dtype)
534+
if dtype.kind not in "uib":
535+
raise TypeError(f"Cannot use Integer axis for dtype {dtype}")
536+
537+
lefts = ax.edges[:-1].astype("int")
538+
539+
# deal with bool variables
540+
if dtype.kind == "b" and not (underflow or overflow):
541+
lefts = lefts.astype("bool")
542+
539543
# use min/max possible encoded values to indicate flow
544+
bins_dtype = lefts.dtype
540545
if underflow:
541-
if dtype.kind == "u":
542-
dtype = np.dtype(f"i{min(dtype.itemsize * 2, 8)}")
543-
vmin = np.iinfo(dtype).min
544-
lefts = np.concatenate(([vmin], lefts), dtype=dtype)
546+
vmin = np.iinfo(bins_dtype).min
547+
lefts = np.concatenate(([vmin], lefts), dtype=bins_dtype)
545548
if overflow:
546-
vmax = np.iinfo(dtype).max
547-
lefts = np.concatenate((lefts, [vmax]), dtype=dtype)
549+
vmax = np.iinfo(bins_dtype).max
550+
lefts = np.concatenate((lefts, [vmax]), dtype=bins_dtype)
548551

549552
elif isinstance(ax, bh.axis.IntCategory):
550-
lefts = np.asarray([cast(int, ax.bin(i)) for i in range(ax.size)])
551-
lefts = lefts.astype(dtype)
553+
if dtype.kind not in "uib":
554+
raise TypeError(f"Cannot use Integer axis for dtype {dtype}")
555+
556+
lefts = np.asarray([ax.bin(i) for i in range(ax.size)], dtype="int")
557+
558+
# deal with bool variables
559+
if dtype.kind == "b" and not (underflow or overflow):
560+
lefts = lefts.astype("bool")
561+
562+
bins_dtype = lefts.dtype
552563
if overflow:
553-
lefts = np.concatenate((lefts, [np.iinfo(dtype).max]), dtype=dtype)
564+
lefts = np.concatenate(
565+
(lefts, [np.iinfo(bins_dtype).max]), dtype=bins_dtype
566+
)
554567

555568
elif isinstance(ax, bh.axis.StrCategory):
569+
if dtype.kind not in "SU":
570+
raise TypeError(f"Cannot use StrCategory axis for dtype {dtype}")
556571
lefts = np.asarray([ax.bin(i) for i in range(ax.size)])
557572
if overflow:
558573
lefts = np.concatenate((lefts, ["_flow_bin"]))
559574

560575
else:
561-
lefts = ax.edges[:-1].astype(dtype, casting="safe")
576+
if dtype.kind not in "biuf":
577+
raise TypeError(f"Cannot use {type(ax).__name__} axis for dtype {dtype}")
578+
lefts = ax.edges[:-1]
562579
attrs["right_edge"] = ax.edges[-1]
563580
if underflow:
564581
lefts = np.concatenate(([-np.inf], lefts))
@@ -575,13 +592,19 @@ def _bins_name(variable: str) -> str:
575592
def get_edges(coord: xr.DataArray) -> xr.DataArray:
576593
"""Return edges positions."""
577594
name = coord.name
578-
if coord.attrs["bin_type"] in ["Integer", "IntCategory", "StrCategory"]:
579-
return xr.DataArray(coord.values, dims=[name], name=name)
595+
bin_type = coord.attrs["bin_type"]
596+
if bin_type in ["IntCategory", "StrCategory"]:
597+
raise TypeError(f"Edges not available for {bin_type} bins type.")
580598

581-
# insert right_edge
599+
overflow = coord.attrs.get("overflow", False)
600+
601+
if bin_type == "Integer":
602+
right_edge = coord[-2 if overflow else -1] + 1
603+
else:
604+
right_edge = coord.attrs["right_edge"]
582605
values = coord.values
583-
insert = values.size - 1 if coord.attrs.get("overflow", False) else values.size
584-
values = np.insert(values, insert, [coord.attrs["right_edge"]])
606+
insert = values.size - 1 if overflow else values.size
607+
values = np.insert(values, insert, [right_edge])
585608

586609
return xr.DataArray(values, dims=[name], name=name)
587610

tests/test_accessor.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ def get_blank_histogram(n_var: int = 1) -> xr.DataArray:
2929
return h
3030

3131

32+
def get_hist(*axes: bh.axis.Axis, flow=True) -> xr.DataArray:
33+
data = []
34+
for ax in axes:
35+
x = get_array([2])
36+
if isinstance(ax, bh.axis.Integer | bh.axis.IntCategory):
37+
x = x.astype("int")
38+
if isinstance(ax, bh.axis.StrCategory):
39+
x = x.astype("U")
40+
data.append(x)
41+
return xh.histogramdd(*data, bins=axes, flow=flow)
42+
43+
3244
class TestAccessibility:
3345
@pytest.mark.parametrize("x", [get_array([20]), get_array([20, 5])], ids=id_x)
3446
@bool_param("density")
@@ -65,8 +77,42 @@ def test_variable_argument() -> None:
6577

6678

6779
class TestEdges:
80+
def test_edges(self) -> None:
81+
# Regular
82+
h = get_hist(bh.axis.Regular(3, 0.0, 0.3), flow=False)
83+
assert_allclose(h.hist.edges(), [0.0, 0.1, 0.2, 0.3])
84+
h = get_hist(bh.axis.Regular(3, 0.0, 0.3, underflow=False))
85+
assert_allclose(h.hist.edges(), [0.0, 0.1, 0.2, 0.3, np.inf])
86+
h = get_hist(bh.axis.Regular(3, 0.0, 0.3))
87+
assert_allclose(h.hist.edges(), [-np.inf, 0.0, 0.1, 0.2, 0.3, np.inf])
88+
89+
# Integer
90+
h = get_hist(bh.axis.Integer(0, 3), flow=False)
91+
assert_allclose(h.hist.edges(), [0, 1, 2, 3])
92+
h = get_hist(bh.axis.Integer(0, 3, underflow=False))
93+
dtype = h.var1_bins.dtype
94+
vmin = np.iinfo(dtype).min
95+
vmax = np.iinfo(dtype).max
96+
assert_allclose(h.hist.edges(), [0, 1, 2, 3, vmax])
97+
h = get_hist(bh.axis.Integer(0, 3))
98+
assert_allclose(h.hist.edges(), [vmin, 0, 1, 2, 3, vmax])
99+
100+
# Variable
101+
h = get_hist(bh.axis.Variable([0, 1, 3, 10]), flow=False)
102+
assert_allclose(h.hist.edges(), [0, 1, 3, 10])
103+
h = get_hist(bh.axis.Variable([0, 1, 3, 10], underflow=False))
104+
assert_allclose(h.hist.edges(), [0, 1, 3, 10, np.inf])
105+
h = get_hist(bh.axis.Variable([0, 1, 3, 10]))
106+
assert_allclose(h.hist.edges(), [-np.inf, 0, 1, 3, 10, np.inf])
107+
108+
# Not supported
109+
for ax in [bh.axis.IntCategory([0, 1, 2]), bh.axis.StrCategory(["a", "b"])]:
110+
h = get_hist(ax)
111+
with pytest.raises(TypeError):
112+
h.hist.edges()
113+
68114
def test_infer_right_edge(self) -> None:
69-
h = get_blank_histogram()
115+
h = get_hist(bh.axis.Regular(10, 0.0, 1.0), flow=False)
70116
# this reset right edge
71117
h = h.assign_coords(var1_bins=np.arange(0, 10))
72118

@@ -78,10 +124,6 @@ def test_infer_right_edge(self) -> None:
78124
with pytest.raises(ValueError):
79125
_ = h_wrong.hist
80126

81-
def test_basic(self) -> None:
82-
h = get_blank_histogram()
83-
assert_allclose(h.hist.edges("var1"), np.arange(0, 11))
84-
85127
def test_centers(self):
86128
h = get_blank_histogram()
87129
assert_allclose(h.hist.centers(), np.arange(0, 10) + 0.5)

0 commit comments

Comments
 (0)