From 19a18ba0d3c0422a22956800d80f0534c03242d4 Mon Sep 17 00:00:00 2001 From: kshitij-maths Date: Mon, 25 May 2026 11:55:34 +0200 Subject: [PATCH 1/5] test: automatic_shift --- ezyrb/plugin/automatic_shift.py | 2 +- tests/test_automatic_shift.py | 276 ++++++++++++++++++++++++++++++++ 2 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 tests/test_automatic_shift.py diff --git a/ezyrb/plugin/automatic_shift.py b/ezyrb/plugin/automatic_shift.py index 3c74685a..39e1e9e1 100644 --- a/ezyrb/plugin/automatic_shift.py +++ b/ezyrb/plugin/automatic_shift.py @@ -1,7 +1,7 @@ """Module for Scaler plugin""" import numpy as np - +import torch from ezyrb import Database, Snapshot, Parameter from .plugin import Plugin diff --git a/tests/test_automatic_shift.py b/tests/test_automatic_shift.py new file mode 100644 index 00000000..923e842c --- /dev/null +++ b/tests/test_automatic_shift.py @@ -0,0 +1,276 @@ +import numpy as np +import pytest +import torch +import torch.nn as nn +from unittest import TestCase +from unittest.mock import Mock + +from ezyrb import Database, Parameter, Snapshot +from ezyrb.plugin.automatic_shift import AutomaticShiftSnapshots + +class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.dummy_param = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + val = x.view(x.shape[0], -1).sum(dim=1, keepdim=True) * 0.0 + return val + self.dummy_param + + +class SimpleANN: + def __init__(self, stop_training=None): + self.model = DummyModel() + self.lr = 0.01 + self.l2_regularization = 0.0 + self.loss_trend = [] + self.stop_training = stop_training if stop_training else [1] + self.frequency_print = 100 + + def _build_model(self, x, y): + pass + + def fit(self, x, y): + pass + + def optimizer(self, params, lr, weight_decay): + return torch.optim.SGD(params, lr=lr, weight_decay=weight_decay) + + def predict(self, x): + return np.zeros((x.shape[0], 1)) + + +class SimpleInterpolator: + def fit(self, x, y): + self.x_fit = np.asarray(x) + self.y_fit = np.asarray(y) + + def predict(self, x): + if hasattr(self, 'y_fit'): + return np.full((x.shape[0],), self.y_fit.mean()) + return np.zeros((x.shape[0],)) + + +class MockROM: + def __init__(self, db): + self.database = db + self.predict_full_database = db + self._full_database = None + + +class TestAutomaticShiftSnapshots(TestCase): + + def setUp(self): + self.space = np.array([0.0, 1.0, 2.0]) + self.db = Database() + + snap1 = Snapshot(values=np.array([1.0, 2.0, 3.0]), space=self.space.copy()) + snap2 = Snapshot(values=np.array([2.0, 3.0, 4.0]), space=self.space.copy()) + snap3 = Snapshot(values=np.array([3.0, 4.0, 5.0]), space=self.space.copy()) + + self.db.add(Parameter([1.0]), snap1) + self.db.add(Parameter([2.0]), snap2) + self.db.add(Parameter([3.0]), snap3) + + self.rom = MockROM(self.db) + + def test_constructor_stores_parameters(self): + shift_net = SimpleANN() + interp_net = SimpleANN() + interpolator = SimpleInterpolator() + + plugin = AutomaticShiftSnapshots( + shift_network=shift_net, + interp_network=interp_net, + interpolator=interpolator, + parameter_index=1, + reference_index=2, + barycenter_loss=5.0, + ) + + self.assertIs(plugin.shift_network, shift_net) + self.assertIs(plugin.interp_network, interp_net) + self.assertIs(plugin.interpolator, interpolator) + self.assertEqual(plugin.parameter_index, 1) + self.assertEqual(plugin.reference_index, 2) + self.assertEqual(plugin.barycenter_loss, 5.0) + + def test_fit_preprocessing_sets_reference_snapshot(self): + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + reference_index=1, + ) + + plugin.fit_preprocessing(self.rom) + + expected_snap = self.db._pairs[1][1] + np.testing.assert_array_equal(plugin.reference_snapshot.values, expected_snap.values) + np.testing.assert_array_equal(plugin.reference_snapshot.space, expected_snap.space) + + def test_fit_preprocessing_calls_train_interp_network(self): + shift_net = SimpleANN() + interp_net = SimpleANN() + interp_net.fit = Mock() + + plugin = AutomaticShiftSnapshots( + shift_network=shift_net, + interp_network=interp_net, + interpolator=SimpleInterpolator(), + reference_index=0, + ) + + plugin.fit_preprocessing(self.rom) + + interp_net.fit.assert_called_once() + args, _ = interp_net.fit.call_args + np.testing.assert_array_equal(args[0], self.space.reshape(-1, 1)) + + def test_fit_preprocessing_calls_train_shift_network(self): + shift_net = SimpleANN() + shift_net._build_model = Mock() + + plugin = AutomaticShiftSnapshots( + shift_network=shift_net, + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + ) + + plugin.fit_preprocessing(self.rom) + shift_net._build_model.assert_called_once() + + def test_fit_preprocessing_modifies_snapshots(self): + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + ) + plugin.fit_preprocessing(self.rom) + self.assertIsNotNone(self.db._pairs[0][1].values) + + def test_fit_preprocessing_with_barycenter_loss_zero(self): + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + barycenter_loss=0.0, + ) + plugin.fit_preprocessing(self.rom) + self.assertIsNotNone(plugin.reference_snapshot) + + def test_fit_preprocessing_with_barycenter_loss_nonzero(self): + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + barycenter_loss=10.0, + ) + plugin.fit_preprocessing(self.rom) + self.assertIsNotNone(plugin.reference_snapshot) + + def test_predict_postprocessing_creates_full_database(self): + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + ) + plugin.fit_preprocessing(self.rom) + plugin.predict_postprocessing(self.rom) + + self.assertIsInstance(self.rom._full_database, Database) + + def test_predict_postprocessing_preserves_snapshot_count(self): + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + ) + plugin.fit_preprocessing(self.rom) + original_count = len(self.rom.predict_full_database) + plugin.predict_postprocessing(self.rom) + + self.assertEqual(len(self.rom._full_database), original_count) + + def test_predict_postprocessing_modifies_space(self): + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + ) + plugin.fit_preprocessing(self.rom) + plugin.predict_postprocessing(self.rom) + + new_spaces = [snap.space.copy() for _, snap in self.rom._full_database._pairs] + for new_space in new_spaces: + self.assertEqual(len(new_space), len(self.space)) + + def test_stop_training_integer_criterion(self): + shift_net = SimpleANN(stop_training=[2]) + plugin = AutomaticShiftSnapshots( + shift_network=shift_net, + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + ) + plugin.fit_preprocessing(self.rom) + self.assertEqual(len(shift_net.loss_trend), 2) + + def test_stop_training_float_criterion(self): + shift_net = SimpleANN(stop_training=[100.0]) + plugin = AutomaticShiftSnapshots( + shift_network=shift_net, + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + ) + plugin.fit_preprocessing(self.rom) + self.assertGreaterEqual(len(shift_net.loss_trend), 1) + + def test_single_snapshot_database(self): + db = Database() + snap = Snapshot(values=np.array([1.0, 2.0, 3.0]), space=self.space) + db.add(Parameter([1.0]), snap) + rom = MockROM(db) + + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + ) + plugin.fit_preprocessing(rom) + plugin.predict_postprocessing(rom) + self.assertEqual(len(rom._full_database), 1) + + def test_reference_index_boundary(self): + db = Database() + for i in range(5): + snap = Snapshot(values=np.array([float(i)]), space=np.array([0.5])) + db.add(Parameter([float(i)]), snap) + + rom = MockROM(db) + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + reference_index=4, + ) + plugin.fit_preprocessing(rom) + self.assertEqual(plugin.reference_snapshot.values[0], 4.0) + + def test_multidimensional_parameters_raise_valueerror(self): + db = Database() + snap1 = Snapshot(values=np.array([1.0, 2.0, 3.0]), space=self.space) + db.add(Parameter([1.0, 10.0]), snap1) + rom = MockROM(db) + + plugin = AutomaticShiftSnapshots( + shift_network=SimpleANN(), + interp_network=SimpleANN(), + interpolator=SimpleInterpolator(), + parameter_index=1, + ) + with self.assertRaises(ValueError): + plugin.fit_preprocessing(rom) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) \ No newline at end of file From 5afabb310f75a8f82f3a7d37c859c5caa7e0dcdf Mon Sep 17 00:00:00 2001 From: kshitij-maths Date: Mon, 25 May 2026 11:58:52 +0200 Subject: [PATCH 2/5] test: database_splitter --- tests/test_database_splitter.py | 225 ++++++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 tests/test_database_splitter.py diff --git a/tests/test_database_splitter.py b/tests/test_database_splitter.py new file mode 100644 index 00000000..dab699e5 --- /dev/null +++ b/tests/test_database_splitter.py @@ -0,0 +1,225 @@ +import numpy as np +from unittest import TestCase +from ezyrb import Database +from ezyrb.plugin.database_splitter import DatabaseSplitter, DatabaseDictionarySplitter + +class DummyROM: + train_full_database = None + test_full_database = None + validation_full_database = None + predict_full_database = None + + def __init__(self, db): + self._database = db + + +class TestDatabaseSplitter(TestCase): + + def test_split_integers_train_size(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0) + splitter.fit_preprocessing(rom) + self.assertEqual(len(rom.train_full_database), 80) + + def test_split_integers_test_size(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0) + splitter.fit_preprocessing(rom) + self.assertEqual(len(rom.test_full_database), 20) + + def test_split_integers_validation_predict_empty(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0) + splitter.fit_preprocessing(rom) + self.assertEqual(len(rom.validation_full_database), 0) + self.assertEqual(len(rom.predict_full_database), 0) + + def test_split_integers_total_conserved(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=70, test=20, validation=5, predict=5) + splitter.fit_preprocessing(rom) + total = (len(rom.train_full_database) + + len(rom.test_full_database) + + len(rom.validation_full_database) + + len(rom.predict_full_database)) + self.assertEqual(total, 100) + + def test_split_integers_returns_database_instances(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0) + splitter.fit_preprocessing(rom) + self.assertIsInstance(rom.train_full_database, Database) + self.assertIsInstance(rom.test_full_database, Database) + self.assertIsInstance(rom.validation_full_database, Database) + self.assertIsInstance(rom.predict_full_database, Database) + + def test_split_integers_inconsistent_chunks_raises(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=70, test=20, validation=0, predict=0) + with self.assertRaises(ValueError): + splitter.fit_preprocessing(rom) + + + def test_split_floats_total_conserved(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=0.7, test=0.2, validation=0.05, + predict=0.05, seed=0) + splitter.fit_preprocessing(rom) + total = (len(rom.train_full_database) + + len(rom.test_full_database) + + len(rom.validation_full_database) + + len(rom.predict_full_database)) + self.assertEqual(total, 100) + + def test_split_floats_returns_database_instances(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=0.8, test=0.2, seed=0) + splitter.fit_preprocessing(rom) + self.assertIsInstance(rom.train_full_database, Database) + self.assertIsInstance(rom.test_full_database, Database) + + def test_split_floats_inconsistent_ratios_raises(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom = DummyROM(db) + splitter = DatabaseSplitter(train=0.7, test=0.2, validation=0.0, + predict=0.0) + with self.assertRaises(ValueError): + splitter.fit_preprocessing(rom) + + + def test_split_seed_reproducibility(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom1 = DummyROM(db) + DatabaseSplitter(train=0.8, test=0.2, seed=42).fit_preprocessing(rom1) + + rom2 = DummyROM(db) + DatabaseSplitter(train=0.8, test=0.2, seed=42).fit_preprocessing(rom2) + + np.testing.assert_array_equal( + rom1.train_full_database.parameters_matrix, + rom2.train_full_database.parameters_matrix, + ) + + def test_split_different_seeds_differ(self): + db = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + rom1 = DummyROM(db) + DatabaseSplitter(train=0.8, test=0.2, seed=0).fit_preprocessing(rom1) + + rom2 = DummyROM(db) + DatabaseSplitter(train=0.8, test=0.2, seed=99).fit_preprocessing(rom2) + + with self.assertRaises(AssertionError): + np.testing.assert_array_equal( + rom1.train_full_database.parameters_matrix, + rom2.train_full_database.parameters_matrix, + ) + + def test_split_dict_database_explicit_flattening(self): + db_a = Database(np.random.uniform(size=(100, 2)), + np.random.uniform(size=(100, 5))) + db_b = Database(np.random.uniform(size=(50, 2)), + np.random.uniform(size=(50, 5))) + + rom = DummyROM({'a': db_a, 'b': db_b}) + + splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0) + splitter.fit_preprocessing(rom) + + self.assertIsInstance(rom.train_full_database, Database) + + self.assertEqual(len(rom.train_full_database), 80) + + +class TestDatabaseDictionarySplitter(TestCase): + + def _make_dict_rom(self): + db_train = Database(np.random.uniform(size=(60, 2)), + np.random.uniform(size=(60, 5))) + db_test = Database(np.random.uniform(size=(20, 2)), + np.random.uniform(size=(20, 5))) + db_val = Database(np.random.uniform(size=(10, 2)), + np.random.uniform(size=(10, 5))) + db_pred = Database(np.random.uniform(size=(10, 2)), + np.random.uniform(size=(10, 5))) + db_dict = { + 'train': db_train, + 'test': db_test, + 'val': db_val, + 'pred': db_pred, + } + return DummyROM(db_dict), db_dict + + def test_train_key_assigned(self): + rom, db_dict = self._make_dict_rom() + DatabaseDictionarySplitter(train_key='train').fit_preprocessing(rom) + self.assertEqual(len(rom.train_full_database), 60) + + def test_test_key_assigned(self): + rom, db_dict = self._make_dict_rom() + DatabaseDictionarySplitter(test_key='test').fit_preprocessing(rom) + self.assertEqual(len(rom.test_full_database), 20) + + def test_validation_key_assigned(self): + rom, db_dict = self._make_dict_rom() + DatabaseDictionarySplitter(validation_key='val').fit_preprocessing(rom) + self.assertEqual(len(rom.validation_full_database), 10) + + def test_predict_key_assigned(self): + rom, db_dict = self._make_dict_rom() + DatabaseDictionarySplitter(predict_key='pred').fit_preprocessing(rom) + self.assertEqual(len(rom.predict_full_database), 10) + + def test_all_keys_assigned(self): + rom, db_dict = self._make_dict_rom() + splitter = DatabaseDictionarySplitter( + train_key='train', test_key='test', + validation_key='val', predict_key='pred', + ) + splitter.fit_preprocessing(rom) + self.assertEqual(len(rom.train_full_database), 60) + self.assertEqual(len(rom.test_full_database), 20) + self.assertEqual(len(rom.validation_full_database), 10) + self.assertEqual(len(rom.predict_full_database), 10) + + def test_assigned_database_is_same_object(self): + rom, db_dict = self._make_dict_rom() + splitter = DatabaseDictionarySplitter( + train_key='train', test_key='test', + ) + splitter.fit_preprocessing(rom) + self.assertIs(rom.train_full_database, db_dict['train']) + self.assertIs(rom.test_full_database, db_dict['test']) + + def test_unset_key_leaves_attribute_none(self): + rom, _ = self._make_dict_rom() + DatabaseDictionarySplitter(train_key='train').fit_preprocessing(rom) + self.assertIsNone(rom.test_full_database) + self.assertIsNone(rom.validation_full_database) + self.assertIsNone(rom.predict_full_database) + + def test_non_dict_database_raises(self): + db = Database(np.random.uniform(size=(10, 2)), + np.random.uniform(size=(10, 5))) + rom = DummyROM(db) + splitter = DatabaseDictionarySplitter(train_key='train') + with self.assertRaises(ValueError): + splitter.fit_preprocessing(rom) \ No newline at end of file From d60bf6caedf0130c9db632b3ed97d3c0593d8e29 Mon Sep 17 00:00:00 2001 From: kshitij-maths Date: Mon, 25 May 2026 12:01:49 +0200 Subject: [PATCH 3/5] test: aggregation --- tests/test_aggregation.py | 279 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 tests/test_aggregation.py diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py new file mode 100644 index 00000000..78f8cf36 --- /dev/null +++ b/tests/test_aggregation.py @@ -0,0 +1,279 @@ +import copy +import unittest +import numpy as np +from unittest import TestCase +from ezyrb import Database, RBF +from ezyrb.approximation.linear import Linear +from ezyrb.reduction.pod import POD +from ezyrb.reducedordermodel import ReducedOrderModel as ROM +from ezyrb.reducedordermodel import MultiReducedOrderModel as MROM +from ezyrb.plugin.aggregation import Aggregation +from ezyrb.plugin.database_splitter import DatabaseSplitter + +class MockROM: + validation_full_database = None + + def __init__(self, db): + self.validation_full_database = db + + def predict(self, db): + return db + +class MockMROM: + train_full_database = None + validation_full_database = None + predict_full_database = None + multi_predict_database = None + weights_predict = None + + def __init__(self, db, n_roms=2): + self.roms = {f'rom{i}': MockROM(db) for i in range(n_roms)} + self.train_full_database = db + self.validation_full_database = db + self.predict_full_database = db + self.multi_predict_database = {f'rom{i}': db for i in range(n_roms)} + self.weights_predict = {} + + +def _make_unit_db(): + space = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + params = np.array([[0.5], [1.5]]) + snaps = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]) + return Database(params, snaps, space=space) + + +def _make_integration_db(n_params=5, n_space=3): + mu = np.linspace(0.5, 3.0, n_params) + x = np.linspace(0, 2 * np.pi, n_space) + snaps = np.array([np.sin(m * x) for m in mu]) + space = x.reshape(-1, 1) + return Database(mu.reshape(-1, 1), snaps, space=space) + +def _relative_error(predicted, actual): + norms = np.linalg.norm(actual, axis=1) + norms = np.where(norms < 1e-12, 1.0, norms) + return np.mean(np.linalg.norm(predicted - actual, axis=1) / norms) + +class TestAggregation(TestCase): + + def setUp(self): + self.db = _make_unit_db() + + def test_constructor_default_fit_function_is_none(self): + agg = Aggregation() + self.assertIsNone(agg.fit_function) + + def test_constructor_default_predict_function_is_linear(self): + agg = Aggregation() + self.assertIsInstance(agg.predict_function, Linear) + + def test_constructor_custom_arguments(self): + agg = Aggregation(fit_function=RBF(), predict_function=RBF()) + self.assertIsInstance(agg.fit_function, RBF) + self.assertIsInstance(agg.predict_function, RBF) + + + def test_check_sum_gaussians_partial_zeros(self): + agg = Aggregation() + mrom = MockMROM(self.db, n_roms=2) + gaussians = np.array([[0.0, 0.8], [0.0, 0.2]]) + res = agg._check_sum_gaussians(mrom, gaussians.sum(axis=0), gaussians.copy()) + np.testing.assert_array_equal(res[:, 0], [0.5, 0.5]) + np.testing.assert_array_equal(res[:, 1], [0.8, 0.2]) + + def test_check_sum_gaussians_no_zeros_unchanged(self): + agg = Aggregation() + mrom = MockMROM(self.db, n_roms=2) + gaussians = np.array([[0.3, 0.7], [0.6, 0.3]]) + original = gaussians.copy() + res = agg._check_sum_gaussians(mrom, gaussians.sum(axis=0), gaussians.copy()) + np.testing.assert_array_equal(res, original) + + def test_check_sum_gaussians_all_zeros(self): + agg = Aggregation() + mrom = MockMROM(self.db, n_roms=2) + gaussians = np.zeros((2, 3)) + res = agg._check_sum_gaussians(mrom, gaussians.sum(axis=0), gaussians.copy()) + np.testing.assert_array_equal(res, np.full((2, 3), 0.5)) + + def test_check_sum_gaussians_equal_weight_matches_n_roms(self): + n_roms = 4 + agg = Aggregation() + mrom = MockMROM(self.db, n_roms=n_roms) + gaussians = np.zeros((n_roms, 2)) + res = agg._check_sum_gaussians(mrom, gaussians.sum(axis=0), gaussians.copy()) + np.testing.assert_array_almost_equal(res, np.full((n_roms, 2), 1.0 / n_roms)) + + + def test_compute_validation_weights_perfect_prediction_values(self): + mrom = MockMROM(self.db, n_roms=2) + agg = Aggregation() + g = agg._compute_validation_weights(mrom, sigma=1.0, normalized=False) + np.testing.assert_array_almost_equal(g, np.ones_like(g)) + + def test_compute_validation_weights_normalized_sums_to_one(self): + mrom = MockMROM(self.db, n_roms=2) + agg = Aggregation() + g = agg._compute_validation_weights(mrom, sigma=1.0, normalized=True) + np.testing.assert_array_almost_equal(g.sum(axis=0), np.ones_like(g[0])) + + def test_compute_validation_weights_shape(self): + mrom = MockMROM(self.db, n_roms=3) + agg = Aggregation() + g = agg._compute_validation_weights(mrom, sigma=1.0) + self.assertEqual(g.shape[0], 3) + + def test_compute_validation_weights_sigma_effect(self): + mrom = MockMROM(self.db, n_roms=2) + agg = Aggregation() + g_large = agg._compute_validation_weights(mrom, sigma=1e6, normalized=False) + g_small = agg._compute_validation_weights(mrom, sigma=1e-6, normalized=False) + np.testing.assert_array_almost_equal(g_large, np.ones_like(g_large)) + np.testing.assert_array_almost_equal(g_small, np.ones_like(g_small)) + + + def test_optimize_sigma_returns_finite_value(self): + mrom = MockMROM(self.db, n_roms=2) + agg = Aggregation() + sigma = agg._optimize_sigma(mrom) + self.assertTrue(np.isfinite(sigma).all()) + + def test_optimize_sigma_within_default_range(self): + mrom = MockMROM(self.db, n_roms=2) + agg = Aggregation() + sigma = agg._optimize_sigma(mrom) + self.assertGreaterEqual(float(sigma), 1e-5) + self.assertLessEqual(float(sigma), 1e-2) + + def test_aggregation_no_fit_function(self): + mrom = MockMROM(self.db, n_roms=2) + agg = Aggregation(fit_function=None, predict_function=RBF()) + agg.fit_postprocessing(mrom) + agg.predict_postprocessing(mrom) + self.assertIsNotNone(mrom.predict_full_database) + self.assertEqual(len(agg.predict_functions), 2) + + def test_aggregation_with_fit_function(self): + mrom = MockMROM(self.db, n_roms=1) + agg = Aggregation(fit_function=RBF(), predict_function=RBF()) + agg.fit_postprocessing(mrom) + agg.predict_postprocessing(mrom) + self.assertIsNotNone(mrom.predict_full_database) + + def test_nan_handling_in_weights(self): + mrom = MockMROM(self.db, n_roms=2) + agg = Aggregation(fit_function=None, predict_function=RBF()) + agg._compute_validation_weights = ( + lambda mrom, sigma, normalized=False: np.full((2, 2, 3), np.nan) + ) + agg._optimize_sigma = lambda mrom: 1e-3 + agg.fit_postprocessing(mrom) + self.assertEqual(len(agg.predict_functions), 2) + + +class TestAggregationIntegration(TestCase): + + @classmethod + def setUpClass(cls): + cls.db = _make_integration_db(n_params=5, n_space=3) + + def _make_splitter(self, seed=0): + return DatabaseSplitter( + train=2, test=0, validation=2, predict=1, seed=seed + ) + + def _build_and_fit_mrom(self, agg, seed=0): + splitter = self._make_splitter(seed=seed) + rom1 = ROM(self.db, POD(rank=1), RBF()) + rom2 = ROM(self.db, POD(rank=1), Linear()) + agg._optimize_sigma = lambda mrom: 1e-3 + mrom = MROM( + {'rbf': rom1, 'lin': rom2}, + plugins=[splitter, agg], + rom_plugin=splitter, + ) + mrom.fit() + return mrom + + def test_fit_does_not_raise(self): + agg = Aggregation(fit_function=None, predict_function=RBF()) + self._build_and_fit_mrom(agg) + + def test_fit_regression_path_does_not_raise(self): + splitter = self._make_splitter() + rom1 = ROM(self.db, POD(rank=1), RBF()) + agg = Aggregation(fit_function=RBF(), predict_function=RBF()) + mrom = MROM({'rbf': rom1}, plugins=[splitter, agg], rom_plugin=splitter) + mrom.fit() + + def test_predict_returns_database_instance(self): + agg = Aggregation(fit_function=None, predict_function=RBF()) + mrom = self._build_and_fit_mrom(agg) + mrom.predict(mrom.predict_full_database) + self.assertIsInstance(mrom.predict_full_database, Database) + + def test_predict_snapshot_shape(self): + agg = Aggregation(fit_function=None, predict_function=RBF()) + mrom = self._build_and_fit_mrom(agg) + mrom.predict(mrom.predict_full_database) + self.assertEqual(mrom.predict_full_database.snapshots_matrix.shape[1], 3) + + def test_predict_functions_count_matches_n_roms(self): + agg = Aggregation(fit_function=None, predict_function=RBF()) + self._build_and_fit_mrom(agg) + self.assertEqual(len(agg.predict_functions), 2) + + def test_weights_are_finite(self): + agg = Aggregation(fit_function=None, predict_function=RBF()) + mrom = self._build_and_fit_mrom(agg) + mrom.predict(mrom.predict_full_database) + for key, w in mrom.weights_predict.items(): + self.assertTrue(np.isfinite(w).all(), + msg=f"Non-finite weight for ROM '{key}'") + + def test_weights_sum_to_one(self): + agg = Aggregation(fit_function=None, predict_function=RBF()) + mrom = self._build_and_fit_mrom(agg) + mrom.predict(mrom.predict_full_database) + weight_sum = np.sum(list(mrom.weights_predict.values()), axis=0) + np.testing.assert_array_almost_equal( + weight_sum, np.ones_like(weight_sum), decimal=5 + ) + + def test_fit_reproducible_with_same_seed(self): + agg1 = Aggregation(fit_function=None, predict_function=RBF()) + agg2 = Aggregation(fit_function=None, predict_function=RBF()) + mrom1 = self._build_and_fit_mrom(agg1, seed=7) + mrom2 = self._build_and_fit_mrom(agg2, seed=7) + + pred_db1 = copy.deepcopy(mrom1.predict_full_database) + pred_db2 = copy.deepcopy(mrom2.predict_full_database) + mrom1.predict(pred_db1) + mrom2.predict(pred_db2) + + np.testing.assert_array_almost_equal( + mrom1.predict_full_database.snapshots_matrix, + mrom2.predict_full_database.snapshots_matrix, + decimal=10, + ) + + def test_fit_different_seeds_produce_different_predictions(self): + agg1 = Aggregation(fit_function=None, predict_function=RBF()) + agg2 = Aggregation(fit_function=None, predict_function=RBF()) + mrom1 = self._build_and_fit_mrom(agg1, seed=0) + mrom2 = self._build_and_fit_mrom(agg2, seed=99) + + pred_db1 = copy.deepcopy(mrom1.predict_full_database) + pred_db2 = copy.deepcopy(mrom2.predict_full_database) + mrom1.predict(pred_db1) + mrom2.predict(pred_db2) + + with self.assertRaises(AssertionError): + np.testing.assert_array_almost_equal( + mrom1.predict_full_database.snapshots_matrix, + mrom2.predict_full_database.snapshots_matrix, + decimal=10, + ) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 191c437e3bcf6b27d8343d3824d5b00b9f3ff33d Mon Sep 17 00:00:00 2001 From: kshitij-maths Date: Mon, 25 May 2026 12:05:28 +0200 Subject: [PATCH 4/5] tests: plugging parallel tests --- tests/test_parallel/test_aggregation.py | 7 +++++++ tests/test_parallel/test_automatic_shift.py | 8 ++++++++ tests/test_parallel/test_database_splitter.py | 5 +++++ 3 files changed, 20 insertions(+) create mode 100644 tests/test_parallel/test_aggregation.py create mode 100644 tests/test_parallel/test_automatic_shift.py create mode 100644 tests/test_parallel/test_database_splitter.py diff --git a/tests/test_parallel/test_aggregation.py b/tests/test_parallel/test_aggregation.py new file mode 100644 index 00000000..513a314c --- /dev/null +++ b/tests/test_parallel/test_aggregation.py @@ -0,0 +1,7 @@ +import pytest +import ezyrb +from ezyrb.parallel import ReducedOrderModel as ParallelROM + +ezyrb.ReducedOrderModel = ParallelROM + +from tests.test_aggregation import TestAggregation \ No newline at end of file diff --git a/tests/test_parallel/test_automatic_shift.py b/tests/test_parallel/test_automatic_shift.py new file mode 100644 index 00000000..1e1d72e9 --- /dev/null +++ b/tests/test_parallel/test_automatic_shift.py @@ -0,0 +1,8 @@ +import pytest +import ezyrb +import torch +from ezyrb.parallel import ReducedOrderModel as ParallelROM + +ezyrb.ReducedOrderModel = ParallelROM + +from tests.test_automatic_shift import TestAutomaticShiftSnapshots \ No newline at end of file diff --git a/tests/test_parallel/test_database_splitter.py b/tests/test_parallel/test_database_splitter.py new file mode 100644 index 00000000..22191325 --- /dev/null +++ b/tests/test_parallel/test_database_splitter.py @@ -0,0 +1,5 @@ +import pytest +import ezyrb.parallel +import ezyrb.plugin.database_splitter + +from tests.test_database_splitter import TestDatabaseSplitter, TestDatabaseDictionarySplitter \ No newline at end of file From 643a578c5a80a4bdec14d9404567dc144785f744 Mon Sep 17 00:00:00 2001 From: kshitij-maths Date: Mon, 25 May 2026 12:40:34 +0200 Subject: [PATCH 5/5] add squeeze for numpy compatibility --- tests/test_aggregation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index 78f8cf36..b85f00f2 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -142,8 +142,8 @@ def test_optimize_sigma_within_default_range(self): mrom = MockMROM(self.db, n_roms=2) agg = Aggregation() sigma = agg._optimize_sigma(mrom) - self.assertGreaterEqual(float(sigma), 1e-5) - self.assertLessEqual(float(sigma), 1e-2) + self.assertGreaterEqual(float(np.squeeze(sigma)), 1e-5) + self.assertLessEqual(float(np.squeeze(sigma)), 1e-2) def test_aggregation_no_fit_function(self): mrom = MockMROM(self.db, n_roms=2)