Skip to content

Commit fc1d39d

Browse files
Add simulation cache (#195)
* Add simulation cache * Format * Pass ruff check
1 parent 1571b09 commit fc1d39d

5 files changed

Lines changed: 214 additions & 0 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Added caching of saved simulations

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"microdf_python",
1919
"plotly>=5.0.0",
2020
"requests>=2.31.0",
21+
"psutil>=5.9.0",
2122
]
2223

2324
[project.optional-dependencies]

src/policyengine/core/cache.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import logging
2+
from collections import OrderedDict
3+
4+
import psutil
5+
6+
logger = logging.getLogger(__name__)
7+
8+
_MEMORY_THRESHOLDS_GB = [8, 16, 32]
9+
_warned_thresholds: set[int] = set()
10+
11+
12+
class LRUCache[T]:
13+
"""Least-recently-used cache with configurable size limit and memory monitoring."""
14+
15+
def __init__(self, max_size: int = 100):
16+
self._max_size = max_size
17+
self._cache: OrderedDict[str, T] = OrderedDict()
18+
19+
def get(self, key: str) -> T | None:
20+
"""Get item from cache, marking it as recently used."""
21+
if key not in self._cache:
22+
return None
23+
self._cache.move_to_end(key)
24+
return self._cache[key]
25+
26+
def add(self, key: str, value: T) -> None:
27+
"""Add item to cache with LRU eviction when full."""
28+
if key in self._cache:
29+
self._cache.move_to_end(key)
30+
else:
31+
self._cache[key] = value
32+
if len(self._cache) > self._max_size:
33+
self._cache.popitem(last=False)
34+
35+
self._check_memory_usage()
36+
37+
def clear(self) -> None:
38+
"""Clear all items from cache."""
39+
self._cache.clear()
40+
_warned_thresholds.clear()
41+
42+
def __len__(self) -> int:
43+
return len(self._cache)
44+
45+
def _check_memory_usage(self) -> None:
46+
"""Check memory usage and warn at threshold crossings."""
47+
process = psutil.Process()
48+
memory_gb = process.memory_info().rss / (1024**3)
49+
50+
for threshold in _MEMORY_THRESHOLDS_GB:
51+
if memory_gb >= threshold and threshold not in _warned_thresholds:
52+
logger.warning(
53+
f"Memory usage has reached {memory_gb:.2f}GB (threshold: {threshold}GB). "
54+
f"Cache contains {len(self._cache)} items."
55+
)
56+
_warned_thresholds.add(threshold)

src/policyengine/core/simulation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33

44
from pydantic import BaseModel, Field
55

6+
from .cache import LRUCache
67
from .dataset import Dataset
78
from .dynamic import Dynamic
89
from .policy import Policy
910
from .tax_benefit_model_version import TaxBenefitModelVersion
1011

12+
_cache: LRUCache["Simulation"] = LRUCache(max_size=100)
13+
1114

1215
class Simulation(BaseModel):
1316
id: str = Field(default_factory=lambda: str(uuid4()))
@@ -25,12 +28,17 @@ def run(self):
2528
self.tax_benefit_model_version.run(self)
2629

2730
def ensure(self):
31+
cached_result = _cache.get(self.id)
32+
if cached_result:
33+
return cached_result
2834
try:
2935
self.tax_benefit_model_version.load(self)
3036
except Exception:
3137
self.run()
3238
self.save()
3339

40+
_cache.add(self.id, self)
41+
3442
def save(self):
3543
"""Save the simulation's output dataset."""
3644
self.tax_benefit_model_version.save(self)

