-
Notifications
You must be signed in to change notification settings - Fork 79
Accelerated ts.optimize by batching Frechet Cell Filter #439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
|
Could we get some tests verifying identical numerical behavior of the old and new versions? Can be deleted before merging when we get rid of the unbatched version. |
|
@orionarcher I added |
torch_sim/math.py
Outdated
| num_tol = 1e-16 if dtype == torch.float64 else 1e-8 | ||
| batched = T.dim() == 3 | ||
|
|
||
| if batched: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why support both batched and unbatched versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now removed all unbatched code
tests/test_math_frechet.py
Outdated
| class TestExpmFrechet: | ||
| """Tests for expm_frechet against scipy.linalg.expm_frechet.""" | ||
|
|
||
| def test_small_matrix(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we testing the batched or unbatched versions here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I incorporated the batched tests only in test_math.py
|
Haven't looked carefully on the implementation but I would potentially support to have 2 separate functions for batched (B, 3, 3) and unbatched (3,3) algorithms. This would also prevent graph breaks in the future, be easier to read, and in practice a state.cell is always (B, 3, 3), potentially with B=1. So we would always use the batched version anyway. |
9fd0eea to
db04cd6
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
db04cd6 to
ee790a3
Compare
|
@orionarcher I removed all unbatched and unused code while preserving the new performance speedups. Please see the PR description for a detailed list of changes. @thomasloux It’s indeed a good point, but for now it’s probably better to keep things clean and stick to the batched implementation only. By keeping only the batched implementation, we can remove quite a few lines of dead code. |
d6d5b46 to
06ebdae
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
1f467cf to
ff0e9f3
Compare
|
Again not checking the integration math, but noting that |
Update Summary
1. torch_sim/math.py
Removed unbatched/legacy code:
expm_frechet_block_enlarge(helper function for block enlargement method)_diff_pade3,_diff_pade5,_diff_pade7,_diff_pade9(Padé approximation helpers)expm_frechet_algo_64(original algorithm implementation)matrix_exp(custom matrix exponential function)vec,expm_frechet_kronform(Kronecker form helpers)expm_cond(condition number estimation)class expm(autograd Function class)_is_valid_matrix,_determine_eigenvalue_case(unbatched helpers)Refactored
expm_frechet:SPS"orblockEnlarge)Refactored
matrix_log_33:_ensure_batched,_determine_matrix_log_cases,_process_matrix_log_casehelpers2. torch_sim/optimizers/cell_filters.py
Vectorized compute_cell_forces:
expm_frechet(A_batch, E_batch)is now called once with alln_systems * 9matrices batched together3. tests/test_math.py
Refactored tests:
TestExpmFrechet:test_expm_frechet,test_small_norm_expm_frechet,test_fuzzTestExpmFrechetTorch:test_expm_frechet,test_fuzzAll updated to use 3x3 matrices and simplified by removing
methodparameter testing. Fuzz tests streamlined with fewer iterations.Removed tests:
test_problematic_matrix,test_medium_matrix(both numpy and torch versions)TestExpmFrechetTorchGradclassTests for comparing computation methods and large matrix performance no longer apply to the 3x3-specialized implementation.
Added tests:
TestExpmFrechet.test_large_norm_matrices- Tests scaling behavior for larger norm matricesTestLogM33.test_batched_positive_definite- Tests batched matrix logarithm with round-trip verificationTestFrechetCellFilterIntegration- Integration tests for the cell filter pipelinetest_wrap_positions_*- Tests for the newwrap_positionspropertyResults
The figure below shows the speedup achieved for 10-step atomic relaxation. The test is performed for a 8-atom cubic supercell of MgO using the

mace-mpamodel. Prior results are shown in blue, while new results are shown in red. The speedup is calculated asspeedup (%) = (baseline_time / current_time − 1) × 100. We observe a speedup up to 564% for large batches.