Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ebd5a39
feat: added log statement and hint to user that an empty dataframe in…
JoshLoecker Oct 9, 2025
7cab8f3
fix: do not drop na values; keep as much data as possible
JoshLoecker Oct 23, 2025
7282a34
fix: allow empty genomic values
JoshLoecker Oct 23, 2025
c903cc8
chore: move from `np.float(ing,64,32)` to python's `float`
JoshLoecker Nov 3, 2025
5b39a1c
chore: ruff formatting
JoshLoecker Nov 3, 2025
34451a9
chore(dev): ignore, but provide warning for, unused imports
JoshLoecker Nov 3, 2025
c0de528
fix: clipped values should use floats, not integers
JoshLoecker Nov 3, 2025
526da0d
chore: import required modules
JoshLoecker Nov 3, 2025
512fa4b
fix: set index name + single column name
JoshLoecker Nov 3, 2025
c319e8e
chore: more explicit variable name
JoshLoecker Nov 3, 2025
86ea376
chore: update docstring
JoshLoecker Nov 3, 2025
d8179bb
chore: do not modify input dataframe inplace
JoshLoecker Nov 3, 2025
838947f
fix: use a per-replicate count for weighting instead of a single weig…
JoshLoecker Nov 3, 2025
ae0797f
fix: use index value for ensembl ids
JoshLoecker Nov 3, 2025
f423d1b
refactor: do not use async
JoshLoecker Nov 3, 2025
7f99af7
refactor: do not use async
JoshLoecker Nov 3, 2025
926c139
refactor: use lowercase identifier names; provide default values
JoshLoecker Nov 3, 2025
4149eb9
refactor: import required modules
JoshLoecker Nov 3, 2025
f56b853
fix: use `.loc[]` to prevent copy-on-write warning
JoshLoecker Nov 3, 2025
6105b98
fix: explicitly check column names
JoshLoecker Nov 3, 2025
1219cc4
refactor: raise error if adjustment method not found
JoshLoecker Nov 3, 2025
9c71754
refactor: validate batch in loop
JoshLoecker Nov 3, 2025
4142156
refactor: use lowercase names
JoshLoecker Nov 3, 2025
89f05aa
refactor: import required modules
JoshLoecker Nov 3, 2025
60a883a
feat: replicate R's zFPKM module
JoshLoecker Nov 3, 2025
95dc2db
refactor: do not drop na values, keep as much data for as long as pos…
JoshLoecker Nov 3, 2025
bb1449a
refactor: explicit type cast
JoshLoecker Nov 3, 2025
5480246
refactor!: provide default min peak height and distance
JoshLoecker Nov 3, 2025
038a5fe
feat: remove NA values by default (replicates R functionality)
JoshLoecker Nov 3, 2025
31cb4d4
fix: do not build a list of `None`
JoshLoecker Nov 3, 2025
2f17279
fix: plot gaussian distribution with a peak of the fpkm value at `mu`…
JoshLoecker Nov 3, 2025
2d6b20f
refactor!: provide default min peak height and distance
JoshLoecker Nov 3, 2025
8076c76
refactor: pythonic method to collect merged gene z-scores
JoshLoecker Nov 3, 2025
9bcda5f
refactor: do not drop na values, keep as much data for as long as pos…
JoshLoecker Nov 3, 2025
6fde5bd
refactor: use new min zfpkm peak height/distance
JoshLoecker Nov 3, 2025
74ca5de
refactor: do not drop na values, keep as much data for as long as pos…
JoshLoecker Nov 3, 2025
666fb84
chore: remove `__main__` function
JoshLoecker Nov 3, 2025
4de5421
fix: convert identifiers to lowercase
JoshLoecker Nov 3, 2025
b917ac1
refactor: more robust data handling when creating the gene info file
JoshLoecker Nov 3, 2025
e512740
refactor: include more robust error handling for dataframes when coll…
JoshLoecker Nov 3, 2025
8d505f7
chore(dev): expand overloaded functions to add more type paths
JoshLoecker Nov 3, 2025
448627b
feat(dev): added ty.toml file for type hints
JoshLoecker Nov 3, 2025
3135061
Merge branch 'develop' into fix-zfpkm
JoshLoecker Nov 3, 2025
cebab44
fix(test): lowercase names
JoshLoecker Nov 3, 2025
37be94b
format: ruff formatting
JoshLoecker Nov 3, 2025
a19536b
feat(test): added `approx` tests
JoshLoecker Nov 3, 2025
0891f01
feat(test): added `density` tests
JoshLoecker Nov 3, 2025
330fc44
feat(test): added `find_peaks` tests
JoshLoecker Nov 3, 2025
4617863
feat(dev): use named tuples for return types
JoshLoecker Nov 3, 2025
688ad8e
style: update docstrings and types
JoshLoecker Nov 3, 2025
874fdf2
refactor: do not create temporary variable for aggregation function
JoshLoecker Nov 3, 2025
c6a0015
style: ruff formatting
JoshLoecker Nov 3, 2025
ae60218
feat(dev): use named tuples for return types
JoshLoecker Nov 3, 2025
c3ef2f5
Merge branch 'develop' into zfpkm-tests
JoshLoecker Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 130 additions & 107 deletions main/como/approx.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,64 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Any
from typing import Literal, NamedTuple

