1212from collections import abc
1313from copy import copy
1414from functools import partial , reduce
15- from typing import cast
1615
1716import boost_histogram as bh
1817import numpy as np
4443RangeType = tuple [float | None , float | None ]
4544
4645
47- # TODO add flow
48-
49-
5046def histogram (
5147 x : xr .DataArray ,
5248 / ,
@@ -535,30 +531,51 @@ def get_coord(name: str, ax: bh.axis.Axis, dtype: np.dtype, flow: bool) -> xr.Da
535531 attrs = dict (bin_type = type (ax ).__name__ , underflow = underflow , overflow = overflow )
536532
537533 if isinstance (ax , bh .axis .Integer ):
538- lefts = ax .edges [:- 1 ].astype (dtype )
534+ if dtype .kind not in "uib" :
535+ raise TypeError (f"Cannot use Integer axis for dtype { dtype } " )
536+
537+ lefts = ax .edges [:- 1 ].astype ("int" )
538+
539+ # deal with bool variables
540+ if dtype .kind == "b" and not (underflow or overflow ):
541+ lefts = lefts .astype ("bool" )
542+
539543 # use min/max possible encoded values to indicate flow
544+ bins_dtype = lefts .dtype
540545 if underflow :
541- if dtype .kind == "u" :
542- dtype = np .dtype (f"i{ min (dtype .itemsize * 2 , 8 )} " )
543- vmin = np .iinfo (dtype ).min
544- lefts = np .concatenate (([vmin ], lefts ), dtype = dtype )
546+ vmin = np .iinfo (bins_dtype ).min
547+ lefts = np .concatenate (([vmin ], lefts ), dtype = bins_dtype )
545548 if overflow :
546- vmax = np .iinfo (dtype ).max
547- lefts = np .concatenate ((lefts , [vmax ]), dtype = dtype )
549+ vmax = np .iinfo (bins_dtype ).max
550+ lefts = np .concatenate ((lefts , [vmax ]), dtype = bins_dtype )
548551
549552 elif isinstance (ax , bh .axis .IntCategory ):
550- lefts = np .asarray ([cast (int , ax .bin (i )) for i in range (ax .size )])
551- lefts = lefts .astype (dtype )
553+ if dtype .kind not in "uib" :
554+ raise TypeError (f"Cannot use Integer axis for dtype { dtype } " )
555+
556+ lefts = np .asarray ([ax .bin (i ) for i in range (ax .size )], dtype = "int" )
557+
558+ # deal with bool variables
559+ if dtype .kind == "b" and not (underflow or overflow ):
560+ lefts = lefts .astype ("bool" )
561+
562+ bins_dtype = lefts .dtype
552563 if overflow :
553- lefts = np .concatenate ((lefts , [np .iinfo (dtype ).max ]), dtype = dtype )
564+ lefts = np .concatenate (
565+ (lefts , [np .iinfo (bins_dtype ).max ]), dtype = bins_dtype
566+ )
554567
555568 elif isinstance (ax , bh .axis .StrCategory ):
569+ if dtype .kind not in "SU" :
570+ raise TypeError (f"Cannot use StrCategory axis for dtype { dtype } " )
556571 lefts = np .asarray ([ax .bin (i ) for i in range (ax .size )])
557572 if overflow :
558573 lefts = np .concatenate ((lefts , ["_flow_bin" ]))
559574
560575 else :
561- lefts = ax .edges [:- 1 ].astype (dtype , casting = "safe" )
576+ if dtype .kind not in "biuf" :
577+ raise TypeError (f"Cannot use { type (ax ).__name__ } axis for dtype { dtype } " )
578+ lefts = ax .edges [:- 1 ]
562579 attrs ["right_edge" ] = ax .edges [- 1 ]
563580 if underflow :
564581 lefts = np .concatenate (([- np .inf ], lefts ))
@@ -575,13 +592,19 @@ def _bins_name(variable: str) -> str:
575592def get_edges (coord : xr .DataArray ) -> xr .DataArray :
576593 """Return edges positions."""
577594 name = coord .name
578- if coord .attrs ["bin_type" ] in ["Integer" , "IntCategory" , "StrCategory" ]:
579- return xr .DataArray (coord .values , dims = [name ], name = name )
595+ bin_type = coord .attrs ["bin_type" ]
596+ if bin_type in ["IntCategory" , "StrCategory" ]:
597+ raise TypeError (f"Edges not available for { bin_type } bins type." )
580598
581- # insert right_edge
599+ overflow = coord .attrs .get ("overflow" , False )
600+
601+ if bin_type == "Integer" :
602+ right_edge = coord [- 2 if overflow else - 1 ] + 1
603+ else :
604+ right_edge = coord .attrs ["right_edge" ]
582605 values = coord .values
583- insert = values .size - 1 if coord . attrs . get ( " overflow" , False ) else values .size
584- values = np .insert (values , insert , [coord . attrs [ " right_edge" ] ])
606+ insert = values .size - 1 if overflow else values .size
607+ values = np .insert (values , insert , [right_edge ])
585608
586609 return xr .DataArray (values , dims = [name ], name = name )
587610
0 commit comments