From 6eb39ec6876b4c430f01e6218aa84884f257ff50 Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Wed, 20 May 2026 21:32:50 -0700 Subject: [PATCH 1/3] Add PyTorch triple backward support --- .../_torch/NPDoubleBackwardMixin.py | 185 ++++++++++ .../openequivariance/_torch/TensorProduct.py | 123 ++++++- .../_torch/TensorProductConv.py | 142 +++++++ .../openequivariance/benchmark/correctness.py | 345 ++++++++++++++++++ .../benchmark/test_buffers.py | 110 ++++++ tests/batch_test.py | 207 ++++++++++- tests/conv_test.py | 186 ++++++++++ 7 files changed, 1296 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py index caf94268..702295d5 100644 --- a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py +++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py @@ -1,6 +1,13 @@ import torch +def _none_to_zeros(values, refs): + return tuple( + ref * 0 if value is None else value + for value, ref in zip(values, refs) + ) + + class NumpyDoubleBackwardMixin: """ Adds a Numpy double backward method to any TensorProduct @@ -43,6 +50,89 @@ def double_backward_cpu( d.detach().cpu().numpy(), ) + def triple_backward_cpu( + self, + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ): + assert self.torch_op + + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) + weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda").requires_grad_(True) + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda").requires_grad_(True) + weights_dgrad_torch = ( + torch.tensor(weights_dgrad).to("cuda").requires_grad_(True) + ) + out_tgrad_torch = torch.tensor(out_tgrad).to("cuda") + in1_tgrad_torch = torch.tensor(in1_tgrad).to("cuda") + in2_tgrad_torch = torch.tensor(in2_tgrad).to("cuda") + weights_tgrad_torch = torch.tensor(weights_tgrad).to("cuda") + + out_torch = self.forward(in1_torch, in2_torch, weights_torch) + in1_grad, in2_grad, weights_grad = torch.autograd.grad( + outputs=out_torch, + inputs=[in1_torch, in2_torch, weights_torch], + grad_outputs=out_grad_torch, + create_graph=True, + retain_graph=True, + ) + double_grads = torch.autograd.grad( + outputs=[in1_grad, in2_grad, weights_grad], + inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], + create_graph=True, + retain_graph=True, + allow_unused=True, + ) + double_grads = _none_to_zeros( + double_grads, (in1_torch, in2_torch, weights_torch, out_grad_torch) + ) + triple_grads = torch.autograd.grad( + outputs=double_grads, + inputs=[ + in1_torch, + in2_torch, + weights_torch, + out_grad_torch, + in1_dgrad_torch, + in2_dgrad_torch, + weights_dgrad_torch, + ], + grad_outputs=[ + in1_tgrad_torch, + in2_tgrad_torch, + weights_tgrad_torch, + out_tgrad_torch, + ], + allow_unused=True, + ) + triple_grads = _none_to_zeros( + triple_grads, + ( + in1_torch, + in2_torch, + weights_torch, + out_grad_torch, + in1_dgrad_torch, + in2_dgrad_torch, + weights_dgrad_torch, + ), + ) + + return tuple(grad.detach().cpu().numpy() for grad in triple_grads) + class NumpyDoubleBackwardMixinConv: """ @@ -95,3 +185,98 @@ def double_backward_cpu( c.detach().cpu().numpy(), d.detach().cpu().numpy(), ) + + def triple_backward_cpu( + self, + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + graph, + ): + assert self.torch_op + + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) + weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda").requires_grad_(True) + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda").requires_grad_(True) + weights_dgrad_torch = ( + torch.tensor(weights_dgrad).to("cuda").requires_grad_(True) + ) + out_tgrad_torch = torch.tensor(out_tgrad).to("cuda") + in1_tgrad_torch = torch.tensor(in1_tgrad).to("cuda") + in2_tgrad_torch = torch.tensor(in2_tgrad).to("cuda") + weights_tgrad_torch = torch.tensor(weights_tgrad).to("cuda") + + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") + torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") + + out_torch = self.forward( + in1_torch, + in2_torch, + weights_torch, + torch_rows, + torch_cols, + torch_transpose_perm, + ) + in1_grad, in2_grad, weights_grad = torch.autograd.grad( + outputs=out_torch, + inputs=[in1_torch, in2_torch, weights_torch], + grad_outputs=out_grad_torch, + create_graph=True, + retain_graph=True, + ) + double_grads = torch.autograd.grad( + outputs=[in1_grad, in2_grad, weights_grad], + inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], + create_graph=True, + retain_graph=True, + allow_unused=True, + ) + double_grads = _none_to_zeros( + double_grads, (in1_torch, in2_torch, weights_torch, out_grad_torch) + ) + triple_grads = torch.autograd.grad( + outputs=double_grads, + inputs=[ + in1_torch, + in2_torch, + weights_torch, + out_grad_torch, + in1_dgrad_torch, + in2_dgrad_torch, + weights_dgrad_torch, + ], + grad_outputs=[ + in1_tgrad_torch, + in2_tgrad_torch, + weights_tgrad_torch, + out_tgrad_torch, + ], + allow_unused=True, + ) + triple_grads = _none_to_zeros( + triple_grads, + ( + in1_torch, + in2_torch, + weights_torch, + out_grad_torch, + in1_dgrad_torch, + in2_dgrad_torch, + weights_dgrad_torch, + ), + ) + + return tuple(grad.detach().cpu().numpy() for grad in triple_grads) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index c18b1231..de587b7b 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -200,6 +200,12 @@ def fake_double_backward(kernel, hash, L1_in, L2_in, W, L3_grad, E, F, G): def register_autograd(): backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward + double_backward_op = torch.ops.libtorch_tp_jit.jit_tp_double_backward + + def zero_if_none(grad_output, like): + if grad_output is None: + return torch.zeros_like(like) + return grad_output def setup_context(ctx, inputs, output): ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_dim = inputs @@ -218,7 +224,7 @@ def setup_context_double_backward(ctx, inputs, output): ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs def double_backward(ctx, E, F, G): - result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( + result = double_backward_op( ctx.kernel, ctx.hash, ctx.L1_in, @@ -237,6 +243,121 @@ def double_backward(ctx, E, F, G): setup_context=setup_context_double_backward, ) + def setup_context_triple_backward(ctx, inputs, output): + ( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.weights, + ctx.L3_grad, + ctx.L1_dgrad, + ctx.L2_dgrad, + ctx.W_dgrad, + ) = inputs + + def triple_backward(ctx, t_L1_grad, t_L2_grad, t_W_grad, t_L3_dgrad): + t_L1_grad = zero_if_none(t_L1_grad, ctx.L1_in) + t_L2_grad = zero_if_none(t_L2_grad, ctx.L2_in) + t_W_grad = zero_if_none(t_W_grad, ctx.weights) + t_L3_dgrad = zero_if_none(t_L3_dgrad, ctx.L3_grad) + + g1_L1_dgrad, g1_L2_dgrad, g1_W, g1_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_dgrad, + ctx.weights, + ctx.L3_grad, + t_L1_grad, + t_L2_grad, + torch.zeros_like(ctx.weights), + ) + g2_L1, g2_L2, g2_W_dgrad, g2_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W_dgrad, + ctx.L3_grad, + t_L1_grad, + t_L2_grad, + torch.zeros_like(ctx.W_dgrad), + ) + g3_L1_dgrad, g3_L2, g3_W, g3_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_in, + ctx.weights, + ctx.L3_grad, + torch.zeros_like(ctx.L1_dgrad), + torch.zeros_like(ctx.L2_in), + t_W_grad, + ) + g4_L1, g4_L2_dgrad, g4_W, g4_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_dgrad, + ctx.weights, + ctx.L3_grad, + torch.zeros_like(ctx.L1_in), + torch.zeros_like(ctx.L2_dgrad), + t_W_grad, + ) + + g5_L1_dgrad, g5_L2, g5_W = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_in, + ctx.weights, + t_L3_dgrad, + ) + g6_L1, g6_L2_dgrad, g6_W = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_dgrad, + ctx.weights, + t_L3_dgrad, + ) + g7_L1, g7_L2, g7_W_dgrad = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W_dgrad, + t_L3_dgrad, + ) + + grad_L1 = g2_L1 + g4_L1 + g6_L1 + g7_L1 + grad_L2 = g2_L2 + g3_L2 + g5_L2 + g7_L2 + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_L3_grad = g1_L3_grad + g2_L3_grad + g3_L3_grad + g4_L3_grad + grad_L1_dgrad = g1_L1_dgrad + g3_L1_dgrad + g5_L1_dgrad + grad_L2_dgrad = g1_L2_dgrad + g4_L2_dgrad + g6_L2_dgrad + grad_W_dgrad = g2_W_dgrad + g7_W_dgrad + + return ( + None, + None, + grad_L1, + grad_L2, + grad_W, + grad_L3_grad, + grad_L1_dgrad, + grad_L2_dgrad, + grad_W_dgrad, + ) + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_double_backward", + triple_backward, + setup_context=setup_context_triple_backward, + ) + def register_autocast(): global torch diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index c1087a63..c75b7397 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -330,6 +330,11 @@ def register_autograd(): backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward + def zero_if_none(grad_output, like): + if grad_output is None: + return torch.zeros_like(like) + return grad_output + def setup_context(ctx, inputs, output): ( ctx.kernel, @@ -413,6 +418,143 @@ def double_backward(ctx, E, F, G): setup_context=setup_context_double_backward, ) + def setup_context_triple_backward(ctx, inputs, output): + ( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + ctx.L1_dgrad, + ctx.L2_dgrad, + ctx.W_dgrad, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + + def triple_backward(ctx, t_L1_grad, t_L2_grad, t_W_grad, t_L3_dgrad): + t_L1_grad = zero_if_none(t_L1_grad, ctx.L1_in) + t_L2_grad = zero_if_none(t_L2_grad, ctx.L2_in) + t_W_grad = zero_if_none(t_W_grad, ctx.W) + t_L3_dgrad = zero_if_none(t_L3_dgrad, ctx.grad_output) + + common_args = ( + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + + g1_L1_dgrad, g1_L2_dgrad, g1_W, g1_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_dgrad, + ctx.W, + ctx.grad_output, + t_L1_grad, + t_L2_grad, + torch.zeros_like(ctx.W), + *common_args, + ) + g2_L1, g2_L2, g2_W_dgrad, g2_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W_dgrad, + ctx.grad_output, + t_L1_grad, + t_L2_grad, + torch.zeros_like(ctx.W_dgrad), + *common_args, + ) + g3_L1_dgrad, g3_L2, g3_W, g3_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_in, + ctx.W, + ctx.grad_output, + torch.zeros_like(ctx.L1_dgrad), + torch.zeros_like(ctx.L2_in), + t_W_grad, + *common_args, + ) + g4_L1, g4_L2_dgrad, g4_W, g4_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_dgrad, + ctx.W, + ctx.grad_output, + torch.zeros_like(ctx.L1_in), + torch.zeros_like(ctx.L2_dgrad), + t_W_grad, + *common_args, + ) + + g5_L1_dgrad, g5_L2, g5_W = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_in, + ctx.W, + t_L3_dgrad, + *common_args, + ) + g6_L1, g6_L2_dgrad, g6_W = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_dgrad, + ctx.W, + t_L3_dgrad, + *common_args, + ) + g7_L1, g7_L2, g7_W_dgrad = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W_dgrad, + t_L3_dgrad, + *common_args, + ) + + grad_L1 = g2_L1 + g4_L1 + g6_L1 + g7_L1 + grad_L2 = g2_L2 + g3_L2 + g5_L2 + g7_L2 + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_L3_grad = g1_L3_grad + g2_L3_grad + g3_L3_grad + g4_L3_grad + grad_L1_dgrad = g1_L1_dgrad + g3_L1_dgrad + g5_L1_dgrad + grad_L2_dgrad = g1_L2_dgrad + g4_L2_dgrad + g6_L2_dgrad + grad_W_dgrad = g2_W_dgrad + g7_W_dgrad + + return ( + None, + None, + grad_L1, + grad_L2, + grad_W, + grad_L3_grad, + grad_L1_dgrad, + grad_L2_dgrad, + grad_W_dgrad, + None, + None, + None, + None, + ) + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_double_backward", + triple_backward, + setup_context=setup_context_triple_backward, + ) + def register_autocast(): torch.library.register_autocast( diff --git a/openequivariance/openequivariance/benchmark/correctness.py b/openequivariance/openequivariance/benchmark/correctness.py index 45c45c4a..fa95f826 100644 --- a/openequivariance/openequivariance/benchmark/correctness.py +++ b/openequivariance/openequivariance/benchmark/correctness.py @@ -13,6 +13,8 @@ get_random_buffers_double_backward, get_random_buffers_forward_conv, get_random_buffers_forward, + get_random_buffers_triple_backward, + get_random_buffers_triple_backward_conv, ) from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.TensorProductBase import TensorProductBase @@ -316,6 +318,175 @@ def correctness_double_backward( return result +def correctness_triple_backward( + problem: TPProblem, + test_implementation: Union[type[TensorProductBase], TensorProductBase], + reference_implementation: Optional[type[TensorProductBase]], + batch_size: int, + correctness_threshold: float, + prng_seed: int, +): + buffers = get_random_buffers_triple_backward( + problem, batch_size=batch_size, prng_seed=prng_seed + ) + + if reference_implementation is None: + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct + + reference_implementation = E3NNTensorProduct + + result = {"thresh": correctness_threshold, "batch_size": batch_size} + tensors = [] + for i, impl in enumerate([test_implementation, reference_implementation]): + is_test_impl = i == 0 + tp = instantiate_implementation(impl, problem) + buffers_copy = [buf.copy() for buf in buffers] + ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) = buffers_copy + + weights_reordered = tp.reorder_weights_from_e3nn( + weights, has_batch_dim=not problem.shared_weights + ) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not problem.shared_weights + ) + weights_tgrad_reordered = tp.reorder_weights_from_e3nn( + weights_tgrad, has_batch_dim=not problem.shared_weights + ) + + if impl == CUETensorProduct and problem.shared_weights: + weights_reordered = weights_reordered[np.newaxis, :] + weights_dgrad_reordered = weights_dgrad_reordered[np.newaxis, :] + weights_tgrad_reordered = weights_tgrad_reordered[np.newaxis, :] + + if is_test_impl: + ( + in1, + in2, + out_grad, + in1_dgrad, + in2_dgrad, + out_tgrad, + in1_tgrad, + in2_tgrad, + ) = [ + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + ( + in1, + in2, + out_grad, + in1_dgrad, + in2_dgrad, + out_tgrad, + in1_tgrad, + in2_tgrad, + ), + ( + problem.irreps_in1, + problem.irreps_in2, + problem.irreps_out, + problem.irreps_in1, + problem.irreps_in2, + problem.irreps_out, + problem.irreps_in1, + problem.irreps_in2, + ), + ) + ] + + ( + in1_grad, + in2_grad, + weights_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + weights_dgrad_grad, + ) = tp.triple_backward_cpu( + in1, + in2, + out_grad, + weights_reordered, + weights_dgrad_reordered, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad_reordered, + in1_tgrad, + in2_tgrad, + ) + + if is_test_impl: + ( + in1_grad, + in2_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + ) = [ + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") + for arr, irreps in zip( + ( + in1_grad, + in2_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + ), + ( + problem.irreps_in1, + problem.irreps_in2, + problem.irreps_out, + problem.irreps_in1, + problem.irreps_in2, + ), + ) + ] + + tensors.append( + ( + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not problem.shared_weights + ), + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + tp.reorder_weights_to_e3nn( + weights_dgrad_grad, has_batch_dim=not problem.shared_weights + ), + ) + ) + + for name, to_check, ground_truth in [ + ("in1_grad", tensors[0][0], tensors[1][0]), + ("in2_grad", tensors[0][1], tensors[1][1]), + ("weights_grad", tensors[0][2], tensors[1][2]), + ("output_grad", tensors[0][3], tensors[1][3]), + ("in1_double_grad", tensors[0][4], tensors[1][4]), + ("in2_double_grad", tensors[0][5], tensors[1][5]), + ("weights_double_grad", tensors[0][6], tensors[1][6]), + ]: + result[name] = check_similiarity( + name, to_check, ground_truth, correctness_threshold + ) + + return result + + def correctness_forward_conv( conv, graph, @@ -636,3 +807,177 @@ def correctness_double_backward_conv( result[name] = check_similiarity(name, to_check, ground_truth, thresh) return result + + +def correctness_triple_backward_conv( + conv, + graph, + thresh, + prng_seed, + reference_implementation=None, + high_precision_ref=False, +): + buffers = get_random_buffers_triple_backward_conv( + conv.config, graph.node_count, graph.nnz, prng_seed + ) + + if reference_implementation is None: + from openequivariance._torch.E3NNConv import E3NNConv + + reference_implementation = E3NNConv + + reference_problem = conv.config + if high_precision_ref: + reference_problem = copy.deepcopy(conv.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + + reference_tp = reference_implementation(reference_problem, torch_op=True) + + result = {"thresh": thresh} + tensors = [] + for i, tp in enumerate([conv, reference_tp]): + is_test_impl = i == 0 + buffers_copy = [buf.copy() for buf in buffers] + + if i == 1 and high_precision_ref: + buffers_copy = [np.array(el, dtype=np.float64) for el in buffers] + + ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) = buffers_copy + + weights_reordered = tp.reorder_weights_from_e3nn( + weights, has_batch_dim=not conv.config.shared_weights + ) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not conv.config.shared_weights + ) + weights_tgrad_reordered = tp.reorder_weights_from_e3nn( + weights_tgrad, has_batch_dim=not conv.config.shared_weights + ) + + if is_test_impl: + ( + in1, + in2, + out_grad, + in1_dgrad, + in2_dgrad, + out_tgrad, + in1_tgrad, + in2_tgrad, + ) = [ + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + ( + in1, + in2, + out_grad, + in1_dgrad, + in2_dgrad, + out_tgrad, + in1_tgrad, + in2_tgrad, + ), + ( + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + ), + ) + ] + + ( + in1_grad, + in2_grad, + weights_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + weights_dgrad_grad, + ) = tp.triple_backward_cpu( + in1, + in2, + out_grad, + weights_reordered, + weights_dgrad_reordered, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad_reordered, + in1_tgrad, + in2_tgrad, + graph, + ) + + if is_test_impl: + ( + in1_grad, + in2_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + ) = [ + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") + for arr, irreps in zip( + ( + in1_grad, + in2_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + ), + ( + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + ), + ) + ] + + tensors.append( + ( + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not conv.config.shared_weights + ), + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + tp.reorder_weights_to_e3nn( + weights_dgrad_grad, has_batch_dim=not conv.config.shared_weights + ), + ) + ) + + for name, to_check, ground_truth in [ + ("in1_grad", tensors[0][0], tensors[1][0]), + ("in2_grad", tensors[0][1], tensors[1][1]), + ("weights_grad", tensors[0][2], tensors[1][2]), + ("output_grad", tensors[0][3], tensors[1][3]), + ("in1_double_grad", tensors[0][4], tensors[1][4]), + ("in2_double_grad", tensors[0][5], tensors[1][5]), + ("weights_double_grad", tensors[0][6], tensors[1][6]), + ]: + result[name] = check_similiarity(name, to_check, ground_truth, thresh) + + return result diff --git a/openequivariance/openequivariance/benchmark/test_buffers.py b/openequivariance/openequivariance/benchmark/test_buffers.py index c657d5bc..54612701 100644 --- a/openequivariance/openequivariance/benchmark/test_buffers.py +++ b/openequivariance/openequivariance/benchmark/test_buffers.py @@ -127,6 +127,61 @@ def get_random_buffers_double_backward( ) +def get_random_buffers_triple_backward( + tpp: TPProblem, batch_size: int, prng_seed: int +): + rng = np.random.default_rng(prng_seed) + + in1 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([batch_size, tpp.weight_numel]) + ) + weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_dgrad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_tgrad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + + in1_dgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_dgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_tgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + in1_tgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_tgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + + return ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) + + def get_random_buffers_forward_conv( tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int ): @@ -225,3 +280,58 @@ def get_random_buffers_double_backward_conv( in2_grad, out_double_grad, ) + + +def get_random_buffers_triple_backward_conv( + tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int +): + rng = np.random.default_rng(prng_seed) + + in1 = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([edge_count, tpp.weight_numel]) + ) + weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_dgrad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_tgrad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + + in1_dgrad = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_dgrad = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_tgrad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + in1_tgrad = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_tgrad = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + + return ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) diff --git a/tests/batch_test.py b/tests/batch_test.py index ff1cd1ce..c58269b4 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -7,8 +7,12 @@ correctness_backward, correctness_double_backward, correctness_forward, + correctness_triple_backward, +) +from openequivariance.benchmark.test_buffers import ( + get_random_buffers_forward, + get_random_buffers_triple_backward, ) -from openequivariance.benchmark.test_buffers import get_random_buffers_forward from openequivariance.benchmark.problems import ( diffdock_problems, e3nn_torch_tetris_poly_problems, @@ -96,6 +100,207 @@ def test_tp_double_bwd(self, tp_and_problem): self.check_result(result, "weights_grad") +class TPTripleBackwardCorrectness: + def thresh(self, direction): + return {"triple_bwd": 5e-4}[direction] + + def check_result(self, result, fieldname): + with check: + error = result[fieldname]["diff_Linf_norm"] + thresh = result["thresh"] + assert result[fieldname]["pass"], ( + f"{fieldname} observed error={error:.5f} >= {thresh}" + ) + + @pytest.fixture(scope="class") + def extra_tp_constructor_args(self): + return {} + + @pytest.fixture(scope="class") + def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + tp = oeq.TensorProduct(problem, **extra_tp_constructor_args) + return tp, problem + + def test_tp_triple_bwd(self, tp_and_problem): + tp, problem = tp_and_problem + result = correctness_triple_backward( + problem=problem, + test_implementation=tp, + reference_implementation=None, + batch_size=4, + correctness_threshold=self.thresh("triple_bwd"), + prng_seed=12345, + ) + + for fieldname in [ + "in1_grad", + "in2_grad", + "weights_grad", + "output_grad", + "in1_double_grad", + "in2_double_grad", + "weights_double_grad", + ]: + self.check_result(result, fieldname) + + +class TestTripleBackwardUVUSingleIrrep(TPTripleBackwardCorrectness): + def id_func(m, i): + return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" + + @pytest.fixture( + params=[((2, 1, 2), (1, 1, 1))], + ids=lambda x: TestTripleBackwardUVUSingleIrrep.id_func(x[0], x[1]), + scope="class", + ) + def problem(self, request, dtype): + m, i = request.param[0], request.param[1] + instructions = [(0, 0, 0, "uvu", True)] + return oeq.TPProblem( + f"{m[0]}x{i[0]}e", + f"{m[1]}x{i[1]}e", + f"{m[2]}x{i[2]}e", + instructions, + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + + +class TestTripleBackwardUVWSingleIrrep(TPTripleBackwardCorrectness): + def id_func(m, i): + return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" + + @pytest.fixture( + params=[((2, 2, 2), (1, 1, 1))], + ids=lambda x: TestTripleBackwardUVWSingleIrrep.id_func(x[0], x[1]), + scope="class", + ) + def problem(self, request, dtype): + m, i = request.param[0], request.param[1] + instructions = [(0, 0, 0, "uvw", True)] + return oeq.TPProblem( + f"{m[0]}x{i[0]}e", + f"{m[1]}x{i[1]}e", + f"{m[2]}x{i[2]}e", + instructions, + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + + +class TestTripleBackwardSharedWeights(TPTripleBackwardCorrectness): + def id_func(m, i): + return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" + + @pytest.fixture( + params=[((2, 1, 2), (1, 1, 1))], + ids=lambda x: TestTripleBackwardSharedWeights.id_func(x[0], x[1]), + scope="class", + ) + def problem(self, request, dtype): + m, i = request.param[0], request.param[1] + instructions = [(0, 0, 0, "uvu", True)] + return oeq.TPProblem( + f"{m[0]}x{i[0]}e", + f"{m[1]}x{i[1]}e", + f"{m[2]}x{i[2]}e", + instructions, + shared_weights=True, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + + +class TestTripleBackwardDirectOps: + @pytest.fixture(scope="class") + def problem(self, dtype, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + return oeq.TPProblem( + "2x1e", + "1x1e", + "2x1e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + + @pytest.fixture(scope="class") + def tp_and_problem(self, problem): + tp = oeq.TensorProduct(problem) + return tp, problem + + def test_direct_tp_double_backward_op_is_differentiable(self, tp_and_problem): + tp, problem = tp_and_problem + buffers = get_random_buffers_triple_backward( + problem, batch_size=3, prng_seed=12345 + ) + ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) = buffers + + weights = tp.reorder_weights_from_e3nn( + weights, has_batch_dim=not problem.shared_weights + ) + weights_dgrad = tp.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not problem.shared_weights + ) + weights_tgrad = tp.reorder_weights_from_e3nn( + weights_tgrad, has_batch_dim=not problem.shared_weights + ) + + tensors = [ + torch.tensor(arr, device="cuda", requires_grad=True) + for arr in [ + in1, + in2, + weights, + out_grad, + in1_dgrad, + in2_dgrad, + weights_dgrad, + ] + ] + grad_outputs = [ + torch.tensor(arr, device="cuda") + for arr in [in1_tgrad, in2_tgrad, weights_tgrad, out_tgrad] + ] + + outputs = torch.ops.libtorch_tp_jit.jit_tp_double_backward( + tp.kernel, tp.hash, *tensors + ) + grads = torch.autograd.grad( + outputs=outputs, + inputs=tensors, + grad_outputs=grad_outputs, + allow_unused=True, + ) + + assert len(grads) == 7 + assert all(grad is not None for grad in grads) + + class TestProductionModels(TPCorrectness): production_model_tpps = ( mace_problems() diff --git a/tests/conv_test.py b/tests/conv_test.py index 8471e593..19dd3789 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -10,7 +10,12 @@ correctness_backward_conv, correctness_double_backward_conv, correctness_forward_conv, + correctness_triple_backward_conv, ) +from openequivariance.benchmark.test_buffers import ( + get_random_buffers_triple_backward_conv, +) +from openequivariance.core.ConvolutionBase import CoordGraph from itertools import product import torch @@ -40,6 +45,14 @@ def graph(request): return graph +@pytest.fixture(scope="module") +def small_graph(): + coords = np.zeros((3, 3), dtype=np.float32) + rows = np.array([0, 1, 1, 2], dtype=np.int64) + cols = np.array([1, 0, 2, 1], dtype=np.int64) + return CoordGraph(coords, rows, cols, "small") + + @pytest.fixture(scope="module") def with_jax(request): return request.config.getoption("--jax") @@ -138,6 +151,179 @@ def test_tp_double_bwd(self, conv_object, graph): self.check_result(result, "weights_grad") +class ConvTripleBackwardCorrectness: + def thresh(self, direction): + return {"triple_bwd": 5e-4}[direction] + + def check_result(self, result, fieldname): + with check: + error = result[fieldname]["diff_Linf_norm"] + thresh = result["thresh"] + assert result[fieldname]["pass"], ( + f"{fieldname} observed error={error:.5f} >= {thresh}" + ) + + @pytest.fixture(scope="class") + def extra_conv_constructor_args(self): + return {} + + @pytest.fixture(params=["atomic", "deterministic"], scope="class") + def conv_object(self, request, problem, extra_conv_constructor_args, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + if request.param == "atomic": + return oeq.TensorProductConv( + problem, deterministic=False, **extra_conv_constructor_args + ) + elif request.param == "deterministic": + if not problem.shared_weights: + return oeq.TensorProductConv( + problem, deterministic=True, **extra_conv_constructor_args + ) + else: + pytest.skip("Shared weights not supported with deterministic") + + def test_tp_triple_bwd(self, conv_object, small_graph): + if conv_object is None: + pytest.skip("'conv_object' fixture returned None, skipping") + + result = correctness_triple_backward_conv( + conv_object, + small_graph, + thresh=self.thresh("triple_bwd"), + prng_seed=12345, + reference_implementation=None, + ) + + for fieldname in [ + "in1_grad", + "in2_grad", + "weights_grad", + "output_grad", + "in1_double_grad", + "in2_double_grad", + "weights_double_grad", + ]: + self.check_result(result, fieldname) + + +class TestTripleBackwardConvUVU(ConvTripleBackwardCorrectness): + def id_func(m, i): + return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" + + @pytest.fixture( + params=[((2, 1, 2), (1, 1, 1))], + ids=lambda x: TestTripleBackwardConvUVU.id_func(x[0], x[1]), + scope="class", + ) + def problem(self, request, dtype): + m, i = request.param[0], request.param[1] + instructions = [(0, 0, 0, "uvu", True)] + return oeq.TPProblem( + f"{m[0]}x{i[0]}e", + f"{m[1]}x{i[1]}e", + f"{m[2]}x{i[2]}e", + instructions, + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + + +class TestTripleBackwardConvDirectOps: + @pytest.fixture(scope="class") + def problem(self, dtype, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + return oeq.TPProblem( + "2x1e", + "1x1e", + "2x1e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + + @pytest.fixture(scope="class") + def conv_object(self, problem): + return oeq.TensorProductConv(problem, deterministic=True) + + def test_direct_conv_double_backward_op_is_differentiable( + self, conv_object, small_graph + ): + problem = conv_object.config + buffers = get_random_buffers_triple_backward_conv( + problem, small_graph.node_count, small_graph.nnz, prng_seed=12345 + ) + ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) = buffers + + weights = conv_object.reorder_weights_from_e3nn( + weights, has_batch_dim=not problem.shared_weights + ) + weights_dgrad = conv_object.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not problem.shared_weights + ) + weights_tgrad = conv_object.reorder_weights_from_e3nn( + weights_tgrad, has_batch_dim=not problem.shared_weights + ) + + tensors = [ + torch.tensor(arr, device="cuda", requires_grad=True) + for arr in [ + in1, + in2, + weights, + out_grad, + in1_dgrad, + in2_dgrad, + weights_dgrad, + ] + ] + grad_outputs = [ + torch.tensor(arr, device="cuda") + for arr in [in1_tgrad, in2_tgrad, weights_tgrad, out_tgrad] + ] + rows = torch.tensor(small_graph.rows, device="cuda") + cols = torch.tensor(small_graph.cols, device="cuda") + transpose_perm = torch.tensor(small_graph.transpose_perm, device="cuda") + + outputs = torch.ops.libtorch_tp_jit.jit_conv_double_backward( + conv_object.kernel, + conv_object.hash, + *tensors, + rows, + cols, + conv_object.workspace_buffer, + transpose_perm, + ) + grads = torch.autograd.grad( + outputs=outputs, + inputs=tensors, + grad_outputs=grad_outputs, + allow_unused=True, + ) + + assert len(grads) == 7 + assert all(grad is not None for grad in grads) + + class TestProductionModels(ConvCorrectness): production_model_tpps = ( mace_problems() + diffdock_problems() + [e3tools_problems()[0]] From c0d5712b45d16799e6762902ec7800032b909d28 Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Mon, 1 Jun 2026 22:03:13 -0700 Subject: [PATCH 2/3] better tests, alleged fix for bad deterministic test --- tests/batch_test.py | 104 +++------------------------- tests/conv_test.py | 67 +++--------------- tests/stream_test.py | 160 +++++++++++++++++++++++++++++++++++++++---- 3 files changed, 165 insertions(+), 166 deletions(-) diff --git a/tests/batch_test.py b/tests/batch_test.py index c58269b4..197301f0 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -31,7 +31,12 @@ def dtype(request): class TPCorrectness: def thresh(self, direction): - return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] + return { + "fwd": 1e-5, + "bwd": 3e-4, + "double_bwd": 3e-4, + "triple_bwd": 5e-4, + }[direction] def check_result(self, result, fieldname): with check: @@ -99,32 +104,10 @@ def test_tp_double_bwd(self, tp_and_problem): self.check_result(result, "in2_grad") self.check_result(result, "weights_grad") - -class TPTripleBackwardCorrectness: - def thresh(self, direction): - return {"triple_bwd": 5e-4}[direction] - - def check_result(self, result, fieldname): - with check: - error = result[fieldname]["diff_Linf_norm"] - thresh = result["thresh"] - assert result[fieldname]["pass"], ( - f"{fieldname} observed error={error:.5f} >= {thresh}" - ) - - @pytest.fixture(scope="class") - def extra_tp_constructor_args(self): - return {} - - @pytest.fixture(scope="class") - def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): + def test_tp_triple_bwd(self, tp_and_problem, with_jax): if with_jax: pytest.skip("N/A for JAX") - tp = oeq.TensorProduct(problem, **extra_tp_constructor_args) - return tp, problem - - def test_tp_triple_bwd(self, tp_and_problem): tp, problem = tp_and_problem result = correctness_triple_backward( problem=problem, @@ -147,78 +130,6 @@ def test_tp_triple_bwd(self, tp_and_problem): self.check_result(result, fieldname) -class TestTripleBackwardUVUSingleIrrep(TPTripleBackwardCorrectness): - def id_func(m, i): - return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" - - @pytest.fixture( - params=[((2, 1, 2), (1, 1, 1))], - ids=lambda x: TestTripleBackwardUVUSingleIrrep.id_func(x[0], x[1]), - scope="class", - ) - def problem(self, request, dtype): - m, i = request.param[0], request.param[1] - instructions = [(0, 0, 0, "uvu", True)] - return oeq.TPProblem( - f"{m[0]}x{i[0]}e", - f"{m[1]}x{i[1]}e", - f"{m[2]}x{i[2]}e", - instructions, - shared_weights=False, - internal_weights=False, - irrep_dtype=dtype, - weight_dtype=dtype, - ) - - -class TestTripleBackwardUVWSingleIrrep(TPTripleBackwardCorrectness): - def id_func(m, i): - return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" - - @pytest.fixture( - params=[((2, 2, 2), (1, 1, 1))], - ids=lambda x: TestTripleBackwardUVWSingleIrrep.id_func(x[0], x[1]), - scope="class", - ) - def problem(self, request, dtype): - m, i = request.param[0], request.param[1] - instructions = [(0, 0, 0, "uvw", True)] - return oeq.TPProblem( - f"{m[0]}x{i[0]}e", - f"{m[1]}x{i[1]}e", - f"{m[2]}x{i[2]}e", - instructions, - shared_weights=False, - internal_weights=False, - irrep_dtype=dtype, - weight_dtype=dtype, - ) - - -class TestTripleBackwardSharedWeights(TPTripleBackwardCorrectness): - def id_func(m, i): - return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" - - @pytest.fixture( - params=[((2, 1, 2), (1, 1, 1))], - ids=lambda x: TestTripleBackwardSharedWeights.id_func(x[0], x[1]), - scope="class", - ) - def problem(self, request, dtype): - m, i = request.param[0], request.param[1] - instructions = [(0, 0, 0, "uvu", True)] - return oeq.TPProblem( - f"{m[0]}x{i[0]}e", - f"{m[1]}x{i[1]}e", - f"{m[2]}x{i[2]}e", - instructions, - shared_weights=True, - internal_weights=False, - irrep_dtype=dtype, - weight_dtype=dtype, - ) - - class TestTripleBackwardDirectOps: @pytest.fixture(scope="class") def problem(self, dtype, with_jax): @@ -446,6 +357,7 @@ def thresh(self, direction): "fwd": 1e-5, "bwd": 5e-4, # Expect higher errors for shared weights "double_bwd": 5e-4, + "triple_bwd": 5e-4, }[direction] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") diff --git a/tests/conv_test.py b/tests/conv_test.py index 19dd3789..e57a2109 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -60,7 +60,12 @@ def with_jax(request): class ConvCorrectness: def thresh(self, direction): - return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction] + return { + "fwd": 3e-4, + "bwd": 3e-4, + "double_bwd": 3e-4, + "triple_bwd": 5e-4, + }[direction] def check_result(self, result, fieldname): with check: @@ -150,47 +155,16 @@ def test_tp_double_bwd(self, conv_object, graph): self.check_result(result, "in2_grad") self.check_result(result, "weights_grad") - -class ConvTripleBackwardCorrectness: - def thresh(self, direction): - return {"triple_bwd": 5e-4}[direction] - - def check_result(self, result, fieldname): - with check: - error = result[fieldname]["diff_Linf_norm"] - thresh = result["thresh"] - assert result[fieldname]["pass"], ( - f"{fieldname} observed error={error:.5f} >= {thresh}" - ) - - @pytest.fixture(scope="class") - def extra_conv_constructor_args(self): - return {} - - @pytest.fixture(params=["atomic", "deterministic"], scope="class") - def conv_object(self, request, problem, extra_conv_constructor_args, with_jax): + def test_tp_triple_bwd(self, conv_object, graph, with_jax): if with_jax: pytest.skip("N/A for JAX") - if request.param == "atomic": - return oeq.TensorProductConv( - problem, deterministic=False, **extra_conv_constructor_args - ) - elif request.param == "deterministic": - if not problem.shared_weights: - return oeq.TensorProductConv( - problem, deterministic=True, **extra_conv_constructor_args - ) - else: - pytest.skip("Shared weights not supported with deterministic") - - def test_tp_triple_bwd(self, conv_object, small_graph): if conv_object is None: pytest.skip("'conv_object' fixture returned None, skipping") result = correctness_triple_backward_conv( conv_object, - small_graph, + graph, thresh=self.thresh("triple_bwd"), prng_seed=12345, reference_implementation=None, @@ -208,30 +182,6 @@ def test_tp_triple_bwd(self, conv_object, small_graph): self.check_result(result, fieldname) -class TestTripleBackwardConvUVU(ConvTripleBackwardCorrectness): - def id_func(m, i): - return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" - - @pytest.fixture( - params=[((2, 1, 2), (1, 1, 1))], - ids=lambda x: TestTripleBackwardConvUVU.id_func(x[0], x[1]), - scope="class", - ) - def problem(self, request, dtype): - m, i = request.param[0], request.param[1] - instructions = [(0, 0, 0, "uvu", True)] - return oeq.TPProblem( - f"{m[0]}x{i[0]}e", - f"{m[1]}x{i[1]}e", - f"{m[2]}x{i[2]}e", - instructions, - shared_weights=False, - internal_weights=False, - irrep_dtype=dtype, - weight_dtype=dtype, - ) - - class TestTripleBackwardConvDirectOps: @pytest.fixture(scope="class") def problem(self, dtype, with_jax): @@ -430,6 +380,7 @@ def thresh(self, direction): "fwd": 1e-5, "bwd": 7.5e-2, # Expect higher errors for shared weights "double_bwd": 5e-1, + "triple_bwd": 5e-1, }[direction] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") diff --git a/tests/stream_test.py b/tests/stream_test.py index 42ac4dd2..5e151ec6 100644 --- a/tests/stream_test.py +++ b/tests/stream_test.py @@ -91,6 +91,65 @@ def conv_buffers(edge_index, tpp, gen): return (X, Y, W, edge_index[0], edge_index[1]) +@pytest.fixture +def conv_det_buffers(edge_index, tpp, gen): + edge_index, _ = edge_index.sort_by("row") + _, sender_perm = edge_index.sort_by("col") + X = torch.rand( + edge_index.num_rows, tpp.irreps_in1.dim, device="cuda", generator=gen + ) + Y = torch.rand( + edge_index.num_cols, tpp.irreps_in2.dim, device="cuda", generator=gen + ) + W = torch.rand(edge_index.num_cols, tpp.weight_numel, device="cuda", generator=gen) + return (X, Y, W, edge_index[0], edge_index[1], sender_perm) + + +def _none_to_zeros(values, refs): + return tuple( + ref * 0 if value is None else value for value, ref in zip(values, refs) + ) + + +def _triple_backward_from_output(out, X, Y, W): + out_grad = torch.ones_like(out).requires_grad_(True) + in1_dgrad = torch.ones_like(X).requires_grad_(True) + in2_dgrad = torch.ones_like(Y).requires_grad_(True) + w_dgrad = torch.ones_like(W).requires_grad_(True) + + in1_grad, in2_grad, w_grad = torch.autograd.grad( + outputs=out, + inputs=(X, Y, W), + grad_outputs=out_grad, + create_graph=True, + retain_graph=True, + ) + + double_grads = torch.autograd.grad( + outputs=(in1_grad, in2_grad, w_grad), + inputs=(X, Y, W, out_grad), + grad_outputs=(in1_dgrad, in2_dgrad, w_dgrad), + create_graph=True, + retain_graph=True, + allow_unused=True, + ) + double_grads = _none_to_zeros(double_grads, (X, Y, W, out_grad)) + + triple_grads = torch.autograd.grad( + outputs=double_grads, + inputs=(X, Y, W, out_grad, in1_dgrad, in2_dgrad, w_dgrad), + grad_outputs=( + torch.ones_like(X), + torch.ones_like(Y), + torch.ones_like(W), + torch.ones_like(out), + ), + allow_unused=True, + ) + + return sum(torch.norm(grad) for grad in triple_grads if grad is not None) + + @pytest.fixture def oeq_tp_fwd(tpp, tp_buffers): tp_oeq = TensorProduct(tpp) @@ -160,6 +219,29 @@ def double_backward_fn(X, Y, W): ) +@pytest.fixture +def oeq_tp_triple_bwd(tpp, tp_buffers): + tp_oeq = TensorProduct(tpp) + + def triple_backward_fn(X, Y, W): + X.requires_grad_(True) + Y.requires_grad_(True) + W.requires_grad_(True) + out = tp_oeq(X, Y, W) + return _triple_backward_from_output(out, X, Y, W) + + return Executable( + triple_backward_fn, + tp_buffers, + [ + KE("forward", 1), + KE("backward", 4), + KE("double_backward_A", 5), + KE("double_backward_B", 5), + ], + ) + + @pytest.fixture def oeq_conv_atomic_fwd(tpp, conv_buffers): tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=False) @@ -239,30 +321,55 @@ def double_backward_fn(X, Y, W, receivers, senders): @pytest.fixture -def oeq_conv_det_fwd(tpp, conv_buffers): +def oeq_conv_atomic_triple_bwd(tpp, conv_buffers): tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=False) - return Executable(tp_conv, conv_buffers, [KE("forward", 1), KE("fixup_forward", 1)]) + def triple_backward_fn(X, Y, W, receivers, senders): + X.requires_grad_(True) + Y.requires_grad_(True) + W.requires_grad_(True) + out = tp_conv(X, Y, W, receivers, senders) + return _triple_backward_from_output(out, X, Y, W) + + return Executable( + triple_backward_fn, + conv_buffers, + [ + KE("forward", 1), + KE("backward", 4), + KE("double_backward_A", 5), + KE("double_backward_B", 5), + ], + ) @pytest.fixture -def oeq_conv_det_bwd(tpp, conv_buffers): - tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=False) +def oeq_conv_det_fwd(tpp, conv_det_buffers): + tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=True) + + return Executable( + tp_conv, conv_det_buffers, [KE("forward", 1), KE("fixup_forward", 1)] + ) + + +@pytest.fixture +def oeq_conv_det_bwd(tpp, conv_det_buffers): + tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=True) # Set up backward-executing callable - def backward_fn(X, Y, W, receivers, senders): + def backward_fn(X, Y, W, receivers, senders, sender_perm): X.requires_grad_(True) Y.requires_grad_(True) W.requires_grad_(True) output = tp_conv( - X, Y, W, receivers, senders + X, Y, W, receivers, senders, sender_perm ).sum() # Scalar output for backward output.backward() return output return Executable( backward_fn, - conv_buffers, + conv_det_buffers, [ KE("forward", 1), KE("fixup_forward", 1), @@ -273,16 +380,16 @@ def backward_fn(X, Y, W, receivers, senders): @pytest.fixture -def oeq_conv_det_double_bwd(tpp, conv_buffers): - tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=False) +def oeq_conv_det_double_bwd(tpp, conv_det_buffers): + tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=True) - def double_backward_fn(X, Y, W, receivers, senders): + def double_backward_fn(X, Y, W, receivers, senders, sender_perm): # First forward pass X.requires_grad_(True) Y.requires_grad_(True) W.requires_grad_(True) - out = tp_conv(X, Y, W, receivers, senders) + out = tp_conv(X, Y, W, receivers, senders, sender_perm) out_grad = out.clone().detach().requires_grad_(True) # First backward (gradients w.r.t inputs) @@ -308,7 +415,7 @@ def double_backward_fn(X, Y, W, receivers, senders): return Executable( double_backward_fn, - conv_buffers, + conv_det_buffers, [ KE("forward", 1), KE("fixup_forward", 2), @@ -321,17 +428,46 @@ def double_backward_fn(X, Y, W, receivers, senders): ) +@pytest.fixture +def oeq_conv_det_triple_bwd(tpp, conv_det_buffers): + tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=True) + + def triple_backward_fn(X, Y, W, receivers, senders, sender_perm): + X.requires_grad_(True) + Y.requires_grad_(True) + W.requires_grad_(True) + out = tp_conv(X, Y, W, receivers, senders, sender_perm) + return _triple_backward_from_output(out, X, Y, W) + + return Executable( + triple_backward_fn, + conv_det_buffers, + [ + KE("forward", 1), + KE("fixup_forward", 6), + KE("backward", 4), + KE("fixup_backward", 4), + KE("double_backward_A", 5), + KE("double_backward_B", 5), + KE("fixup_double_backwardB", 5), + ], + ) + + @pytest.fixture( params=[ "oeq_tp_fwd", "oeq_tp_bwd", "oeq_tp_double_bwd", + "oeq_tp_triple_bwd", "oeq_conv_atomic_fwd", "oeq_conv_atomic_bwd", "oeq_conv_atomic_double_bwd", + "oeq_conv_atomic_triple_bwd", "oeq_conv_det_fwd", "oeq_conv_det_bwd", "oeq_conv_det_double_bwd", + "oeq_conv_det_triple_bwd", ], ) def executable(request): From f54d389e127668a42c724bd3c2a25d4a88fe1425 Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Mon, 1 Jun 2026 22:06:47 -0700 Subject: [PATCH 3/3] ruff --- .../openequivariance/_torch/NPDoubleBackwardMixin.py | 3 +-- openequivariance/openequivariance/benchmark/test_buffers.py | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py index 702295d5..4a74226e 100644 --- a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py +++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py @@ -3,8 +3,7 @@ def _none_to_zeros(values, refs): return tuple( - ref * 0 if value is None else value - for value, ref in zip(values, refs) + ref * 0 if value is None else value for value, ref in zip(values, refs) ) diff --git a/openequivariance/openequivariance/benchmark/test_buffers.py b/openequivariance/openequivariance/benchmark/test_buffers.py index 54612701..252c2064 100644 --- a/openequivariance/openequivariance/benchmark/test_buffers.py +++ b/openequivariance/openequivariance/benchmark/test_buffers.py @@ -127,9 +127,7 @@ def get_random_buffers_double_backward( ) -def get_random_buffers_triple_backward( - tpp: TPProblem, batch_size: int, prng_seed: int -): +def get_random_buffers_triple_backward(tpp: TPProblem, batch_size: int, prng_seed: int): rng = np.random.default_rng(prng_seed) in1 = np.array(