import numpy as np
import numpy.typing as npt


def _coerce_to_float_array(a):
"""Helper to ensure input is a 1D float array."""
class RegularizedArray(NamedTuple):
x: npt.NDArray[np.number]
y: npt.NDArray[np.number]
not_na: npt.NDArray[bool]
kept_na: bool


class Approx(NamedTuple):
x: npt.NDArray[float]
y: npt.NDArray[float]


def _coerce_to_float_array(a: Sequence[int | float | np.number]):
"""Coerce input to a 1D float array.

Args:
a: the array to coerce

Returns:
A floating point 1D array.
"""
arr = np.asarray(a, dtype=float)
if arr.ndim != 1:
arr = arr.ravel()
return arr


def _regularize_values(x: np.ndarray, y: np.ndarray, ties, na_rm: bool) -> dict[str, Any]:
"""Minimal reimplementation of R's regularize.values() used by approx().
def _regularize_values(
x: npt.NDArray[np.number],
y: npt.NDArray[np.number],
*,
na_rm: bool,
ties: Callable[[np.ndarray], float] | Literal["mean", "first", "last", "min", "max", "median", "sum"] = "mean",
) -> RegularizedArray:
"""Reimplementation of R's regularize.values() used by approx().

- Removes NA pairs if na_rm is True (else requires x to have no NA).
- Sorts by x (stable).
- Collapses duplicate x via `ties` aggregator.
Returns dict with:
x: sorted unique x
y: aggregated y aligned with x
not_na: boolean mask of y that are not NaN
kept_na: True iff any NaN remained in y after regularization

Args:
x: the x values to regularize
y: ties: the corresponding y values
na_rm: should NA values be removed?
ties: how to handle duplicate x values; can be a string or a callable

