diff --git a/src/pyrecest/_backend/pytorch/linalg.py b/src/pyrecest/_backend/pytorch/linalg.py index b35a5cc32..d3ed0a052 100644 --- a/src/pyrecest/_backend/pytorch/linalg.py +++ b/src/pyrecest/_backend/pytorch/linalg.py @@ -174,11 +174,13 @@ def solve_sylvester(a, b, q): a = a.to(dtype=common_dtype) b = b.to(dtype=common_dtype) q = q.to(dtype=common_dtype) - if ( - a.shape == b.shape + is_real_shared_symmetric_factor = ( + not is_complex(a) + and a.shape == b.shape and _torch.all(a == b) and _torch.all(_torch.abs(a - a.transpose(-2, -1)) < 1e-6) - ): + ) + if is_real_shared_symmetric_factor: eigvals, eigvecs = eigh(a) if _torch.all(eigvals >= 1e-6): tilde_q = eigvecs.transpose(-2, -1) @ q @ eigvecs diff --git a/tests/test_pytorch_backend.py b/tests/test_pytorch_backend.py index d1882c131..2579ea6e6 100644 --- a/tests/test_pytorch_backend.py +++ b/tests/test_pytorch_backend.py @@ -264,6 +264,21 @@ def test_solve_sylvester_promotes_mixed_dtypes(self): self.assertEqual(result.dtype, pytorch_backend.float64) self.assertTrue(pytorch_backend.allclose(result, expected)) + def test_solve_sylvester_complex_symmetric_uses_general_solver(self): + a = pytorch_backend.array( + [[2.0 + 0.0j, 0.0 + 1.0j], [0.0 + 1.0j, 3.0 + 0.0j]], + dtype=pytorch_backend.complex128, + ) + q = pytorch_backend.array( + [[1.0 + 2.0j, 0.5 - 0.25j], [-1.0 + 0.75j, 2.0 - 1.0j]], + dtype=pytorch_backend.complex128, + ) + + result = pytorch_backend.linalg.solve_sylvester(a, a, q) + + residual = a @ result + result @ a + self.assertTrue(pytorch_backend.allclose(residual, q, atol=1e-10, rtol=1e-10)) + def test_sqrtm_complex_result_uses_matching_complex_precision(self): dtype_pairs = ( (pytorch_backend.float32, pytorch_backend.complex64),