Skip to content

Commit 4d128ee

Browse files
authored
Merge pull request #129 from igerber/review-ci-tests
Add backend-aware test parameter scaling for pure Python CI
2 parents 5cc7814 + cf37b5a commit 4d128ee

10 files changed

Lines changed: 382 additions & 205 deletions

CLAUDE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ Tests mirror the source modules:
375375
- `tests/test_pretrends.py` - Tests for pre-trends power analysis
376376
- `tests/test_datasets.py` - Tests for dataset loading functions
377377

378+
Session-scoped `ci_params` fixture in `conftest.py` scales bootstrap iterations and TROP grid sizes in pure Python mode — use `ci_params.bootstrap(n)` and `ci_params.grid(values)` in new tests with `n_bootstrap >= 20`. For SE convergence tests (analytical vs bootstrap comparison), use `ci_params.bootstrap(n, min_n=199)` to ensure sufficient iterations.
379+
378380
### Test Writing Guidelines
379381

380382
**For fallback/error handling paths:**

tests/conftest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
to avoid import-time subprocess latency.
66
"""
77

8+
import math
89
import os
910
import subprocess
1011

@@ -81,3 +82,48 @@ def test_comparison_with_r(require_r):
8182
"""
8283
if not r_available:
8384
pytest.skip("R or did package not available")
85+
86+
87+
# =============================================================================
88+
# CI Performance: Backend-Aware Parameter Scaling
89+
# =============================================================================
90+
91+
from diff_diff._backend import HAS_RUST_BACKEND
92+
93+
_PURE_PYTHON_MODE = (
94+
os.environ.get("DIFF_DIFF_BACKEND", "auto").lower() == "python"
95+
or not HAS_RUST_BACKEND
96+
)
97+
98+
99+
class CIParams:
100+
"""Scale test parameters in pure Python mode for CI performance.
101+
102+
When Rust backend is available, all values pass through unchanged.
103+
In pure Python mode, bootstrap iterations and LOOCV grids are scaled
104+
down to reduce CI time while preserving code path coverage.
105+
"""
106+
107+
@staticmethod
108+
def bootstrap(n: int, *, min_n: int = 11) -> int:
109+
"""Scale bootstrap iterations. Guaranteed monotonic: bootstrap(n+1) >= bootstrap(n).
110+
111+
Use a larger min_n for tests comparing analytical vs bootstrap SEs,
112+
which need more iterations for stable convergence.
113+
"""
114+
if not _PURE_PYTHON_MODE or n <= 10:
115+
return n
116+
return min(n, max(min_n, int(math.sqrt(n) * 1.6)))
117+
118+
@staticmethod
119+
def grid(values: list) -> list:
120+
"""Scale TROP lambda grids. Keeps first, middle, last for grids > 3 elements."""
121+
if not _PURE_PYTHON_MODE or len(values) <= 3:
122+
return values
123+
return [values[0], values[len(values) // 2], values[-1]]
124+
125+
126+
@pytest.fixture(scope="session")
127+
def ci_params():
128+
"""Backend-aware parameter scaling for CI performance."""
129+
return CIParams()

tests/test_ci_params.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Tests for CIParams bootstrap scaling in conftest.py."""
2+
3+
import math
4+
5+
import tests.conftest as conftest_module
6+
from tests.conftest import CIParams
7+
8+
9+
class TestCIParamsBootstrap:
10+
def test_min_n_in_pure_python_mode(self, monkeypatch):
11+
"""min_n raises the floor in pure Python mode."""
12+
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
13+
assert CIParams.bootstrap(499, min_n=199) == 199
14+
15+
def test_min_n_passthrough_in_rust_mode(self, monkeypatch):
16+
"""min_n has no effect when Rust backend is available."""
17+
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", False)
18+
assert CIParams.bootstrap(499, min_n=199) == 499
19+
20+
def test_min_n_capped_at_original_request(self, monkeypatch):
21+
"""min_n never exceeds the original n."""
22+
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
23+
assert CIParams.bootstrap(100, min_n=199) == 100
24+
25+
def test_n_lte_10_ignores_min_n(self, monkeypatch):
26+
"""n <= 10 always returns n regardless of min_n or mode."""
27+
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
28+
assert CIParams.bootstrap(10, min_n=199) == 10
29+
30+
def test_default_min_n_preserves_existing_behavior(self, monkeypatch):
31+
"""Default min_n=11 matches pre-change behavior."""
32+
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
33+
assert CIParams.bootstrap(499) == max(11, int(math.sqrt(499) * 1.6)) # 35

tests/test_estimators.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,9 +2410,10 @@ def single_treated_unit_data(self):
24102410

24112411
return pd.DataFrame(data)
24122412

2413-
def test_basic_fit(self, sdid_panel_data):
2413+
def test_basic_fit(self, sdid_panel_data, ci_params):
24142414
"""Test basic SDID model fitting."""
2415-
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
2415+
n_boot = ci_params.bootstrap(50)
2416+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
24162417
results = sdid.fit(
24172418
sdid_panel_data,
24182419
outcome="outcome",
@@ -2428,9 +2429,10 @@ def test_basic_fit(self, sdid_panel_data):
24282429
assert results.n_treated == 5
24292430
assert results.n_control == 25
24302431

2431-
def test_att_direction(self, sdid_panel_data):
2432+
def test_att_direction(self, sdid_panel_data, ci_params):
24322433
"""Test that ATT is estimated in correct direction."""
2433-
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
2434+
n_boot = ci_params.bootstrap(50)
2435+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
24342436
results = sdid.fit(
24352437
sdid_panel_data,
24362438
outcome="outcome",
@@ -2489,9 +2491,10 @@ def test_unit_weights_nonnegative(self, sdid_panel_data):
24892491
for w in results.unit_weights.values():
24902492
assert w >= 0
24912493

2492-
def test_single_treated_unit(self, single_treated_unit_data):
2494+
def test_single_treated_unit(self, single_treated_unit_data, ci_params):
24932495
"""Test SDID with a single treated unit (classic SC scenario)."""
2494-
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
2496+
n_boot = ci_params.bootstrap(50)
2497+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
24952498
results = sdid.fit(
24962499
single_treated_unit_data,
24972500
outcome="outcome",
@@ -2554,9 +2557,10 @@ def test_placebo_inference(self, sdid_panel_data):
25542557
assert len(results.placebo_effects) > 0
25552558
assert results.se > 0
25562559

2557-
def test_bootstrap_inference(self, sdid_panel_data):
2560+
def test_bootstrap_inference(self, sdid_panel_data, ci_params):
25582561
"""Test bootstrap-based inference."""
2559-
sdid = SyntheticDiD(variance_method="bootstrap", n_bootstrap=100, seed=42)
2562+
n_boot = ci_params.bootstrap(100)
2563+
sdid = SyntheticDiD(variance_method="bootstrap", n_bootstrap=n_boot, seed=42)
25602564
results = sdid.fit(
25612565
sdid_panel_data,
25622566
outcome="outcome",
@@ -2567,7 +2571,7 @@ def test_bootstrap_inference(self, sdid_panel_data):
25672571
)
25682572

25692573
assert results.variance_method == "bootstrap"
2570-
assert results.n_bootstrap == 100
2574+
assert results.n_bootstrap == n_boot
25712575
assert results.se > 0
25722576
assert results.conf_int[0] < results.att < results.conf_int[1]
25732577

@@ -2627,9 +2631,10 @@ def test_pre_treatment_fit(self, sdid_panel_data):
26272631
assert results.pre_treatment_fit is not None
26282632
assert results.pre_treatment_fit >= 0
26292633

2630-
def test_summary_output(self, sdid_panel_data):
2634+
def test_summary_output(self, sdid_panel_data, ci_params):
26312635
"""Test that summary produces string output."""
2632-
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
2636+
n_boot = ci_params.bootstrap(50)
2637+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
26332638
results = sdid.fit(
26342639
sdid_panel_data,
26352640
outcome="outcome",
@@ -2645,9 +2650,10 @@ def test_summary_output(self, sdid_panel_data):
26452650
assert "ATT" in summary
26462651
assert "Unit Weights" in summary
26472652

2648-
def test_to_dict(self, sdid_panel_data):
2653+
def test_to_dict(self, sdid_panel_data, ci_params):
26492654
"""Test conversion to dictionary."""
2650-
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
2655+
n_boot = ci_params.bootstrap(50)
2656+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
26512657
results = sdid.fit(
26522658
sdid_panel_data,
26532659
outcome="outcome",
@@ -2664,9 +2670,10 @@ def test_to_dict(self, sdid_panel_data):
26642670
assert "n_post_periods" in result_dict
26652671
assert "pre_treatment_fit" in result_dict
26662672

2667-
def test_to_dataframe(self, sdid_panel_data):
2673+
def test_to_dataframe(self, sdid_panel_data, ci_params):
26682674
"""Test conversion to DataFrame."""
2669-
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
2675+
n_boot = ci_params.bootstrap(50)
2676+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
26702677
results = sdid.fit(
26712678
sdid_panel_data,
26722679
outcome="outcome",
@@ -2681,9 +2688,10 @@ def test_to_dataframe(self, sdid_panel_data):
26812688
assert len(df) == 1
26822689
assert "att" in df.columns
26832690

2684-
def test_repr(self, sdid_panel_data):
2691+
def test_repr(self, sdid_panel_data, ci_params):
26852692
"""Test string representation."""
2686-
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
2693+
n_boot = ci_params.bootstrap(50)
2694+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
26872695
results = sdid.fit(
26882696
sdid_panel_data,
26892697
outcome="outcome",
@@ -2697,9 +2705,10 @@ def test_repr(self, sdid_panel_data):
26972705
assert "SyntheticDiDResults" in repr_str
26982706
assert "ATT=" in repr_str
26992707

2700-
def test_is_significant_property(self, sdid_panel_data):
2708+
def test_is_significant_property(self, sdid_panel_data, ci_params):
27012709
"""Test is_significant property."""
2702-
sdid = SyntheticDiD(n_bootstrap=100, seed=42)
2710+
n_boot = ci_params.bootstrap(100)
2711+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
27032712
results = sdid.fit(
27042713
sdid_panel_data,
27052714
outcome="outcome",
@@ -2813,12 +2822,13 @@ def test_auto_infer_post_periods(self, sdid_panel_data):
28132822
assert results.pre_periods == [0, 1, 2, 3]
28142823
assert results.post_periods == [4, 5, 6, 7]
28152824

2816-
def test_with_covariates(self, sdid_panel_data):
2825+
def test_with_covariates(self, sdid_panel_data, ci_params):
28172826
"""Test SDID with covariates."""
28182827
# Add a covariate
28192828
sdid_panel_data["size"] = np.random.normal(100, 10, len(sdid_panel_data))
28202829

2821-
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
2830+
n_boot = ci_params.bootstrap(50)
2831+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
28222832
results = sdid.fit(
28232833
sdid_panel_data,
28242834
outcome="outcome",
@@ -2832,9 +2842,10 @@ def test_with_covariates(self, sdid_panel_data):
28322842
assert results is not None
28332843
assert sdid.is_fitted_
28342844

2835-
def test_confidence_interval_contains_estimate(self, sdid_panel_data):
2845+
def test_confidence_interval_contains_estimate(self, sdid_panel_data, ci_params):
28362846
"""Test that confidence interval contains the estimate."""
2837-
sdid = SyntheticDiD(n_bootstrap=100, seed=42)
2847+
n_boot = ci_params.bootstrap(100)
2848+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
28382849
results = sdid.fit(
28392850
sdid_panel_data,
28402851
outcome="outcome",
@@ -2847,9 +2858,10 @@ def test_confidence_interval_contains_estimate(self, sdid_panel_data):
28472858
lower, upper = results.conf_int
28482859
assert lower < results.att < upper
28492860

2850-
def test_reproducibility_with_seed(self, sdid_panel_data):
2861+
def test_reproducibility_with_seed(self, sdid_panel_data, ci_params):
28512862
"""Test that results are reproducible with the same seed."""
2852-
results1 = SyntheticDiD(n_bootstrap=50, seed=42).fit(
2863+
n_boot = ci_params.bootstrap(50)
2864+
results1 = SyntheticDiD(n_bootstrap=n_boot, seed=42).fit(
28532865
sdid_panel_data,
28542866
outcome="outcome",
28552867
treatment="treated",
@@ -2858,7 +2870,7 @@ def test_reproducibility_with_seed(self, sdid_panel_data):
28582870
post_periods=[4, 5, 6, 7],
28592871
)
28602872

2861-
results2 = SyntheticDiD(n_bootstrap=50, seed=42).fit(
2873+
results2 = SyntheticDiD(n_bootstrap=n_boot, seed=42).fit(
28622874
sdid_panel_data,
28632875
outcome="outcome",
28642876
treatment="treated",
@@ -2870,7 +2882,7 @@ def test_reproducibility_with_seed(self, sdid_panel_data):
28702882
assert results1.att == results2.att
28712883
assert results1.se == results2.se
28722884

2873-
def test_insufficient_pre_periods_warning(self):
2885+
def test_insufficient_pre_periods_warning(self, ci_params):
28742886
"""Test that SDID warns with very few pre-treatment periods."""
28752887
np.random.seed(42)
28762888

@@ -2909,7 +2921,8 @@ def test_insufficient_pre_periods_warning(self):
29092921

29102922
df = pd.DataFrame(data)
29112923

2912-
sdid = SyntheticDiD(n_bootstrap=30, seed=42)
2924+
n_boot = ci_params.bootstrap(30)
2925+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
29132926

29142927
# Should work but may warn about few pre-periods
29152928
# (Depending on implementation - some may warn, some may not)
@@ -2926,7 +2939,7 @@ def test_insufficient_pre_periods_warning(self):
29262939
assert np.isfinite(results.att)
29272940
assert results.se > 0
29282941

2929-
def test_single_pre_period_edge_case(self):
2942+
def test_single_pre_period_edge_case(self, ci_params):
29302943
"""Test SDID with single pre-treatment period (extreme edge case)."""
29312944
np.random.seed(42)
29322945

@@ -2964,7 +2977,8 @@ def test_single_pre_period_edge_case(self):
29642977

29652978
df = pd.DataFrame(data)
29662979

2967-
sdid = SyntheticDiD(n_bootstrap=30, seed=42)
2980+
n_boot = ci_params.bootstrap(30)
2981+
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
29682982

29692983
# With single pre-period, time weights will be trivially [1.0]
29702984
results = sdid.fit(
@@ -2981,7 +2995,7 @@ def test_single_pre_period_edge_case(self):
29812995
# Time weights should have single entry
29822996
assert len(results.time_weights) == 1
29832997

2984-
def test_more_pre_periods_than_control_units(self):
2998+
def test_more_pre_periods_than_control_units(self, ci_params):
29852999
"""Test SDID when n_pre_periods > n_control_units (underdetermined)."""
29863000
np.random.seed(42)
29873001

@@ -3020,7 +3034,8 @@ def test_more_pre_periods_than_control_units(self):
30203034
df = pd.DataFrame(data)
30213035

30223036
# Use regularization to help with underdetermined system
3023-
sdid = SyntheticDiD(lambda_reg=1.0, n_bootstrap=30, seed=42)
3037+
n_boot = ci_params.bootstrap(30)
3038+
sdid = SyntheticDiD(lambda_reg=1.0, n_bootstrap=n_boot, seed=42)
30243039

30253040
results = sdid.fit(
30263041
df,

tests/test_methodology_callaway.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ class TestSEFormulas:
801801
"""Tests for standard error formula verification."""
802802

803803
@pytest.mark.slow
804-
def test_analytical_se_close_to_bootstrap_se(self):
804+
def test_analytical_se_close_to_bootstrap_se(self, ci_params):
805805
"""
806806
Analytical and bootstrap SEs should be within 20%.
807807
@@ -812,6 +812,7 @@ def test_analytical_se_close_to_bootstrap_se(self):
812812
This test is marked slow because it uses 499 bootstrap iterations
813813
for thorough validation of SE convergence.
814814
"""
815+
n_boot = ci_params.bootstrap(499, min_n=199)
815816
data = generate_staggered_data(
816817
n_units=300,
817818
n_periods=8,
@@ -821,7 +822,7 @@ def test_analytical_se_close_to_bootstrap_se(self):
821822
)
822823

823824
cs_anal = CallawaySantAnna(n_bootstrap=0)
824-
cs_boot = CallawaySantAnna(n_bootstrap=499, seed=42)
825+
cs_boot = CallawaySantAnna(n_bootstrap=n_boot, seed=42)
825826

826827
results_anal = cs_anal.fit(
827828
data, outcome='outcome', unit='unit',
@@ -893,12 +894,13 @@ def test_bootstrap_weight_moments_webb(self):
893894
var_w = np.var(weights)
894895
assert abs(var_w - 1.0) < 0.05, f"Webb Var(w) should be ~1.0, got {var_w}"
895896

896-
def test_bootstrap_produces_valid_inference(self):
897+
def test_bootstrap_produces_valid_inference(self, ci_params):
897898
"""Test that bootstrap produces valid inference with p-values and CIs.
898899
899900
Uses 99 bootstrap iterations - sufficient to verify the mechanism works
900901
without being slow for CI runs.
901902
"""
903+
n_boot = ci_params.bootstrap(99)
902904
data = generate_staggered_data(
903905
n_units=100,
904906
n_periods=6,
@@ -907,7 +909,7 @@ def test_bootstrap_produces_valid_inference(self):
907909
seed=42
908910
)
909911

910-
cs = CallawaySantAnna(n_bootstrap=99, seed=42)
912+
cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42)
911913
results = cs.fit(
912914
data, outcome='outcome', unit='unit',
913915
time='period', first_treat='first_treat'

0 commit comments

Comments
 (0)