Returns:
A NamedTuple with:
- x: sorted unique x
- y: aggregated y aligned with x
- not_na: boolean mask of y that are not NaN
- kept_na: True iff any NaN remained in y after regularization
"""
if na_rm:
# Remove pairs where x or y is NA
Expand All @@ -39,8 +73,7 @@ def _regularize_values(x: np.ndarray, y: np.ndarray, ties, na_rm: bool) -> dict[
kept_na = np.isnan(y).any()

if x.size == 0:
# THIS IS THE CORRECTED LINE:
return {"x": x, "y": y, "not_na": np.array([], dtype=bool), "kept_na": kept_na}
return RegularizedArray(x=x, y=y, not_na=np.array([], dtype=bool), kept_na=kept_na)

# Use a stable sort (mergesort) to match R's order()
order = np.argsort(x, kind="mergesort")
Expand All @@ -51,27 +84,26 @@ def _regularize_values(x: np.ndarray, y: np.ndarray, ties, na_rm: bool) -> dict[
if callable(ties):
agg_fn = ties
else:
ties_str = "mean" if ties is None else str(ties).lower()
if ties_str in ("mean", "avg", "average"):
# fmt: off
if ties in ("mean", "avg", "average"):
agg_fn = np.mean
elif ties_str in ("first", "left"):

elif ties in ("first", "left"):
def agg_fn(a):
return a[0]
elif ties_str in ("last", "right"):

elif ties in ("last", "right"):
def agg_fn(a):
return a[-1]
elif ties_str == "min":
elif ties == "min":
agg_fn = np.min
elif ties_str == "max":
elif ties == "max":
agg_fn = np.max
elif ties_str == "median":
elif ties == "median":
agg_fn = np.median
elif ties_str == "sum":
elif ties == "sum":
agg_fn = np.sum
else:
raise ValueError("Unsupported `ties`; use a callable or one of 'mean', 'first', 'last', 'min', 'max', 'median', 'sum'.")
# fmt: on

# Find unique x values and their indices/counts
unique_x, start_idx, counts = np.unique(x_sorted, return_index=True, return_counts=True)
Expand All @@ -86,8 +118,8 @@ def agg_fn(a):

not_na = ~np.isnan(y_agg)
# Check if any NaNs remain in the *aggregated* y values
kept_na_final = np.any(~not_na) if np.any(np.isnan(y_agg)) else False
return {"x": unique_x, "y": y_agg, "not_na": not_na, "kept_na": kept_na_final}
kept_na_final = bool(np.any(~not_na) if np.any(np.isnan(y_agg)) else False)
return RegularizedArray(x=unique_x, y=y_agg, not_na=not_na, kept_na=kept_na_final)


def approx(
Expand All @@ -98,11 +130,11 @@ def approx(
n: int = 50,
yleft: float | None = None,
yright: float | None = None,
rule: int | Sequence[int] = 1,
rule: int | Sequence[int] | npt.NDArray[int] = 1,
f: float = 0.0,
ties: str | Callable[[np.ndarray], float] = "mean",
ties: Callable[[np.ndarray], float] | Literal["mean", "first", "last", "min", "max", "median", "sum"] = "mean",
na_rm: bool = True,
) -> dict[str, np.ndarray]:
) -> Approx:
"""Faithful Python port of R's `approx` function.