tests/test_cache.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import os
2+
import tempfile
3+
4+
import pandas as pd
5+
from microdf import MicroDataFrame
6+
7+
from policyengine.core import Simulation
8+
from policyengine.core.cache import LRUCache
9+
from policyengine.tax_benefit_models.uk import (
10+
PolicyEngineUKDataset,
11+
UKYearData,
12+
uk_latest,
13+
)
14+
15+
16+
def test_simulation_cache_hit():
17+
"""Test that simulation caching works with UK simulations."""
18+
person_df = MicroDataFrame(
19+
pd.DataFrame(
20+
{
21+
"person_id": [1, 2, 3],
22+
"benunit_id": [1, 1, 2],
23+
"household_id": [1, 1, 2],
24+
"age": [30, 25, 40],
25+
"employment_income": [50000, 30000, 60000],
26+
"person_weight": [1.0, 1.0, 1.0],
27+
}
28+
),
29+
weights="person_weight",
30+
)
31+
32+
benunit_df = MicroDataFrame(
33+
pd.DataFrame(
34+
{
35+
"benunit_id": [1, 2],
36+
"benunit_weight": [1.0, 1.0],
37+
}
38+
),
39+
weights="benunit_weight",
40+
)
41+
42+
household_df = MicroDataFrame(
43+
pd.DataFrame(
44+
{
45+
"household_id": [1, 2],
46+
"household_weight": [1.0, 1.0],
47+
}
48+
),
49+
weights="household_weight",
50+
)
51+
52+
with tempfile.TemporaryDirectory() as tmpdir:
53+
filepath = os.path.join(tmpdir, "test.h5")
54+
55+
dataset = PolicyEngineUKDataset(
56+
name="Test",
57+
description="Test dataset",
58+
filepath=filepath,
59+
year=2024,
60+
data=UKYearData(
61+
person=person_df, benunit=benunit_df, household=household_df
62+
),
63+
)
64+
65+
simulation = Simulation(
66+
dataset=dataset,
67+
tax_benefit_model_version=uk_latest,
68+
output_dataset=dataset,
69+
)
70+
71+
# Import the cache
72+
from policyengine.core.simulation import _cache
73+
74+
# Manually add to cache (simulating what ensure() does)
75+
_cache.add(simulation.id, simulation)
76+
77+
# Verify simulation is in cache
78+
assert simulation.id in _cache._cache
79+
assert len(_cache) >= 1
80+
81+
# Verify cache returns same object
82+
cached_sim = _cache.get(simulation.id)
83+
assert cached_sim is simulation
84+
85+
# Clear cache for other tests
86+
_cache.clear()
87+
88+
89+
def test_lru_cache_eviction():
90+
"""Test that LRU cache properly evicts old items."""
91+
cache = LRUCache[str](max_size=3)
92+
93+
cache.add("a", "value_a")
94+
cache.add("b", "value_b")
95+
cache.add("c", "value_c")
96+
97+
assert len(cache) == 3
98+
assert cache.get("a") == "value_a"
99+
100+
# Add fourth item, should evict 'b' (least recently used)
101+
cache.add("d", "value_d")
102+
103+
assert len(cache) == 3
104+
assert cache.get("b") is None
105+
assert cache.get("a") == "value_a"
106+
assert cache.get("c") == "value_c"
107+
assert cache.get("d") == "value_d"
108+
109+
110+
def test_lru_cache_access_updates_order():
111+
"""Test that accessing items updates their position in LRU order."""
112+
cache = LRUCache[str](max_size=3)
113+
114+
cache.add("a", "value_a")
115+
cache.add("b", "value_b")
116+
cache.add("c", "value_c")
117+
118+
# Access 'a' to move it to most recently used
119+
cache.get("a")
120+
121+
# Add fourth item, should evict 'b' (now least recently used)
122+
cache.add("d", "value_d")
123+
124+
assert cache.get("a") == "value_a"
125+
assert cache.get("b") is None
126+
assert cache.get("c") == "value_c"
127+
assert cache.get("d") == "value_d"
128+
129+
130+
def test_lru_cache_clear():
131+
"""Test that clearing cache works properly."""
132+
cache = LRUCache[str](max_size=10)
133+
134+
cache.add("a", "value_a")
135+
cache.add("b", "value_b")
136+
cache.add("c", "value_c")
137+
138+
assert len(cache) == 3
139+
140+
cache.clear()
141+
142+
assert len(cache) == 0
143+
assert cache.get("a") is None
144+
assert cache.get("b") is None
145+
assert cache.get("c") is None

0 commit comments

Comments
 (0)