Skip to content

Commit e7759d8

Browse files
authored
Merge pull request #3 from AdvancedPhotonSource/tests
Add tests to the project
2 parents 63e3dbf + e15925f commit e7759d8

File tree

10 files changed

+2075
-8
lines changed

10 files changed

+2075
-8
lines changed

.github/workflows/tests.yml

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
# -----------------------------------------------------------------------
11+
# Downloader tests (~1 min)
12+
# -----------------------------------------------------------------------
13+
test-downloader:
14+
runs-on: ubuntu-latest
15+
defaults:
16+
run:
17+
working-directory: src/downloader
18+
steps:
19+
- uses: actions/checkout@v4
20+
21+
- uses: actions/setup-python@v5
22+
with:
23+
python-version: "3.12"
24+
25+
- name: Install dependencies
26+
run: |
27+
pip install -r requirements.txt
28+
pip install pytest pyyaml
29+
30+
- name: Run tests
31+
run: pytest tests/ -v --tb=short
32+
33+
# -----------------------------------------------------------------------
34+
# Simulator tests - pure Python logic only (~1 min)
35+
# -----------------------------------------------------------------------
36+
test-simulator:
37+
runs-on: ubuntu-latest
38+
defaults:
39+
run:
40+
working-directory: src/simulator
41+
steps:
42+
- uses: actions/checkout@v4
43+
44+
- uses: actions/setup-python@v5
45+
with:
46+
python-version: "3.12"
47+
48+
- name: Install dependencies
49+
run: pip install numpy pytest pymatgen
50+
51+
- name: Run tests
52+
run: pytest tests/ -v --tb=short
53+
54+
# -----------------------------------------------------------------------
55+
# Trainer tests (~3 min)
56+
# -----------------------------------------------------------------------
57+
test-trainer:
58+
runs-on: ubuntu-latest
59+
defaults:
60+
run:
61+
working-directory: src/trainer
62+
steps:
63+
- uses: actions/checkout@v4
64+
65+
- uses: actions/setup-python@v5
66+
with:
67+
python-version: "3.12"
68+
69+
- name: Install CPU-only PyTorch and dependencies
70+
run: |
71+
pip install torch --index-url https://download.pytorch.org/whl/cpu
72+
pip install pytorch-lightning numpy pyyaml tqdm matplotlib scikit-learn pytest
73+
74+
- name: Run tests
75+
run: pytest tests/ -v --tb=short
76+
77+
# -----------------------------------------------------------------------
78+
# UI backend tests (~3 min)
79+
# -----------------------------------------------------------------------
80+
test-ui-backend:
81+
runs-on: ubuntu-latest
82+
defaults:
83+
run:
84+
working-directory: src/ui
85+
steps:
86+
- uses: actions/checkout@v4
87+
88+
- uses: actions/setup-python@v5
89+
with:
90+
python-version: "3.12"
91+
92+
- name: Install CPU-only PyTorch and dependencies
93+
run: |
94+
pip install torch --index-url https://download.pytorch.org/whl/cpu
95+
pip install -r requirements.txt
96+
pip install pytest httpx
97+
98+
- name: Run tests
99+
run: pytest tests/ -v --tb=short
100+
101+
# -----------------------------------------------------------------------
102+
# UI frontend tests (~2 min)
103+
# -----------------------------------------------------------------------
104+
test-ui-frontend:
105+
runs-on: ubuntu-latest
106+
defaults:
107+
run:
108+
working-directory: src/ui/frontend
109+
steps:
110+
- uses: actions/checkout@v4
111+
112+
- uses: actions/setup-node@v4
113+
with:
114+
node-version: "18"
115+
cache: "npm"
116+
cache-dependency-path: src/ui/frontend/package-lock.json
117+
118+
- name: Install dependencies
119+
run: npm ci
120+
121+
- name: Install Vitest
122+
run: npm install --save-dev vitest
123+
124+
- name: Run tests
125+
run: npx vitest run --reporter=verbose

