-
Notifications
You must be signed in to change notification settings - Fork 0
🧪 Add unit tests for calculate_qwf_enhanced_qag_model_loss #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -529,6 +529,78 @@ | |
| "metadata": {} | ||
| } | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "unit_tests_markdown" | ||
| }, | ||
| "source": [ | ||
| "## UNIT TESTS" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "id": "unit_tests_code" | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import unittest\n", | ||
| "from unittest.mock import patch\n", | ||
| "import numpy as np\n", | ||
| "\n", | ||
| "class TestCalculateQWFEnhancedQAGModelLoss(unittest.TestCase):\n", | ||
| " def setUp(self):\n", | ||
| " self.params = {\n", | ||
| " 'Affinity_Constant': 1e-12,\n", | ||
| " 'Affinity_Base': 0.15,\n", | ||
| " 'Shielding_Exp_Factor': 1000.0,\n", | ||
| " 'Variance_Factor_1': 2.0,\n", | ||
| " 'Variance_Factor_2': 15.0,\n", | ||
| " 'Variance_Denom_Offset': 0.0001,\n", | ||
| " 'Variance_to_Speed_Scaling': 1e-12,\n", | ||
| " 'Resonance_Dampening_Factor': 1e-48,\n", | ||
| " 'Tunneling_Boost_Factor': 0.4,\n", | ||
| " 'Info_Recovery_Sigmoid_Factor': 0.1,\n", | ||
| " 'Info_Recovery_Time_Offset': 50.0,\n", | ||
| " 'AVI_Decay_Time_Factor': 10.0\n", | ||
| " }\n", | ||
| " self.target_data = {\n", | ||
| " 'GalacticRotation_MeanQAGSpeed': 2.83e+07,\n", | ||
| " 'AVICosmicExpansion_FinalScaleFactor': 2.00,\n", | ||
| " 'ParticleSignature_ResonanceStrength': 1e6\n", | ||
| " }\n", | ||
| "\n", | ||
| " def test_loss_happy_path(self):\n", | ||
| " # Should return a valid float loss\n", | ||
| " loss = calculate_qwf_enhanced_qag_model_loss(self.params, self.target_data)\n", | ||
| " self.assertIsInstance(loss, float)\n", | ||
| " self.assertGreaterEqual(loss, 0.0)\n", | ||
| " self.assertLess(loss, 1e10) # Assuming the normal loss is smaller than the large penalty\n", | ||
| "\n", | ||
| " @patch('__main__.solve_ivp')\n", | ||
| " def test_loss_integration_failure(self, mock_solve_ivp):\n", | ||
| " # Mocking an integration failure (e.g., status != 0)\n", | ||
| " class MockSol:\n", | ||
| " status = -1\n", | ||
| " y = np.array([[-1.0]])\n", | ||
| " mock_solve_ivp.return_value = MockSol()\n", | ||
|
Comment on lines
+586
to
+589
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of defining a nested class |
||
| " \n", | ||
| " loss = calculate_qwf_enhanced_qag_model_loss(self.params, self.target_data)\n", | ||
| " self.assertGreaterEqual(loss, 1e10)\n", | ||
| "\n", | ||
| " def test_loss_missing_params_uses_defaults(self):\n", | ||
| " # Empty params dictionary should use defaults from QAG_DEFINITIONS\n", | ||
| " empty_params = {}\n", | ||
| " loss = calculate_qwf_enhanced_qag_model_loss(empty_params, self.target_data)\n", | ||
| " self.assertIsInstance(loss, float)\n", | ||
| " self.assertGreaterEqual(loss, 0.0)\n", | ||
| "\n", | ||
| "if __name__ == '__main__':\n", | ||
| " unittest.main(argv=['first-arg-is-ignored'], exit=False)" | ||
| ] | ||
| } | ||
| ] | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value
1e10is a magic number representingLARGE_PENALTYfrom the function under test. It's also used intest_loss_integration_failure(line 592). To improve maintainability and avoid hardcoding this value in multiple places, consider defining it as a class constant, for example:Then you can use
self.LARGE_PENALTYin your assertions, making the tests cleaner and easier to update if the penalty value changes.