Skip to content

Commit f30a835

Browse files
authored
perf(aggregation): Prevent cuda sync in normalize (#557)
1 parent 0a1ecfd commit f30a835

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/torchjd/_linalg/_gramian.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
5151
sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
5252
therefore `G` divided by the sum of its diagonal elements.
5353
"""
54+
5455
squared_frobenius_norm = gramian.diagonal().sum()
55-
if squared_frobenius_norm < eps:
56-
output = torch.zeros_like(gramian)
57-
else:
58-
output = gramian / squared_frobenius_norm
56+
condition = squared_frobenius_norm < eps
57+
58+
# Use torch.where rather than a if-else to avoid cuda synchronization.
59+
output = torch.where(condition, torch.zeros_like(gramian), gramian / squared_frobenius_norm)
5960
return cast(PSDMatrix, output)
6061

6162

0 commit comments

Comments
 (0)