pytest.ini

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[pytest]
2+
# Root-level pytest config ensures rootdir is the repo root,
3+
# preventing src/trainer/pyproject.toml from being used as config.
4+
testpaths =
5+
src/downloader/tests
6+
src/simulator/tests
7+
src/trainer/tests
8+
src/ui/tests
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""Unit tests for the downloader module."""
2+
3+
import os
4+
import sys
5+
import tempfile
6+
from pathlib import Path
7+
from unittest.mock import patch, MagicMock
8+
9+
import pytest
10+
import yaml
11+
12+
# Add parent directory to path so we can import the module
13+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
14+
15+
from downloader import (
16+
load_config,
17+
ensure_output_dir,
18+
get_api_key,
19+
passes_filters,
20+
_get_crystal_system_from_sg,
21+
is_spacegroup_stable,
22+
_write_cif_from_struct,
23+
)
24+
25+
26+
# ---------------------------------------------------------------------------
27+
# load_config
28+
# ---------------------------------------------------------------------------
29+
class TestLoadConfig:
30+
def test_valid_yaml(self, tmp_path):
31+
cfg_file = tmp_path / "config.yaml"
32+
cfg_file.write_text("output_directory: /data/raw_cif\nfilters:\n max_atoms: 500\n")
33+
cfg = load_config(cfg_file)
34+
assert cfg["output_directory"] == "/data/raw_cif"
35+
assert cfg["filters"]["max_atoms"] == 500
36+
37+
def test_empty_yaml_returns_empty_dict(self, tmp_path):
38+
cfg_file = tmp_path / "empty.yaml"
39+
cfg_file.write_text("")
40+
cfg = load_config(cfg_file)
41+
assert cfg == {}
42+
43+
def test_missing_file_raises(self, tmp_path):
44+
with pytest.raises(FileNotFoundError):
45+
load_config(tmp_path / "nonexistent.yaml")
46+
47+
48+
# ---------------------------------------------------------------------------
49+
# ensure_output_dir
50+
# ---------------------------------------------------------------------------
51+
class TestEnsureOutputDir:
52+
def test_creates_directory(self, tmp_path):
53+
new_dir = tmp_path / "a" / "b" / "c"
54+
assert not new_dir.exists()
55+
ensure_output_dir(new_dir)
56+
assert new_dir.is_dir()
57+
58+
def test_existing_directory_no_error(self, tmp_path):
59+
ensure_output_dir(tmp_path) # already exists
60+
61+
62+
# ---------------------------------------------------------------------------
63+
# get_api_key
64+
# ---------------------------------------------------------------------------
65+
class TestGetApiKey:
66+
def test_returns_key_from_env(self, monkeypatch):
67+
monkeypatch.setenv("MP_API_KEY", "test-key-123")
68+
assert get_api_key() == "test-key-123"
69+
70+
def test_raises_when_missing(self, monkeypatch):
71+
monkeypatch.delenv("MP_API_KEY", raising=False)
72+
with pytest.raises(RuntimeError, match="MP_API_KEY"):
73+
get_api_key()
74+
75+
def test_raises_when_empty(self, monkeypatch):
76+
monkeypatch.setenv("MP_API_KEY", " ")
77+
with pytest.raises(RuntimeError, match="MP_API_KEY"):
78+
get_api_key()
79+
80+
81+
# ---------------------------------------------------------------------------
82+
# _get_crystal_system_from_sg
83+
# ---------------------------------------------------------------------------
84+
class TestGetCrystalSystemFromSg:
85+
@pytest.mark.parametrize(
86+
"sg_num, expected",
87+
[
88+
(1, 1), # Triclinic
89+
(2, 1), # Triclinic boundary
90+
(3, 2), # Monoclinic
91+
(15, 2), # Monoclinic boundary
92+
(16, 3), # Orthorhombic
93+
(74, 3), # Orthorhombic boundary
94+
(75, 4), # Tetragonal
95+
(142, 4), # Tetragonal boundary
96+
(143, 5), # Trigonal
97+
(167, 5), # Trigonal boundary
98+
(168, 6), # Hexagonal
99+
(194, 6), # Hexagonal boundary
100+
(195, 7), # Cubic
101+
(230, 7), # Cubic boundary
102+
],
103+
)
104+
def test_valid_space_groups(self, sg_num, expected):
105+
assert _get_crystal_system_from_sg(sg_num) == expected
106+
107+
def test_invalid_returns_none(self):
108+
assert _get_crystal_system_from_sg(231) is None
109+
assert _get_crystal_system_from_sg(None) is None
110+
assert _get_crystal_system_from_sg("abc") is None
111+
112+
def test_zero_and_negative_map_to_triclinic(self):
113+
# Code treats sg_num <= 2 as Triclinic (no lower-bound guard)
114+
assert _get_crystal_system_from_sg(0) == 1
115+
assert _get_crystal_system_from_sg(-1) == 1
116+
117+
118+
# ---------------------------------------------------------------------------
119+
# passes_filters
120+
# ---------------------------------------------------------------------------
121+
class TestPassesFilters:
122+
def _make_mock_structure(self, num_atoms=10, volume=100.0):
123+
"""Create a mock structure with controllable atom count and volume."""
124+
mock = MagicMock()
125+
mock.__len__ = MagicMock(return_value=num_atoms)
126+
mock.volume = volume
127+
return mock
128+
129+
def test_no_filters_passes(self):
130+
struct = self._make_mock_structure()
131+
assert passes_filters(struct, {}) is True
132+
133+
def test_max_atoms_pass(self):
134+
struct = self._make_mock_structure(num_atoms=100)
135+
assert passes_filters(struct, {"max_atoms": 500}) is True
136+
137+
def test_max_atoms_fail(self):
138+
struct = self._make_mock_structure(num_atoms=600)
139+
assert passes_filters(struct, {"max_atoms": 500}) is False
140+
141+
def test_max_atoms_boundary(self):
142+
struct = self._make_mock_structure(num_atoms=500)
143+
assert passes_filters(struct, {"max_atoms": 500}) is True
144+
145+
def test_min_volume_pass(self):
146+
struct = self._make_mock_structure(volume=200.0)
147+
assert passes_filters(struct, {"min_volume": 100.0}) is True
148+
149+
def test_min_volume_fail(self):
150+
struct = self._make_mock_structure(volume=50.0)
151+
assert passes_filters(struct, {"min_volume": 100.0}) is False
152+
153+
def test_max_volume_pass(self):
154+
struct = self._make_mock_structure(volume=500.0)
155+
assert passes_filters(struct, {"max_volume": 1000.0}) is True
156+
157+
def test_max_volume_fail(self):
158+
struct = self._make_mock_structure(volume=1500.0)
159+
assert passes_filters(struct, {"max_volume": 1000.0}) is False
160+
161+
def test_combined_filters(self):
162+
struct = self._make_mock_structure(num_atoms=100, volume=500.0)
163+
filters = {"max_atoms": 200, "min_volume": 100.0, "max_volume": 1000.0}
164+
assert passes_filters(struct, filters) is True
165+
166+
def test_combined_filters_fail_atoms(self):
167+
struct = self._make_mock_structure(num_atoms=300, volume=500.0)
168+
filters = {"max_atoms": 200, "min_volume": 100.0, "max_volume": 1000.0}
169+
assert passes_filters(struct, filters) is False
170+
171+
172+
# ---------------------------------------------------------------------------
173+
# is_spacegroup_stable
174+
# ---------------------------------------------------------------------------
175+
class TestIsSpacegroupStable:
176+
def test_all_same(self):
177+
grid = [[225, 225], [225, 225]]
178+
assert is_spacegroup_stable(grid) is True
179+
180+
def test_different_values(self):
181+
grid = [[225, 225], [225, 226]]
182+
assert is_spacegroup_stable(grid) is False
183+
184+
def test_single_value(self):
185+
grid = [[225]]
186+
assert is_spacegroup_stable(grid) is True
187+
188+
def test_empty_grid(self):
189+
grid = [[]]
190+
assert is_spacegroup_stable(grid) is False
191+
192+
def test_all_zeros(self):
193+
grid = [[0, 0], [0, 0]]
194+
assert is_spacegroup_stable(grid) is True
195+
196+
def test_zero_and_nonzero(self):
197+
grid = [[0, 225], [225, 225]]
198+
assert is_spacegroup_stable(grid) is False
199+
200+
201+
# ---------------------------------------------------------------------------
202+
# _write_cif_from_struct
203+
# ---------------------------------------------------------------------------
204+
class TestWriteCifFromStruct:
205+
def test_writes_file_with_comment(self, tmp_path):
206+
"""Test that CIF writing adds space group comment."""
207+
mock_structure = MagicMock()
208+
out_path = tmp_path / "test.cif"
209+
210+
# Mock CifWriter to write a minimal CIF file
211+
with patch("downloader.CifWriter") as MockCifWriter:
212+
mock_writer = MagicMock()
213+
MockCifWriter.return_value = mock_writer
214+
215+
def write_side_effect(path):
216+
with open(path, "w") as f:
217+
f.write("data_test\n_cell_length_a 5.0\n")
218+
219+
mock_writer.write_file.side_effect = write_side_effect
220+
221+
_write_cif_from_struct(mock_structure, out_path, sg_num=225, sg_symbol="Fm-3m")
222+
223+
assert out_path.exists()
224+
content = out_path.read_text()
225+
assert "Fm-3m" in content
226+
assert "_original_symmetry_space_group_name_H-M" in content

0 commit comments

Comments
 (0)