From dd1252ee0f6e9ff4469e86ac7641a4b688bc10f7 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 17 Jan 2026 18:35:36 -0800 Subject: [PATCH 1/5] Triple backward written. --- .../openequivariance/jax/TensorProduct.py | 81 +++++++++++++++---- 1 file changed, 66 insertions(+), 15 deletions(-) diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 452e7bb7..9db9e46d 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -1,4 +1,5 @@ import jax +import jax.numpy as jnp import numpy as np from functools import partial from openequivariance.jax import extlib @@ -7,7 +8,6 @@ from openequivariance.core.utils import hash_attributes from openequivariance.jax.utils import reorder_jax - @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) def forward(X, Y, W, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call( @@ -16,10 +16,18 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs): return forward_call(X, Y, W, **attrs) -def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs): +def forward_fwd(X, Y, W, L3_dim, irrep_dtype, attrs): return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) +def forward_bwd(L3_dim, irrep_dtype, attrs, inputs, dZ): + X, Y, W = inputs + return backward(X, Y, W, dZ, irrep_dtype, attrs) + + +forward.defvjp(forward_fwd, forward_bwd) + + @partial(jax.custom_vjp, nondiff_argnums=(4, 5)) def backward(X, Y, W, dZ, irrep_dtype, attrs): backward_call = jax.ffi.ffi_call( @@ -30,33 +38,76 @@ def backward(X, Y, W, dZ, irrep_dtype, attrs): jax.ShapeDtypeStruct(W.shape, irrep_dtype), ), ) - return backward_call(X, Y, W, dZ, **attrs) -def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs): +def backward_fwd(X, Y, W, dZ, irrep_dtype, attrs): return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) -def double_backward(irrep_dtype, attrs, inputs, derivatives): +def backward_bwd(irrep_dtype, attrs, inputs, derivs): + X, Y, W, dZ = inputs + ddX, ddY, ddW = derivs + return double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) + + +backward.defvjp(backward_fwd, backward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(7, 8)) +def double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): double_backward_call = jax.ffi.ffi_call( "tp_double_backward", ( - jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), ), ) - return double_backward_call(*inputs, *derivatives, **attrs) + return double_backward_call(X, Y, W, dZ, ddX, ddY, ddW, **attrs) + + +def double_backward_fwd(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): + out = double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) + return out, (X, Y, W, dZ, ddX, ddY, ddW) + +def zeros_like(x): + return jnp.zeros_like(x) + +def triple_backward(irrep_dtype, attrs, residuals, tangent_outputs): + X, Y, W, dZ, ddX, ddY, ddW = residuals + t_dX, t_dY, t_dW, t_ddZ = tangent_outputs + + op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) + g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, irrep_dtype, attrs) + + op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) + g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, irrep_dtype, attrs) + + op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) + g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, irrep_dtype, attrs) + + op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) + g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, irrep_dtype, attrs) + + + g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, irrep_dtype, attrs) + g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, irrep_dtype, attrs) + g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, irrep_dtype, attrs) + grad_X = g2_X + g4_X + g6_X + g7_X + grad_Y = g2_Y + g3_Y + g5_Y + g7_Y + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ -def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): - return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs) + grad_ddX = g1_ddX + g3_ddX + g5_ddX + grad_ddY = g1_ddY + g4_ddY + g6_ddY + grad_ddW = g2_ddW + g7_ddW + return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW -forward.defvjp(forward_with_inputs, backward_autograd) -backward.defvjp(backward_with_inputs, double_backward) +double_backward.defvjp(double_backward_fwd, triple_backward) class TensorProduct(LoopUnrollTP): @@ -156,4 +207,4 @@ def double_backward_cpu( out_grad_jax, )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) - return in1_grad, in2_grad, weights_grad, out_dgrad + return in1_grad, in2_grad, weights_grad, out_dgrad \ No newline at end of file From 4aabbc5bc5696603412a73cfb67485ef468ea4ba Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 17 Jan 2026 19:07:13 -0800 Subject: [PATCH 2/5] Triple backward seems to work. --- .../openequivariance/jax/TensorProductConv.py | 135 ++++++++++++++---- 1 file changed, 111 insertions(+), 24 deletions(-) diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 3aaee28a..6d678f55 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -1,3 +1,5 @@ +import jax +import jax.numpy as jnp import numpy as np from functools import partial from typing import Optional @@ -8,13 +10,12 @@ from openequivariance.core.utils import hash_attributes from openequivariance.jax.utils import reorder_jax -import jax -import jax.numpy as jnp - from openequivariance.benchmark.logging_utils import getLogger logger = getLogger() +def zeros_like(x): + return jnp.zeros_like(x) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): @@ -24,12 +25,25 @@ def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, at return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs) -def forward_with_inputs( +def forward_fwd( X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs ): - return forward( + out = forward( X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs - ), (X, Y, W, rows, cols, sender_perm, workspace) + ) + return out, (X, Y, W) + + +def forward_bwd( + rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ +): + X, Y, W = inputs + return backward( + X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs + ) + + +forward.defvjp(forward_fwd, forward_bwd) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9)) @@ -45,39 +59,69 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) -def backward_with_inputs( +def backward_fwd( X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs ): - return backward( + out = backward( X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs - ), (X, Y, W, dZ) # rows, cols, sender_perm, workspace) + ) + return out, (X, Y, W, dZ) -def double_backward( +def backward_bwd( rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives +): + X, Y, W, dZ = inputs + ddX, ddY, ddW = derivatives + return double_backward( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + irrep_dtype, + attrs, + ) + + +backward.defvjp(backward_fwd, backward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12)) +def double_backward( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs ): double_backward_call = jax.ffi.ffi_call( "conv_double_backward", ( - jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), ), ) return double_backward_call( - *inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, **attrs ) -def backward_autograd( - rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ +def double_backward_fwd( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs ): - return backward( - inputs[0], - inputs[1], - inputs[2], + out = double_backward( + X, + Y, + W, dZ, + ddX, + ddY, + ddW, rows, cols, workspace, @@ -85,10 +129,53 @@ def backward_autograd( irrep_dtype, attrs, ) + return out, (X, Y, W, dZ, ddX, ddY, ddW) + + +def triple_backward( + rows, + cols, + workspace, + sender_perm, + irrep_dtype, + attrs, + residuals, + tangent_outputs, +): + X, Y, W, dZ, ddX, ddY, ddW = residuals + t_dX, t_dY, t_dW, t_ddZ = tangent_outputs + + common_args = (rows, cols, workspace, sender_perm, irrep_dtype, attrs) + + op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) + g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, *common_args) + + op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) + g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, *common_args) + + op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) + g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, *common_args) + op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) + g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, *common_args) -forward.defvjp(forward_with_inputs, backward_autograd) -backward.defvjp(backward_with_inputs, double_backward) + g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, *common_args) + g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, *common_args) + g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, *common_args) + + grad_X = g2_X + g4_X + g6_X + g7_X + grad_Y = g2_Y + g3_Y + g5_Y + g7_Y + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ + + grad_ddX = g1_ddX + g3_ddX + g5_ddX + grad_ddY = g1_ddY + g4_ddY + g6_ddY + grad_ddW = g2_ddW + g7_ddW + + return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW + + +double_backward.defvjp(double_backward_fwd, triple_backward) class TensorProductConv(LoopUnrollConv): @@ -295,4 +382,4 @@ def double_backward_cpu( np.asarray(in2_grad), np.asarray(weights_grad), np.asarray(out_dgrad), - ) + ) \ No newline at end of file From 861b63633867d19e639787a40f202c240f44affd Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 17 Jan 2026 19:22:04 -0800 Subject: [PATCH 3/5] Fixed things up. --- .../openequivariance/jax/TensorProduct.py | 82 +++++++++++++++---- 1 file changed, 68 insertions(+), 14 deletions(-) diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 452e7bb7..607d9146 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -1,4 +1,5 @@ import jax +import jax.numpy as jnp import numpy as np from functools import partial from openequivariance.jax import extlib @@ -16,10 +17,18 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs): return forward_call(X, Y, W, **attrs) -def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs): +def forward_fwd(X, Y, W, L3_dim, irrep_dtype, attrs): return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) +def forward_bwd(L3_dim, irrep_dtype, attrs, inputs, dZ): + X, Y, W = inputs + return backward(X, Y, W, dZ, irrep_dtype, attrs) + + +forward.defvjp(forward_fwd, forward_bwd) + + @partial(jax.custom_vjp, nondiff_argnums=(4, 5)) def backward(X, Y, W, dZ, irrep_dtype, attrs): backward_call = jax.ffi.ffi_call( @@ -30,33 +39,78 @@ def backward(X, Y, W, dZ, irrep_dtype, attrs): jax.ShapeDtypeStruct(W.shape, irrep_dtype), ), ) - return backward_call(X, Y, W, dZ, **attrs) -def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs): +def backward_fwd(X, Y, W, dZ, irrep_dtype, attrs): return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) -def double_backward(irrep_dtype, attrs, inputs, derivatives): +def backward_bwd(irrep_dtype, attrs, inputs, derivs): + X, Y, W, dZ = inputs + ddX, ddY, ddW = derivs + return double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) + + +backward.defvjp(backward_fwd, backward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(7, 8)) +def double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): double_backward_call = jax.ffi.ffi_call( "tp_double_backward", ( - jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), ), ) - return double_backward_call(*inputs, *derivatives, **attrs) + return double_backward_call(X, Y, W, dZ, ddX, ddY, ddW, **attrs) + + +def double_backward_fwd(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): + out = double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) + return out, (X, Y, W, dZ, ddX, ddY, ddW) + + +def zeros_like(x): + return jnp.zeros_like(x) + + +def triple_backward(irrep_dtype, attrs, residuals, tangent_outputs): + X, Y, W, dZ, ddX, ddY, ddW = residuals + t_dX, t_dY, t_dW, t_ddZ = tangent_outputs + + op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) + g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, irrep_dtype, attrs) + + op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) + g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, irrep_dtype, attrs) + + op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) + g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, irrep_dtype, attrs) + + op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) + g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, irrep_dtype, attrs) + + g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, irrep_dtype, attrs) + g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, irrep_dtype, attrs) + g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, irrep_dtype, attrs) + + grad_X = g2_X + g4_X + g6_X + g7_X + grad_Y = g2_Y + g3_Y + g5_Y + g7_Y + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ + grad_ddX = g1_ddX + g3_ddX + g5_ddX + grad_ddY = g1_ddY + g4_ddY + g6_ddY + grad_ddW = g2_ddW + g7_ddW -def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): - return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs) + return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW -forward.defvjp(forward_with_inputs, backward_autograd) -backward.defvjp(backward_with_inputs, double_backward) +double_backward.defvjp(double_backward_fwd, triple_backward) class TensorProduct(LoopUnrollTP): @@ -156,4 +210,4 @@ def double_backward_cpu( out_grad_jax, )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) - return in1_grad, in2_grad, weights_grad, out_dgrad + return in1_grad, in2_grad, weights_grad, out_dgrad \ No newline at end of file From 9bf5a245abb924423220fa4eec62a7d3f4d42274 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 17 Jan 2026 20:34:22 -0800 Subject: [PATCH 4/5] Fixed issues. --- README.md | 5 ++ .../openequivariance/jax/TensorProductConv.py | 82 +++++-------------- tests/example_test.py | 4 + 3 files changed, 29 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index c68e4fa9..4f228074 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,11 @@ Z = tp_conv.forward( X, Y, W, edge_index[0], edge_index[1] ) print(jax.numpy.linalg.norm(Z)) + +# Test JAX JIT +func = lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2) +jitted = jax.jit(func) +print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1]))) ``` ## Citation and Acknowledgements diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 6d678f55..59cb2397 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -17,7 +17,7 @@ def zeros_like(x): return jnp.zeros_like(x) -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call( "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) @@ -31,22 +31,23 @@ def forward_fwd( out = forward( X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs ) - return out, (X, Y, W) + return out, (X, Y, W, rows, cols) def forward_bwd( - rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ + workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ ): - X, Y, W = inputs - return backward( + X, Y, W, rows, cols = res + dX, dY, dW = backward( X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs ) + return dX, dY, dW, None, None forward.defvjp(forward_fwd, forward_bwd) -@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9)) +@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9)) def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): backward_call = jax.ffi.ffi_call( "conv_backward", @@ -65,35 +66,26 @@ def backward_fwd( out = backward( X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs ) - return out, (X, Y, W, dZ) + return out, (X, Y, W, dZ, rows, cols) def backward_bwd( - rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives + workspace, sender_perm, irrep_dtype, attrs, res, derivatives ): - X, Y, W, dZ = inputs + X, Y, W, dZ, rows, cols = res ddX, ddY, ddW = derivatives - return double_backward( - X, - Y, - W, - dZ, - ddX, - ddY, - ddW, - rows, - cols, - workspace, - sender_perm, - irrep_dtype, - attrs, + + gX, gY, gW, gdZ = double_backward( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs ) + return gX, gY, gW, gdZ, None, None + backward.defvjp(backward_fwd, backward_bwd) -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12)) +@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12)) def double_backward( X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs ): @@ -129,12 +121,10 @@ def double_backward_fwd( irrep_dtype, attrs, ) - return out, (X, Y, W, dZ, ddX, ddY, ddW) + return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols) def triple_backward( - rows, - cols, workspace, sender_perm, irrep_dtype, @@ -142,7 +132,7 @@ def triple_backward( residuals, tangent_outputs, ): - X, Y, W, dZ, ddX, ddY, ddW = residuals + X, Y, W, dZ, ddX, ddY, ddW, rows, cols = residuals t_dX, t_dY, t_dW, t_ddZ = tangent_outputs common_args = (rows, cols, workspace, sender_perm, irrep_dtype, attrs) @@ -172,25 +162,13 @@ def triple_backward( grad_ddY = g1_ddY + g4_ddY + g6_ddY grad_ddW = g2_ddW + g7_ddW - return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW + return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None double_backward.defvjp(double_backward_fwd, triple_backward) class TensorProductConv(LoopUnrollConv): - r""" - Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one - key difference: integer arrays passed to this function must have dtype - ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version). - - :param problem: Specification of the tensor product. - :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic - fixup-based algorithm. `Default`: ``False``. - :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, - the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. - """ - def __init__( self, config: TPProblem, deterministic: bool = False, kahan: bool = False ): @@ -199,7 +177,7 @@ def __init__( config, dp, extlib.postprocess_kernel, - idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version + idx_dtype=np.int32, torch_op=False, deterministic=deterministic, kahan=kahan, @@ -232,26 +210,6 @@ def forward( cols: jax.numpy.ndarray, sender_perm: Optional[jax.numpy.ndarray] = None, ) -> jax.numpy.ndarray: - r""" - Computes the fused CG tensor product + convolution. - - :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. - :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. - :param W: Tensor of datatype ``problem.weight_dtype`` and shape - - * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False`` - * ``[problem.weight_numel]`` if ``problem.shared_weights=True`` - - :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix, - datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``. - :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix, - datatype ``np.int32``. - :param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a - permutation that transposes the adjacency matrix nonzeros from row-major to column-major order. - Must be provided when ``deterministic=True``. - - :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. - """ if not self.deterministic: sender_perm = self.dummy_transpose_perm else: diff --git a/tests/example_test.py b/tests/example_test.py index ae19f77e..993f5dfe 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -161,3 +161,7 @@ def test_tutorial_jax(with_jax): tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) print(jax.numpy.linalg.norm(Z)) + + func = lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2) + jitted = jax.jit(func) + print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1]))) \ No newline at end of file From 0f45a246904f55b61f1dba553e1f4700d18a3f97 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 17 Jan 2026 20:40:14 -0800 Subject: [PATCH 5/5] Precommit. --- README.md | 3 +- .../openequivariance/jax/TensorProduct.py | 3 +- .../openequivariance/jax/TensorProductConv.py | 34 +++++++++++-------- tests/example_test.py | 5 ++- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 4f228074..3efe59a7 100644 --- a/README.md +++ b/README.md @@ -185,8 +185,7 @@ Z = tp_conv.forward( print(jax.numpy.linalg.norm(Z)) # Test JAX JIT -func = lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2) -jitted = jax.jit(func) +jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2)) print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1]))) ``` diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index f9c0c01e..05e4b097 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -8,6 +8,7 @@ from openequivariance.core.utils import hash_attributes from openequivariance.jax.utils import reorder_jax + @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) def forward(X, Y, W, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call( @@ -209,4 +210,4 @@ def double_backward_cpu( out_grad_jax, )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) - return in1_grad, in2_grad, weights_grad, out_dgrad \ No newline at end of file + return in1_grad, in2_grad, weights_grad, out_dgrad diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 59cb2397..7439cd4e 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -14,9 +14,11 @@ logger = getLogger() + def zeros_like(x): return jnp.zeros_like(x) + @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call( @@ -34,9 +36,7 @@ def forward_fwd( return out, (X, Y, W, rows, cols) -def forward_bwd( - workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ -): +def forward_bwd(workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ): X, Y, W, rows, cols = res dX, dY, dW = backward( X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs @@ -60,23 +60,29 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) -def backward_fwd( - X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs -): - out = backward( - X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs - ) +def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): + out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs) return out, (X, Y, W, dZ, rows, cols) -def backward_bwd( - workspace, sender_perm, irrep_dtype, attrs, res, derivatives -): +def backward_bwd(workspace, sender_perm, irrep_dtype, attrs, res, derivatives): X, Y, W, dZ, rows, cols = res ddX, ddY, ddW = derivatives gX, gY, gW, gdZ = double_backward( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + irrep_dtype, + attrs, ) return gX, gY, gW, gdZ, None, None @@ -340,4 +346,4 @@ def double_backward_cpu( np.asarray(in2_grad), np.asarray(weights_grad), np.asarray(out_dgrad), - ) \ No newline at end of file + ) diff --git a/tests/example_test.py b/tests/example_test.py index 993f5dfe..e8d23cb7 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -162,6 +162,5 @@ def test_tutorial_jax(with_jax): Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) print(jax.numpy.linalg.norm(Z)) - func = lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2) - jitted = jax.jit(func) - print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1]))) \ No newline at end of file + jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2)) + print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))