diff --git a/main/como/approx.py b/main/como/approx.py index 3ef69a74..441a489f 100644 --- a/main/como/approx.py +++ b/main/como/approx.py @@ -1,30 +1,64 @@ from __future__ import annotations from collections.abc import Callable, Sequence -from typing import Any +from typing import Literal, NamedTuple import numpy as np +import numpy.typing as npt -def _coerce_to_float_array(a): - """Helper to ensure input is a 1D float array.""" +class RegularizedArray(NamedTuple): + x: npt.NDArray[np.number] + y: npt.NDArray[np.number] + not_na: npt.NDArray[bool] + kept_na: bool + + +class Approx(NamedTuple): + x: npt.NDArray[float] + y: npt.NDArray[float] + + +def _coerce_to_float_array(a: Sequence[int | float | np.number]): + """Coerce input to a 1D float array. + + Args: + a: the array to coerce + + Returns: + A floating point 1D array. + """ arr = np.asarray(a, dtype=float) if arr.ndim != 1: arr = arr.ravel() return arr -def _regularize_values(x: np.ndarray, y: np.ndarray, ties, na_rm: bool) -> dict[str, Any]: - """Minimal reimplementation of R's regularize.values() used by approx(). +def _regularize_values( + x: npt.NDArray[np.number], + y: npt.NDArray[np.number], + *, + na_rm: bool, + ties: Callable[[np.ndarray], float] | Literal["mean", "first", "last", "min", "max", "median", "sum"] = "mean", +) -> RegularizedArray: + """Reimplementation of R's regularize.values() used by approx(). - Removes NA pairs if na_rm is True (else requires x to have no NA). - Sorts by x (stable). - Collapses duplicate x via `ties` aggregator. - Returns dict with: - x: sorted unique x - y: aggregated y aligned with x - not_na: boolean mask of y that are not NaN - kept_na: True iff any NaN remained in y after regularization + + Args: + x: the x values to regularize + y: ties: the corresponding y values + na_rm: should NA values be removed? + ties: how to handle duplicate x values; can be a string or a callable + + Returns: + A NamedTuple with: + - x: sorted unique x + - y: aggregated y aligned with x + - not_na: boolean mask of y that are not NaN + - kept_na: True iff any NaN remained in y after regularization """ if na_rm: # Remove pairs where x or y is NA @@ -39,8 +73,7 @@ def _regularize_values(x: np.ndarray, y: np.ndarray, ties, na_rm: bool) -> dict[ kept_na = np.isnan(y).any() if x.size == 0: - # THIS IS THE CORRECTED LINE: - return {"x": x, "y": y, "not_na": np.array([], dtype=bool), "kept_na": kept_na} + return RegularizedArray(x=x, y=y, not_na=np.array([], dtype=bool), kept_na=kept_na) # Use a stable sort (mergesort) to match R's order() order = np.argsort(x, kind="mergesort") @@ -51,27 +84,26 @@ def _regularize_values(x: np.ndarray, y: np.ndarray, ties, na_rm: bool) -> dict[ if callable(ties): agg_fn = ties else: - ties_str = "mean" if ties is None else str(ties).lower() - if ties_str in ("mean", "avg", "average"): + # fmt: off + if ties in ("mean", "avg", "average"): agg_fn = np.mean - elif ties_str in ("first", "left"): - + elif ties in ("first", "left"): def agg_fn(a): return a[0] - elif ties_str in ("last", "right"): - + elif ties in ("last", "right"): def agg_fn(a): return a[-1] - elif ties_str == "min": + elif ties == "min": agg_fn = np.min - elif ties_str == "max": + elif ties == "max": agg_fn = np.max - elif ties_str == "median": + elif ties == "median": agg_fn = np.median - elif ties_str == "sum": + elif ties == "sum": agg_fn = np.sum else: raise ValueError("Unsupported `ties`; use a callable or one of 'mean', 'first', 'last', 'min', 'max', 'median', 'sum'.") + # fmt: on # Find unique x values and their indices/counts unique_x, start_idx, counts = np.unique(x_sorted, return_index=True, return_counts=True) @@ -86,8 +118,8 @@ def agg_fn(a): not_na = ~np.isnan(y_agg) # Check if any NaNs remain in the *aggregated* y values - kept_na_final = np.any(~not_na) if np.any(np.isnan(y_agg)) else False - return {"x": unique_x, "y": y_agg, "not_na": not_na, "kept_na": kept_na_final} + kept_na_final = bool(np.any(~not_na) if np.any(np.isnan(y_agg)) else False) + return RegularizedArray(x=unique_x, y=y_agg, not_na=not_na, kept_na=kept_na_final) def approx( @@ -98,11 +130,11 @@ def approx( n: int = 50, yleft: float | None = None, yright: float | None = None, - rule: int | Sequence[int] = 1, + rule: int | Sequence[int] | npt.NDArray[int] = 1, f: float = 0.0, - ties: str | Callable[[np.ndarray], float] = "mean", + ties: Callable[[np.ndarray], float] | Literal["mean", "first", "last", "min", "max", "median", "sum"] = "mean", na_rm: bool = True, -) -> dict[str, np.ndarray]: +) -> Approx: """Faithful Python port of R's `approx` function. This implementation aims to replicate the behavior of R's `approx` @@ -114,26 +146,18 @@ def approx( y: y-coordinates. If None, `x` is treated as `y` and `x` becomes `1..n`. xout: Points at which to interpolate. method: Interpolation method. "linear" (1) or "constant" (2). - n: If `xout` is not provided, interpolation happens at `n` - equally spaced points spanning the range of `x`. - yleft: Value to use for extrapolation to the left. - Defaults to `NA` (np.nan) if `rule` is 1. - yright: Value to use for extrapolation to the right. - Defaults to `NA` (np.nan) if `rule` is 1. + n: If `xout` is not provided, interpolation happens at `n` equally spaced points spanning the range of `x`. + yleft: Value to use for extrapolation to the left. Defaults to `NA` (np.nan) if `rule` is 1. + yright: Value to use for extrapolation to the right. Defaults to `NA` (np.nan) if `rule` is 1. rule: Extrapolation rule. - 1: Return `np.nan` for points outside the `x` range. - 2: Use `yleft` and `yright` for points outside the range. - f: For `method="constant"`, determines the split point. - `f=0` is left-step, `f=1` is right-step, `f=0.5` is midpoint. - ties: How to handle duplicate `x` values. - Can be 'mean', 'first', 'last', 'min', 'max', 'median', 'sum', - or a callable function. - na_rm: If True, `NA` pairs are removed before interpolation. - If False, `NA`s in `x` cause an error, `NA`s in `y` - are propagated. + f: For `method="constant"`, determines the split point. `f=0` is left-step, `f=1` is right-step, `f=0.5` is midpoint. + ties: How to handle duplicate `x` values. Can be 'mean', 'first', 'last', 'min', 'max', 'median', 'sum', or a callable function. + na_rm: If True, `NA` pairs are removed before interpolation. If False, `NA`s in `x` cause an error, `NA`s in `y` are propagated. Returns: - dict with: + `Approx` class with: - 'x': numpy array of xout used - 'y': numpy array of interpolated values """ @@ -164,25 +188,24 @@ def approx( raise ValueError("invalid interpolation method") # --- Rule normalization --- - if isinstance(rule, list | tuple | np.ndarray): - rlist = list(rule) - if not (1 <= len(rlist) <= 2): + if isinstance(rule, Sequence | np.ndarray): + rule_list: list[int] = list(rule) # type: ignore[invalid-argument-type] # This is a valid argument type, not sure why ty is upset + if not (1 <= len(rule_list) <= 2): raise ValueError("`rule` must have length 1 or 2") - if len(rlist) == 1: - rlist = [rlist[0], rlist[0]] + if len(rule_list) == 1: + rule_list = [rule_list[0], rule_list[0]] else: - rlist = [rule, rule] + rule_list = [rule, rule] - # --- Regularize values --- # Sort by x, collapse ties, and handle NAs like R's regularize.values() - r = _regularize_values(x_arr, y_arr, ties, na_rm) - x_reg = r["x"] - y_reg = r["y"] - not_na_mask = r["not_na"] + r: RegularizedArray = _regularize_values(x_arr, y_arr, na_rm=na_rm, ties=ties) + x_reg: npt.NDArray[float] = r.x + y_reg: npt.NDArray[float] = r.y + not_na_mask: npt.NDArray[bool] = r.not_na # no_na is True if we don't have to worry about NAs in y_reg - no_na = na_rm or (not r["kept_na"]) + no_na: bool = na_rm or (not r.kept_na) # nx is the number of *valid* (non-NA) points for interpolation - nx = x_reg.size if no_na else int(np.count_nonzero(not_na_mask)) + nx: int = x_reg.size if no_na else int(np.count_nonzero(not_na_mask)) if np.isnan(nx): raise ValueError("invalid length(x)") @@ -194,98 +217,98 @@ def approx( if nx == 0: raise ValueError("zero non-NA points") - # --- Set extrapolation values (yleft/yright) --- + # set extrapolation values (yleft/yright) # This logic matches R's C code (R_approxtest) if no_na: - y_first = y_reg[0] - y_last = y_reg[-1] + y_first: float = y_reg[0] + y_last: float = y_reg[-1] else: # Find first and last non-NA y values - y_valid = y_reg[not_na_mask] - y_first = y_valid[0] - y_last = y_valid[-1] + y_valid: npt.NDArray[float] = y_reg[not_na_mask] + y_first: float = y_valid[0] + y_last: float = y_valid[-1] - yleft_val = (np.nan if int(rlist[0]) == 1 else y_first) if yleft is None else float(yleft) - yright_val = (np.nan if int(rlist[1]) == 1 else y_last) if yright is None else float(yright) + yleft_val: float = (np.nan if int(rule_list[0]) == 1 else y_first) if yleft is None else float(yleft) + yright_val: float = (np.nan if int(rule_list[1]) == 1 else y_last) if yright is None else float(yright) - # --- Fabricate xout if missing --- + # fabricate xout if it is missing if xout is None: if n <= 0: raise ValueError("'approx' requires n >= 1") if no_na: - x_first = x_reg[0] - x_last = x_reg[nx - 1] + x_first: float = x_reg[0] + x_last: float = x_reg[nx - 1] else: - x_valid = x_reg[not_na_mask] - x_first = x_valid[0] - x_last = x_valid[-1] - xout_arr = np.linspace(x_first, x_last, num=int(n), dtype=float) + x_valid: npt.NDArray[float] = x_reg[not_na_mask] + x_first: float = x_valid[0] + x_last: float = x_valid[-1] + xout_arr: npt.NDArray[float] = np.linspace(x_first, x_last, num=int(n), dtype=float) else: - xout_arr = _coerce_to_float_array(xout) + xout_arr: npt.NDArray[float] = _coerce_to_float_array(xout) - # --- C_ApproxTest checks --- + # replicate R's C_ApproxTest checks if method_code == 2 and (not np.isfinite(f) or f < 0.0 or f > 1.0): - raise ValueError("approx(): invalid f value") + raise ValueError("approx(): invalid f value; with `method=2`, `f` must be in the range [0,1] (inclusive) or NA") if not no_na: # R re-filters x and y here if NAs remained - x_reg = x_reg[not_na_mask] - y_reg = y_reg[not_na_mask] + x_reg: npt.NDArray[float] = x_reg[not_na_mask] + y_reg: npt.NDArray[float] = y_reg[not_na_mask] - # --- Interpolation --- - # This section vectorized the logic from R's approx1 and R_approxfun - yout = np.empty_like(xout_arr, dtype=float) - mask_nan_xout = np.isnan(xout_arr) + # perform interpolation + # this section is a vectorized form of the logic from R's approx1 and R_approxfun + yout: npt.NDArray[float] = np.empty_like(xout_arr, dtype=float) + mask_nan_xout: npt.NDArray[bool] = np.isnan(xout_arr) yout[mask_nan_xout] = np.nan - mask_valid = ~mask_nan_xout + mask_valid: npt.NDArray[bool] = ~mask_nan_xout if x_reg.size == 0: # No valid points to interpolate against yout[mask_valid] = np.nan - return {"x": xout_arr, "y": yout} + return Approx(x=xout_arr, y=yout) - xv = xout_arr[mask_valid] - left_mask = xv < x_reg[0] - right_mask = xv > x_reg[-1] - mid_mask = ~(left_mask | right_mask) + xv: npt.NDArray[float] = xout_arr[mask_valid] + left_mask: npt.NDArray[bool] = xv < x_reg[0] + right_mask: npt.NDArray[bool] = xv > x_reg[-1] + mid_mask: npt.NDArray[bool] = ~(left_mask | right_mask) - res = np.empty_like(xv) + res: npt.NDArray[float] = np.empty_like(xv, dtype=float) res[left_mask] = yleft_val res[right_mask] = yright_val if np.any(mid_mask): - xm = xv[mid_mask] + xm: npt.NDArray[float] = xv[mid_mask] # Find indices # j_right[k] = index of first x_reg > xm[k] - j_right = np.searchsorted(x_reg, xm, side="right") + j_right: npt.NDArray[int] = np.searchsorted(x_reg, xm, side="right") # j_left[k] = index of first x_reg >= xm[k] - j_left = np.searchsorted(x_reg, xm, side="left") + j_left: npt.NDArray[int] = np.searchsorted(x_reg, xm, side="left") # Points that exactly match an x_reg value - eq_mask = j_left != j_right + eq_mask: npt.NDArray[bool] = j_left != j_right # Points that fall between x_reg values - in_interval_mask = ~eq_mask + in_interval_mask: npt.NDArray[bool] = ~eq_mask - res_mid = np.empty_like(xm) + res_mid: npt.NDArray[float] = np.empty_like(xm, dtype=float) if np.any(eq_mask): # For exact matches, use the corresponding y_reg value # R's C code uses the 'j_left' index here - res_mid[eq_mask] = y_reg[j_left[eq_mask]] + res_mid[eq_mask] = y_reg[j_left[eq_mask]] # type: ignore[non-subscriptable] # we know this is of type npt.NDArray[float], not sure why ty is upset if np.any(in_interval_mask): # R's C code sets i = j-1, where j is the 'right' index - j = j_right[in_interval_mask] - i = j - 1 + j: npt.NDArray[float] = j_right[in_interval_mask] # type: ignore[non-subscriptable] # we know this is of type npt.NDArray[float], not sure why ty is upset + i: npt.NDArray[float] = j - 1 # Get the surrounding x and y values - xi = x_reg[i] - xj = x_reg[j] - yi = y_reg[i] - yj = y_reg[j] + xi: npt.NDArray[float] = x_reg[i] + xj: npt.NDArray[float] = x_reg[j] + yi: npt.NDArray[float] = y_reg[i] + yj: npt.NDArray[float] = y_reg[j] if method_code == 1: # linear - t = (xm[in_interval_mask] - xi) / (xj - xi) + t: npt.NDArray[float] = (xm[in_interval_mask] - xi) / (xj - xi) res_mid[in_interval_mask] = yi + (yj - yi) * t else: # constant # This matches R_approxfun's constant logic @@ -294,17 +317,17 @@ def approx( elif f == 1.0: res_mid[in_interval_mask] = yj else: - # R computes (1-f)*yi + f*yj, but carefully - f1 = 1.0 - f - f2 = f - part = np.zeros_like(yi) + # computes R's (1-f)*yi + f*yj + f1: float = float(1.0 - f) + f2: float = float(f) + part: npt.NDArray[float] = np.zeros_like(yi, dtype=float) if f1 != 0.0: - part = part + yi * f1 + part: npt.NDArray[float] = part + yi * f1 if f2 != 0.0: - part = part + yj * f2 + part: npt.NDArray[float] = part + yj * f2 res_mid[in_interval_mask] = part res[mid_mask] = res_mid yout[mask_valid] = res - return {"x": xout_arr, "y": yout} + return Approx(x=xout_arr, y=yout) diff --git a/main/como/density.py b/main/como/density.py index 0f75623a..810ced26 100644 --- a/main/como/density.py +++ b/main/como/density.py @@ -6,9 +6,8 @@ import numpy as np import numpy.typing as npt -from scipy.interpolate import interp1d -from como.approx import approx +from como.approx import Approx, approx class DensityResult(NamedTuple): @@ -116,7 +115,7 @@ def dnorm(x: float, mean: NUMBER = 0.0, sd: NUMBER = 1.0, log: bool = False, fas # underflow boundary boundary = np.sqrt(-2.0 * m_ln2 * (dbl_min_exp + 1 - dbl_mant_dig)) if a > boundary: - return float(0.0) + return 0.0 # Now, to get full accuracy, split x into two parts, # x = x1+x2, such that |x2| <= 2^-16. @@ -228,12 +227,15 @@ def density( elif kernel == "optcosine": return np.sqrt(1 - 8 / np.pi**2) * np.pi**2 / 16 + if kernel != "gaussian": + raise NotImplementedError(f"Only 'gaussian' kernel is implemented; got '{kernel}'") + x: npt.NDArray[float] = np.asarray(x, dtype=float) has_weights = weights is not None weights: npt.NDArray[float] | None = np.asarray(weights, float) if weights is not None else None - if has_weights and (weights is not None and weights.size != n): - raise ValueError(f"The length of provided weights does not match the length of x: {weights.size} != {n}") + if has_weights and (weights is not None and weights.size != x.size): + raise ValueError(f"The length of provided weights does not match the length of x: {weights.size} != {x.size}") x_na: npt.NDArray[np.bool_] = np.isnan(x) if np.any(x_na): @@ -286,8 +288,9 @@ def density( if bw_calc <= 0: raise ValueError("Bandwidth 'bw' must be positive.") - from_ = float(from_ or min(x) - cut * bw_calc) - to_ = float(to_ or max(x) + cut * bw_calc) + # have to use `... if ... else` because `0` is falsey, resulting in the right-half being used instead of the user-provided value + from_ = float(from_ if from_ is not None else min(x) - cut * bw_calc) + to_ = float(to_ if to_ is not None else max(x) + cut * bw_calc) if not np.isfinite(from_): raise ValueError("'from_' is not finite.") @@ -313,11 +316,11 @@ def density( # xp=known x-coords, fp=known y-cords, x=unknown x-coords; returns interpolated (e.g., unknown) y-coords interp_x: npt.NDArray[float] = np.linspace(from_, to_, num=n_user) - interp_y: npt.NDArray[float] = approx(xords, kords, interp_x) + interp_y: Approx = approx(xords, kords, interp_x) return DensityResult( x=interp_x, - y=interp_y["y"], + y=interp_y.y, x_grid=xords, y_grid=kords, bw=bw_calc, diff --git a/tests/unit/test_approx.py b/tests/unit/test_approx.py new file mode 100644 index 00000000..772dbc77 --- /dev/null +++ b/tests/unit/test_approx.py @@ -0,0 +1,237 @@ +import numpy as np +import pytest + +from como.approx import _coerce_to_float_array, _regularize_values, approx + + +class TestCoerceToFloatArray: + """Tests for the _coerce_to_float_array helper function.""" + + def test_converts_int_list_to_array(self): + result = _coerce_to_float_array([1, 2, 3]) + assert isinstance(result, np.ndarray) + assert result.dtype == float + np.testing.assert_array_equal(result, [1.0, 2.0, 3.0]) + + def test_converts_float_list_to_array(self): + result = _coerce_to_float_array([1.0, 2.0, 3.0]) + assert isinstance(result, np.ndarray) + assert result.dtype == float + np.testing.assert_array_equal(result, [1.0, 2.0, 3.0]) + + +class TestRegularizeValues: + """Tests for the _regularize_values function.""" + + def test_removes_na_pairs_when_na_rm_true(self): + x: list[float] = np.array([1.0, 2.0, np.nan, 4.0]) + y: list[float] = np.array([10.0, np.nan, 30.0, 40.0]) + result = _regularize_values(x, y, na_rm=True, ties="mean") + np.testing.assert_array_equal(result.x, [1.0, 4.0]) + np.testing.assert_array_equal(result.y, [10.0, 40.0]) + + def test_raises_error_with_na_in_x_when_na_rm_false(self): + x: list[float] = np.array([1.0, np.nan, 3.0]) + y: list[float] = np.array([10.0, 20.0, 30.0]) + with pytest.raises(ValueError, match="NA values in x are not allowed"): + _regularize_values(x, y, na_rm=False, ties="mean") + + def test_sorts_by_x(self): + x: list[float] = np.array([3.0, 1.0, 2.0]) + y: list[float] = np.array([30.0, 10.0, 20.0]) + result = _regularize_values(x, y, na_rm=True, ties="mean") + np.testing.assert_array_equal(result.x, [1.0, 2.0, 3.0]) + np.testing.assert_array_equal(result.y, [10.0, 20.0, 30.0]) + + def test_aggregates_duplicates_with_mean(self): + x: list[float] = np.array([1.0, 1.0, 2.0]) + y: list[float] = np.array([10.0, 20.0, 30.0]) + result = _regularize_values(x, y, na_rm=True, ties="mean") + np.testing.assert_array_equal(result.x, [1.0, 2.0]) + np.testing.assert_array_equal(result.y, [15.0, 30.0]) + + def test_aggregates_duplicates_with_first(self): + x: list[float] = np.array([1.0, 1.0, 2.0]) + y: list[float] = np.array([10.0, 20.0, 30.0]) + result = _regularize_values(x, y, na_rm=True, ties="first") + np.testing.assert_array_equal(result.x, [1.0, 2.0]) + np.testing.assert_array_equal(result.y, [10.0, 30.0]) + + def test_aggregates_duplicates_with_last(self): + x: list[float] = np.array([1.0, 1.0, 2.0]) + y: list[float] = np.array([10.0, 20.0, 30.0]) + result = _regularize_values(x, y, na_rm=True, ties="last") + np.testing.assert_array_equal(result.x, [1.0, 2.0]) + np.testing.assert_array_equal(result.y, [20.0, 30.0]) + + def test_handles_empty_arrays(self): + x: list[float] = np.array([]) + y: list[float] = np.array([]) + result = _regularize_values(x, y, na_rm=True, ties="mean") + assert result.x.size == 0 + assert result.y.size == 0 + assert result.not_na.size == 0 + + def test_callable_ties_function(self): + x: list[float] = np.array([1.0, 1.0, 2.0]) + y: list[float] = np.array([10.0, 20.0, 30.0]) + result = _regularize_values(x, y, na_rm=True, ties=np.sum) + np.testing.assert_array_equal(result.x, [1.0, 2.0]) + np.testing.assert_array_equal(result.y, [30.0, 30.0]) + + +class TestApprox: + """Tests for the main approx function.""" + + def test_basic_linear_interpolation(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [1.5, 2.5] + result = approx(x, y, xout=xout) + np.testing.assert_array_almost_equal(result.y, [15.0, 25.0]) + + def test_one_argument_form(self): + y: list[float] = [10.0, 20.0, 30.0] + result = approx(y, xout=[1.5, 2.5]) + np.testing.assert_array_almost_equal(result.y, [15.0, 25.0]) + + def test_default_n_points(self): + x: list[float] = [1.0, 5.0] + y: list[float] = [10.0, 50.0] + result = approx(x, y) + assert len(result.x) == 50 + assert len(result.y) == 50 + + def test_custom_n_points(self): + x: list[float] = [1.0, 5.0] + y: list[float] = [10.0, 50.0] + result = approx(x, y, n=10) + assert len(result.x) == 10 + + def test_extrapolation_rule_1(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [0.5, 3.5] + result = approx(x, y, xout=xout, rule=1) + assert np.isnan(result.y[0]) + assert np.isnan(result.y[1]) + + def test_extrapolation_rule_2(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [0.5, 3.5] + result = approx(x, y, xout=xout, rule=2) + assert result.y[0] == 10.0 + assert result.y[1] == 30.0 + + def test_yleft_yright_parameters(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [0.5, 3.5] + result = approx(x, y, xout=xout, yleft=5.0, yright=35.0) + assert result.y[0] == 5.0 + assert result.y[1] == 35.0 + + def test_constant_method_f_0(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [1.5] + result = approx(x, y, xout=xout, method="constant", f=0.0) + assert result.y[0] == 10.0 + + def test_constant_method_f_1(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [1.5] + result = approx(x, y, xout=xout, method="constant", f=1.0) + assert result.y[0] == 20.0 + + def test_constant_method_f_05(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [1.5] + result = approx(x, y, xout=xout, method="constant", f=0.5) + assert result.y[0] == 15.0 + + def test_exact_match_returns_exact_value(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [2.0] + result = approx(x, y, xout=xout) + assert result.y[0] == 20.0 + + def test_na_in_xout_returns_nan(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [np.nan, 2.0] + result = approx(x, y, xout=xout) + assert np.isnan(result.y[0]) + assert result.y[1] == 20.0 + + def test_method_numeric_codes(self): + x: list[float] = [1.0, 2.0] + y: list[float] = [10.0, 20.0] + xout: list[float] = [1.5] + result1 = approx(x, y, xout=xout, method=1) + result2 = approx(x, y, xout=xout, method=2, f=0.0) + assert result1.y[0] == 15.0 + assert result2.y[0] == 10.0 + + def test_raises_error_different_length_xy(self): + x: list[float] = [1.0, 2.0] + y: list[float] = [10.0] + with pytest.raises(ValueError, match="x and y must have same length"): + approx(x, y) + + def test_raises_error_invalid_method(self): + x: list[float] = [1.0, 2.0] + y: list[float] = [10.0, 20.0] + with pytest.raises(ValueError, match="invalid interpolation method"): + approx(x, y, method="invalid") + + def test_raises_error_invalid_method_code(self): + x: list[float] = [1.0, 2.0] + y: list[float] = [10.0, 20.0] + with pytest.raises(ValueError, match="invalid interpolation method"): + approx(x, y, method=3) + + def test_raises_error_invalid_f(self): + x: list[float] = [1.0, 2.0] + y: list[float] = [10.0, 20.0] + with pytest.raises(ValueError, match="invalid f value"): + approx(x, y, method="constant", f=2.0) + + def test_raises_error_need_two_points_for_linear(self): + x: list[float] = [1.0] + y: list[float] = [10.0] + with pytest.raises(ValueError, match="need at least two non-NA values"): + approx(x, y, method="linear") + + def test_raises_error_zero_points(self): + x: list[float] = [] + y: list[float] = [] + with pytest.raises(ValueError, match="zero non-NA points"): + approx(x, y, method="constant") + + def test_handles_ties_mean(self): + x: list[float] = [1.0, 1.0, 2.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [1.5] + result = approx(x, y, xout=xout, ties="mean") + assert result.y[0] == 22.5 + + def test_rule_as_list(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, 20.0, 30.0] + xout: list[float] = [0.5, 3.5] + result = approx(x, y, xout=xout, rule=[1, 2]) + assert np.isnan(result.y[0]) + assert result.y[1] == 30.0 + + def test_na_rm_false_with_na_in_y(self): + x: list[float] = [1.0, 2.0, 3.0] + y: list[float] = [10.0, np.nan, 30.0] + xout: list[float] = [2.5] + result = approx(x, y, xout=xout, na_rm=False) + # After filtering out NA, should interpolate between 1.0->10.0 and 3.0->30.0 + assert result.y[0] == 25.0 diff --git a/tests/unit/test_density.py b/tests/unit/test_density.py new file mode 100644 index 00000000..c4d88721 --- /dev/null +++ b/tests/unit/test_density.py @@ -0,0 +1,226 @@ +from typing import Literal, cast + +import numpy as np +import numpy.typing as npt +import pytest +from numpy.testing import assert_allclose, assert_array_equal + +from como.density import DensityResult, bin_distance, density, dnorm, nrd0 + +KERNEL_TYPE = Literal["gaussian", "epanechnikov", "rectangular", "triangular", "biweight", "cosine", "optcosine"] + + +class TestBinDistance: + def test_basic_binning(self): + x: npt.NDArray[float] = np.array([0.5, 1.5, 2.5]) + weights: npt.NDArray[float] = np.array([1.0, 1.0, 1.0]) + result: npt.NDArray[float] = bin_distance(x, weights, lo=0, up=3, n=4) + + assert result.shape == (8,) + assert np.all(result >= 0) + + def test_weighted_binning(self): + x: npt.NDArray[float] = np.array([1.0, 2.0]) + weights: npt.NDArray[float] = np.array([2.0, 1.0]) + result: npt.NDArray[float] = bin_distance(x, weights, lo=0, up=3, n=4) + + assert result.shape == (8,) + assert result.sum() > 0 + + def test_empty_array(self): + x: npt.NDArray[float] = np.array([]) + weights: npt.NDArray[float] = np.array([]) + result: npt.NDArray[float] = bin_distance(x, weights, lo=0, up=1, n=2) + + assert result.shape == (4,) + assert_array_equal(result, np.zeros(4)) + + def test_out_of_bounds_handling(self): + x: npt.NDArray[float] = np.array([10.0]) + weights: npt.NDArray[float] = np.array([1.0]) + result: npt.NDArray[float] = bin_distance(x, weights, lo=0, up=1, n=2) + + assert result.shape == (4,) + + +class TestDnorm: + def test_standard_normal(self): + # dnorm(0) for standard normal should be ~0.3989 + result: float = dnorm(0.0, mean=0.0, sd=1.0) + expected: float = 1.0 / np.sqrt(2 * np.pi) + assert_allclose(result, expected, rtol=1e-10) + + def test_with_mean_and_sd(self): + result: float = dnorm(5.0, mean=5.0, sd=2.0) + expected: float = 1.0 / (2.0 * np.sqrt(2 * np.pi)) + assert_allclose(result, expected, rtol=1e-10) + + def test_log_density(self): + result: float = dnorm(0.0, mean=0.0, sd=1.0, log=True) + expected: float = np.log(1.0 / np.sqrt(2 * np.pi)) + assert_allclose(result, expected, rtol=1e-10) + + def test_nan_inputs(self): + assert np.isnan(dnorm(np.nan)) + assert np.isnan(dnorm(0.0, mean=np.nan)) + assert np.isnan(dnorm(0.0, sd=np.nan)) + + def test_negative_sd(self): + result: float = dnorm(0.0, sd=-1.0) + assert np.isnan(result) + + def test_infinite_sd(self): + result: float = dnorm(0.0, sd=np.inf) + assert result == 0.0 or result == -np.inf + + def test_zero_sd_at_mean(self): + result: float = dnorm(5.0, mean=5.0, sd=0.0) + assert np.isinf(result) + + def test_zero_sd_away_from_mean(self): + result: float = dnorm(5.0, mean=3.0, sd=0.0, log=False) + assert result == 0.0 + + def test_fast_dnorm_flag(self): + result_fast: float = dnorm(1.0, fast_dnorm=True) + result_slow: float = dnorm(1.0, fast_dnorm=False) + assert_allclose(result_fast, result_slow, rtol=1e-5) + + def test_large_values(self): + result: float = dnorm(100.0, mean=0.0, sd=1.0) + assert result >= 0.0 + assert result < 1e-10 + + +class TestNrd0: + def test_basic_calculation(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result: float = nrd0(x) + assert result > 0 + assert np.isfinite(result) + + def test_with_constant_values(self): + x: npt.NDArray[float] = np.array([5.0, 5.0, 5.0, 5.0]) + result: float = nrd0(x) + assert result > 0 + assert np.isfinite(result) + + def test_single_nonzero_value(self): + x: npt.NDArray[float] = np.array([0.0, 7.0]) + result: float = nrd0(x) + assert result > 0 + + def test_all_zeros(self): + # if the input array is changed from a shape of `(3,)`, the result will change + # this is because `nrd0` takes the length of the input into account when calculating bandwidth + x: npt.NDArray[float] = np.array([0.0, 0.0, 0.0], dtype=float) + result: float = nrd0(x) + assert result == 0.7224674055842076 + + def test_insufficient_data(self): + with pytest.raises(ValueError, match="need at least 2 data points"): + nrd0(np.array([1.0])) + + def test_empty_array(self): + with pytest.raises(ValueError, match="need at least 2 data points"): + nrd0(np.array([])) + + +class TestDensity: + def test_basic_density(self): + x: npt.NDArray[float] = np.random.normal(0, 1, 100) + result: DensityResult = density(x) + + assert isinstance(result, DensityResult) + assert len(result.x) == 512 + assert len(result.y) == 512 + assert result.bw > 0 + assert np.all(result.y >= 0) + + def test_custom_bandwidth(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result: DensityResult = density(x, bw=0.5) + + assert_allclose(result.bw, 0.5, rtol=1e-10) + + def test_with_weights(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + weights: npt.NDArray[float] = np.array([0.1, 0.2, 0.4, 0.2, 0.1]) + result: DensityResult = density(x, weights=weights, n=5) + + assert len(result.x) == 5 + assert np.all(result.y >= 0) + + def test_custom_grid_range(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result: DensityResult = density(x, from_=0, to_=6, n=100) + assert result.x[0] == 0 + assert result.x[-1] == 6 + assert len(result.x) == 100 + + def test_different_kernels(self): + x: npt.NDArray[float] = np.random.normal(0, 1, 50) + + for kernel in ["epanechnikov", "rectangular", "triangular", "biweight", "cosine", "optcosine"]: + with pytest.raises(NotImplementedError, match=f"Only 'gaussian' kernel is implemented; got '{kernel}'"): + density(x, kernel=cast(KERNEL_TYPE, kernel)) + + def test_kernel_only_mode(self): + result: npt.NDArray[float] = density([1, 2, 3], kernel="gaussian", kernel_only=True) + expected: float = 1 / (2 * np.sqrt(np.pi)) + assert isinstance(result, float) + assert_allclose(result, expected, rtol=1e-10) + + def test_na_removal(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, np.nan, 4.0, 5.0]) + result: DensityResult = density(x, remove_na=True) + + assert len(result.x) == 512 + assert np.all(np.isfinite(result.y)) + + def test_na_without_removal(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, np.nan, 4.0, 5.0]) + + with pytest.raises(ValueError, match="NA values found"): + density(x, remove_na=False) + + def test_insufficient_data_for_nrd0(self): + with pytest.raises(ValueError, match="at least two points"): + density([1.0], bw="nrd0") + + def test_adjust_parameter(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result1: DensityResult = density(x, bw=1.0, adjust=1.0) + result2: DensityResult = density(x, bw=1.0, adjust=2.0) + + assert_allclose(result2.bw, 2.0 * result1.bw, rtol=1e-10) + + def test_invalid_bandwidth_string(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + + with pytest.raises(TypeError, match="must be a number or 'nrd0'"): + density(x, bw="invalid") + + def test_negative_weights(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + weights: npt.NDArray[float] = np.array([1.0, -1.0, 1.0]) + + with pytest.raises(ValueError, match="Negative values found"): + density(x, weights=weights) + + def test_infinite_values_in_x(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, np.inf, 4.0, 5.0]) + result: DensityResult = density(x) + + assert np.all(np.isfinite(result.y)) + + def test_result_named_tuple_fields(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result: DensityResult = density(x, n=100) + + assert hasattr(result, "x") + assert hasattr(result, "y") + assert hasattr(result, "x_grid") + assert hasattr(result, "y_grid") + assert hasattr(result, "bw") + assert hasattr(result, "n") diff --git a/tests/unit/test_peak_finder.py b/tests/unit/test_peak_finder.py new file mode 100644 index 00000000..7d2d730d --- /dev/null +++ b/tests/unit/test_peak_finder.py @@ -0,0 +1,204 @@ +import numpy as np +import numpy.typing as npt +import pandas as pd +import pytest + +from como.peak_finder import ( + _encode_signs, + _enforce_minimum_peak_distance, + _validate_args, + find_peaks, +) + + +class TestValidateArgs: + def test_multidimensional_array_raises_error(self): + x: npt.NDArray[float] = np.array([[1, 2], [3, 4]]) + with pytest.raises(ValueError, match="Expected a 1D array, got 2D array instead"): + _validate_args(x, 1, 1, "0", 0.0, 1, 0.0) + + def test_nan_values_raise_error(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, np.nan, 4.0]) + with pytest.raises(ValueError, match="Input x contains NaNs"): + _validate_args(x, 1, 1, "0", 0.0, 1, 0.0) + + def test_negative_nups_raises_error(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + with pytest.raises(ValueError, match="Argument 'nups' must be non-negative"): + _validate_args(x, -1, 1, "0", 0.0, 1, 0.0) + + def test_negative_ndowns_raises_error(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + with pytest.raises(ValueError, match="Argument 'ndowns' must be non-negative"): + _validate_args(x, 1, -1, "0", 0.0, 1, 0.0) + + def test_invalid_zero_raises_error(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + with pytest.raises(ValueError, match="Argument 'zero' must be '0', '\\+', or '-'"): + _validate_args(x, 1, 1, "x", 0.0, 1, 0.0) # type: ignore[invalid-type-argument] + + def test_negative_min_peak_height_raises_error(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + with pytest.raises(ValueError, match="Argument 'min_peak_height' must be non-negative"): + _validate_args(x, 1, 1, "0", -1.0, 1, 0.0) + + def test_negative_min_peak_distance_raises_error(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + with pytest.raises(ValueError, match="Argument 'minpeakdistance' must be non-negative"): + _validate_args(x, 1, 1, "0", 0.0, -1, 0.0) + + def test_negative_threshold_raises_error(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + with pytest.raises(ValueError, match="Argument 'threshold' must be non-negative"): + _validate_args(x, 1, 1, "0", 0.0, 1, -1.0) + + def test_valid_args_does_not_raise(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0]) + _validate_args(x, 1, 1, "0", 0.0, 1, 0.0) + + +class TestEncodeSigns: + def test_increasing_sequence(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 3.0, 4.0]) + result: str = _encode_signs(x, "0") + assert result == "+++" + + def test_decreasing_sequence(self): + x: npt.NDArray[float] = np.array([4.0, 3.0, 2.0, 1.0]) + result: str = _encode_signs(x, "0") + assert result == "---" + + def test_flat_sequence_with_zero(self): + x: npt.NDArray[float] = np.array([1.0, 1.0, 1.0]) + result: str = _encode_signs(x, "0") + assert result == "00" + + def test_flat_sequence_with_plus(self): + x: npt.NDArray[float] = np.array([1.0, 1.0, 1.0]) + result: str = _encode_signs(x, "+") + assert result == "++" + + def test_flat_sequence_with_minus(self): + x: npt.NDArray[float] = np.array([1.0, 1.0, 1.0]) + result: str = _encode_signs(x, "-") + assert result == "--" + + def test_flat_sequence_with_dollarsign(self): + x: npt.NDArray[float] = np.array([1.0, 1.0, 1.0]) + result: str = _encode_signs(x, "$") + assert result == "$$" + + def test_mixed_sequence(self): + x: npt.NDArray[float] = np.array([1.0, 2.0, 2.0, 3.0, 2.0]) + result: str = _encode_signs(x, "0") + assert result == "+0+-" + + +class TestEnforceMinimumPeakDistance: + def test_inplace_removes_close_peaks(self): + df: pd.DataFrame = pd.DataFrame( + { + "height": [10.0, 8.0, 5.0], + "peak_idx": [0, 2, 10], + "start_idx": [0, 1, 9], + "end_idx": [1, 3, 11], + } + ) + _enforce_minimum_peak_distance(df, min_peak_distance=5, inplace=True) + assert len(df) == 2 + assert df["peak_idx"].tolist() == [0, 10] + + def test_not_inplace_returns_new_dataframe(self): + df: pd.DataFrame = pd.DataFrame( + { + "height": [10.0, 8.0, 5.0], + "peak_idx": [0, 2, 10], + "start_idx": [0, 1, 9], + "end_idx": [1, 3, 11], + } + ) + result = _enforce_minimum_peak_distance(df, min_peak_distance=5, inplace=False) + assert len(result) == 2 + assert len(df) == 3 # original unchanged + assert result["peak_idx"].tolist() == [0, 10] + + def test_all_peaks_sufficiently_spaced(self): + df: pd.DataFrame = pd.DataFrame( + { + "height": [10.0, 8.0], + "peak_idx": [0, 10], + "start_idx": [0, 9], + "end_idx": [1, 11], + } + ) + _enforce_minimum_peak_distance(df, min_peak_distance=5, inplace=True) + assert len(df) == 2 + + +class TestFindPeaks: + def test_simple_peak(self): + x: npt.NDArray[float] = np.asarray([0.0, 1.0, 2.0, 3.0, 2.0, 1.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, min_peak_height=0.0) + assert len(result) == 1 + assert result.iloc[0]["peak_idx"] == 3 + assert result.iloc[0]["height"] == 3.0 + + def test_multiple_peaks(self): + x: npt.NDArray[float] = np.asarray([0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, min_peak_height=0.0) + assert len(result) == 3 + + def test_no_peaks(self): + x: npt.NDArray[float] = np.asarray([1.0, 2.0, 3.0, 4.0, 5.0], dtype=float) + result: pd.DataFrame = find_peaks(x) + assert len(result) == 0 + assert list(result.columns) == ["height", "peak_idx", "start_idx", "end_idx"] + + def test_min_peak_height_filter(self): + x: npt.NDArray[float] = np.asarray([0.0, 1.0, 0.0, 5.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, min_peak_height=2.0) + assert len(result) == 1 + assert result.iloc[0]["height"] == 5.0 + + def test_nups_and_ndowns(self): + x: npt.NDArray[float] = np.asarray([0.0, 1.0, 2.0, 3.0, 2.0, 1.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, nups=2, ndowns=2, min_peak_height=0.0) + assert len(result) == 1 + + def test_npeaks_limits_output(self): + x: npt.NDArray[float] = np.asarray([0.0, 5.0, 0.0, 3.0, 0.0, 1.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, npeaks=2, min_peak_height=0.0) + assert len(result) == 2 + + def test_sortstr_orders_by_height(self): + x: npt.NDArray[float] = np.asarray([0.0, 1.0, 0.0, 5.0, 0.0, 3.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, sortstr=True, min_peak_height=0.0) + heights = result["height"].tolist() + assert heights == sorted(heights, reverse=True) + + def test_min_peak_distance(self): + x: npt.NDArray[float] = np.asarray([0.0, 10.0, 0.0, 8.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, min_peak_distance=3, min_peak_height=0.0) + assert len(result) == 1 + assert result.iloc[0]["height"] == 10.0 + + def test_threshold_filter(self): + x: npt.NDArray[float] = np.asarray([1.0, 2.0, 1.5, 5.0, 1.0], dtype=float) + result: pd.DataFrame = find_peaks(x, threshold=2.0, min_peak_height=0.0) + assert len(result) == 1 + assert result.iloc[0]["height"] == 5.0 + + def test_accepts_list_input(self): + x: npt.NDArray[float] = np.asarray([0.0, 1.0, 2.0, 1.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, min_peak_height=0.0) + assert len(result) == 1 + + def test_accepts_numpy_array(self): + x: npt.NDArray[float] = np.array([0.0, 1.0, 2.0, 1.0, 0.0]) + result: pd.DataFrame = find_peaks(x, min_peak_height=0.0) + assert len(result) == 1 + + def test_custom_peak_pattern(self): + x: npt.NDArray[float] = np.asarray([0.0, 1.0, 2.0, 3.0, 2.0, 1.0, 0.0], dtype=float) + result: pd.DataFrame = find_peaks(x, peak_pattern=r"[+]{3,}[-]{3,}", min_peak_height=0.0) + assert len(result) == 1