This implementation aims to replicate the behavior of R's `approx`
Expand All @@ -114,26 +146,18 @@ def approx(
y: y-coordinates. If None, `x` is treated as `y` and `x` becomes `1..n`.
xout: Points at which to interpolate.
method: Interpolation method. "linear" (1) or "constant" (2).
n: If `xout` is not provided, interpolation happens at `n`
equally spaced points spanning the range of `x`.
yleft: Value to use for extrapolation to the left.
Defaults to `NA` (np.nan) if `rule` is 1.
yright: Value to use for extrapolation to the right.
Defaults to `NA` (np.nan) if `rule` is 1.
n: If `xout` is not provided, interpolation happens at `n` equally spaced points spanning the range of `x`.
yleft: Value to use for extrapolation to the left. Defaults to `NA` (np.nan) if `rule` is 1.
yright: Value to use for extrapolation to the right. Defaults to `NA` (np.nan) if `rule` is 1.
rule: Extrapolation rule.
- 1: Return `np.nan` for points outside the `x` range.
- 2: Use `yleft` and `yright` for points outside the range.
f: For `method="constant"`, determines the split point.
`f=0` is left-step, `f=1` is right-step, `f=0.5` is midpoint.
ties: How to handle duplicate `x` values.
Can be 'mean', 'first', 'last', 'min', 'max', 'median', 'sum',
or a callable function.
na_rm: If True, `NA` pairs are removed before interpolation.
If False, `NA`s in `x` cause an error, `NA`s in `y`
are propagated.
f: For `method="constant"`, determines the split point. `f=0` is left-step, `f=1` is right-step, `f=0.5` is midpoint.
ties: How to handle duplicate `x` values. Can be 'mean', 'first', 'last', 'min', 'max', 'median', 'sum', or a callable function.
na_rm: If True, `NA` pairs are removed before interpolation. If False, `NA`s in `x` cause an error, `NA`s in `y` are propagated.

Returns:
dict with:
`Approx` class with:
- 'x': numpy array of xout used
- 'y': numpy array of interpolated values
"""
Expand Down Expand Up @@ -164,25 +188,24 @@ def approx(
raise ValueError("invalid interpolation method")

# --- Rule normalization ---
if isinstance(rule, list | tuple | np.ndarray):
rlist = list(rule)
if not (1 <= len(rlist) <= 2):
if isinstance(rule, Sequence | np.ndarray):
rule_list: list[int] = list(rule) # type: ignore[invalid-argument-type] # This is a valid argument type, not sure why ty is upset
if not (1 <= len(rule_list) <= 2):
raise ValueError("`rule` must have length 1 or 2")
if len(rlist) == 1:
rlist = [rlist[0], rlist[0]]
if len(rule_list) == 1:
rule_list = [rule_list[0], rule_list[0]]
else:
rlist = [rule, rule]
rule_list = [rule, rule]

# --- Regularize values ---
# Sort by x, collapse ties, and handle NAs like R's regularize.values()
r = _regularize_values(x_arr, y_arr, ties, na_rm)
x_reg = r["x"]
y_reg = r["y"]
not_na_mask = r["not_na"]
r: RegularizedArray = _regularize_values(x_arr, y_arr, na_rm=na_rm, ties=ties)
x_reg: npt.NDArray[float] = r.x
y_reg: npt.NDArray[float] = r.y
not_na_mask: npt.NDArray[bool] = r.not_na
# no_na is True if we don't have to worry about NAs in y_reg
no_na = na_rm or (not r["kept_na"])
no_na: bool = na_rm or (not r.kept_na)
# nx is the number of *valid* (non-NA) points for interpolation
nx = x_reg.size if no_na else int(np.count_nonzero(not_na_mask))
nx: int = x_reg.size if no_na else int(np.count_nonzero(not_na_mask))

if np.isnan(nx):
raise ValueError("invalid length(x)")
Expand All @@ -194,98 +217,98 @@ def approx(
if nx == 0:
raise ValueError("zero non-NA points")

# --- Set extrapolation values (yleft/yright) ---
# set extrapolation values (yleft/yright)
# This logic matches R's C code (R_approxtest)
if no_na:
y_first = y_reg[0]
y_last = y_reg[-1]
y_first: float = y_reg[0]
y_last: float = y_reg[-1]
else:
# Find first and last non-NA y values
y_valid = y_reg[not_na_mask]
y_first = y_valid[0]
y_last = y_valid[-1]
y_valid: npt.NDArray[float] = y_reg[not_na_mask]
y_first: float = y_valid[0]
y_last: float = y_valid[-1]

yleft_val = (np.nan if int(rlist[0]) == 1 else y_first) if yleft is None else float(yleft)
yright_val = (np.nan if int(rlist[1]) == 1 else y_last) if yright is None else float(yright)
yleft_val: float = (np.nan if int(rule_list[0]) == 1 else y_first) if yleft is None else float(yleft)
yright_val: float = (np.nan if int(rule_list[1]) == 1 else y_last) if yright is None else float(yright)

# --- Fabricate xout if missing ---
# fabricate xout if it is missing
if xout is None:
if n <= 0:
raise ValueError("'approx' requires n >= 1")
if no_na:
x_first = x_reg[0]
x_last = x_reg[nx - 1]
x_first: float = x_reg[0]
x_last: float = x_reg[nx - 1]
else:
x_valid = x_reg[not_na_mask]
x_first = x_valid[0]
x_last = x_valid[-1]
xout_arr = np.linspace(x_first, x_last, num=int(n), dtype=float)
x_valid: npt.NDArray[float] = x_reg[not_na_mask]
x_first: float = x_valid[0]
x_last: float = x_valid[-1]
xout_arr: npt.NDArray[float] = np.linspace(x_first, x_last, num=int(n), dtype=float)
else:
xout_arr = _coerce_to_float_array(xout)
xout_arr: npt.NDArray[float] = _coerce_to_float_array(xout)

# --- C_ApproxTest checks ---
# replicate R's C_ApproxTest checks
if method_code == 2 and (not np.isfinite(f) or f < 0.0 or f > 1.0):
raise ValueError("approx(): invalid f value")
raise ValueError("approx(): invalid f value; with `method=2`, `f` must be in the range [0,1] (inclusive) or NA")
if not no_na:
# R re-filters x and y here if NAs remained
x_reg = x_reg[not_na_mask]
y_reg = y_reg[not_na_mask]
x_reg: npt.NDArray[float] = x_reg[not_na_mask]
y_reg: npt.NDArray[float] = y_reg[not_na_mask]

# --- Interpolation ---
# This section vectorized the logic from R's approx1 and R_approxfun
yout = np.empty_like(xout_arr, dtype=float)
mask_nan_xout = np.isnan(xout_arr)
# perform interpolation
# this section is a vectorized form of the logic from R's approx1 and R_approxfun
yout: npt.NDArray[float] = np.empty_like(xout_arr, dtype=float)
mask_nan_xout: npt.NDArray[bool] = np.isnan(xout_arr)
yout[mask_nan_xout] = np.nan
mask_valid = ~mask_nan_xout
mask_valid: npt.NDArray[bool] = ~mask_nan_xout

if x_reg.size == 0:
# No valid points to interpolate against
yout[mask_valid] = np.nan
return {"x": xout_arr, "y": yout}
return Approx(x=xout_arr, y=yout)

xv = xout_arr[mask_valid]
left_mask = xv < x_reg[0]
right_mask = xv > x_reg[-1]
mid_mask = ~(left_mask | right_mask)
xv: npt.NDArray[float] = xout_arr[mask_valid]
left_mask: npt.NDArray[bool] = xv < x_reg[0]
right_mask: npt.NDArray[bool] = xv > x_reg[-1]
mid_mask: npt.NDArray[bool] = ~(left_mask | right_mask)

res = np.empty_like(xv)
res: npt.NDArray[float] = np.empty_like(xv, dtype=float)
res[left_mask] = yleft_val
res[right_mask] = yright_val

if np.any(mid_mask):
xm = xv[mid_mask]
xm: npt.NDArray[float] = xv[mid_mask]

# Find indices
# j_right[k] = index of first x_reg > xm[k]
j_right = np.searchsorted(x_reg, xm, side="right")
j_right: npt.NDArray[int] = np.searchsorted(x_reg, xm, side="right")
# j_left[k] = index of first x_reg >= xm[k]
j_left = np.searchsorted(x_reg, xm, side="left")
j_left: npt.NDArray[int] = np.searchsorted(x_reg, xm, side="left")

# Points that exactly match an x_reg value
eq_mask = j_left != j_right
eq_mask: npt.NDArray[bool] = j_left != j_right
# Points that fall between x_reg values
in_interval_mask = ~eq_mask
in_interval_mask: npt.NDArray[bool] = ~eq_mask

res_mid = np.empty_like(xm)
res_mid: npt.NDArray[float] = np.empty_like(xm, dtype=float)

if np.any(eq_mask):
# For exact matches, use the corresponding y_reg value
# R's C code uses the 'j_left' index here
res_mid[eq_mask] = y_reg[j_left[eq_mask]]
res_mid[eq_mask] = y_reg[j_left[eq_mask]] # type: ignore[non-subscriptable] # we know this is of type npt.NDArray[float], not sure why ty is upset

if np.any(in_interval_mask):
# R's C code sets i = j-1, where j is the 'right' index
j = j_right[in_interval_mask]
i = j - 1
j: npt.NDArray[float] = j_right[in_interval_mask] # type: ignore[non-subscriptable] # we know this is of type npt.NDArray[float], not sure why ty is upset
i: npt.NDArray[float] = j - 1

# Get the surrounding x and y values
xi = x_reg[i]
xj = x_reg[j]
yi = y_reg[i]
yj = y_reg[j]
xi: npt.NDArray[float] = x_reg[i]
xj: npt.NDArray[float] = x_reg[j]
yi: npt.NDArray[float] = y_reg[i]
yj: npt.NDArray[float] = y_reg[j]

if method_code == 1: # linear
t = (xm[in_interval_mask] - xi) / (xj - xi)
t: npt.NDArray[float] = (xm[in_interval_mask] - xi) / (xj - xi)
res_mid[in_interval_mask] = yi + (yj - yi) * t
else: # constant
# This matches R_approxfun's constant logic
Expand All @@ -294,17 +317,17 @@ def approx(
elif f == 1.0:
res_mid[in_interval_mask] = yj
else:
# R computes (1-f)*yi + f*yj, but carefully
f1 = 1.0 - f
f2 = f
part = np.zeros_like(yi)
# computes R's (1-f)*yi + f*yj
f1: float = float(1.0 - f)
f2: float = float(f)
part: npt.NDArray[float] = np.zeros_like(yi, dtype=float)
if f1 != 0.0:
part = part + yi * f1
part: npt.NDArray[float] = part + yi * f1
if f2 != 0.0:
part = part + yj * f2
part: npt.NDArray[float] = part + yj * f2
res_mid[in_interval_mask] = part

res[mid_mask] = res_mid

yout[mask_valid] = res
return {"x": xout_arr, "y": yout}
return Approx(x=xout_arr, y=yout)
Loading