Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import torch


def _none_to_zeros(values, refs):
return tuple(
ref * 0 if value is None else value for value, ref in zip(values, refs)
)


class NumpyDoubleBackwardMixin:
"""
Adds a Numpy double backward method to any TensorProduct
Expand Down Expand Up @@ -43,6 +49,89 @@ def double_backward_cpu(
d.detach().cpu().numpy(),
)

def triple_backward_cpu(
self,
in1,
in2,
out_grad,
weights,
weights_dgrad,
in1_dgrad,
in2_dgrad,
out_tgrad,
weights_tgrad,
in1_tgrad,
in2_tgrad,
):
assert self.torch_op

in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda").requires_grad_(True)
in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda").requires_grad_(True)
weights_dgrad_torch = (
torch.tensor(weights_dgrad).to("cuda").requires_grad_(True)
)
out_tgrad_torch = torch.tensor(out_tgrad).to("cuda")
in1_tgrad_torch = torch.tensor(in1_tgrad).to("cuda")
in2_tgrad_torch = torch.tensor(in2_tgrad).to("cuda")
weights_tgrad_torch = torch.tensor(weights_tgrad).to("cuda")

out_torch = self.forward(in1_torch, in2_torch, weights_torch)
in1_grad, in2_grad, weights_grad = torch.autograd.grad(
outputs=out_torch,
inputs=[in1_torch, in2_torch, weights_torch],
grad_outputs=out_grad_torch,
create_graph=True,
retain_graph=True,
)
double_grads = torch.autograd.grad(
outputs=[in1_grad, in2_grad, weights_grad],
inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch],
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
create_graph=True,
retain_graph=True,
allow_unused=True,
)
double_grads = _none_to_zeros(
double_grads, (in1_torch, in2_torch, weights_torch, out_grad_torch)
)
triple_grads = torch.autograd.grad(
outputs=double_grads,
inputs=[
in1_torch,
in2_torch,
weights_torch,
out_grad_torch,
in1_dgrad_torch,
in2_dgrad_torch,
weights_dgrad_torch,
],
grad_outputs=[
in1_tgrad_torch,
in2_tgrad_torch,
weights_tgrad_torch,
out_tgrad_torch,
],
allow_unused=True,
)
triple_grads = _none_to_zeros(
triple_grads,
(
in1_torch,
in2_torch,
weights_torch,
out_grad_torch,
in1_dgrad_torch,
in2_dgrad_torch,
weights_dgrad_torch,
),
)

return tuple(grad.detach().cpu().numpy() for grad in triple_grads)


class NumpyDoubleBackwardMixinConv:
"""
Expand Down Expand Up @@ -95,3 +184,98 @@ def double_backward_cpu(
c.detach().cpu().numpy(),
d.detach().cpu().numpy(),
)

def triple_backward_cpu(
self,
in1,
in2,
out_grad,
weights,
weights_dgrad,
in1_dgrad,
in2_dgrad,
out_tgrad,
weights_tgrad,
in1_tgrad,
in2_tgrad,
graph,
):
assert self.torch_op

in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO-someday: I wonder if we can combine all of these derivative functions into one to compact this file.

in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda").requires_grad_(True)
in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda").requires_grad_(True)
weights_dgrad_torch = (
torch.tensor(weights_dgrad).to("cuda").requires_grad_(True)
)
out_tgrad_torch = torch.tensor(out_tgrad).to("cuda")
in1_tgrad_torch = torch.tensor(in1_tgrad).to("cuda")
in2_tgrad_torch = torch.tensor(in2_tgrad).to("cuda")
weights_tgrad_torch = torch.tensor(weights_tgrad).to("cuda")

torch_rows = torch.tensor(graph.rows, device="cuda")
torch_cols = torch.tensor(graph.cols, device="cuda")
torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda")

out_torch = self.forward(
in1_torch,
in2_torch,
weights_torch,
torch_rows,
torch_cols,
torch_transpose_perm,
)
in1_grad, in2_grad, weights_grad = torch.autograd.grad(
outputs=out_torch,
inputs=[in1_torch, in2_torch, weights_torch],
grad_outputs=out_grad_torch,
create_graph=True,
retain_graph=True,
)
double_grads = torch.autograd.grad(
outputs=[in1_grad, in2_grad, weights_grad],
inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch],
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
create_graph=True,
retain_graph=True,
allow_unused=True,
)
double_grads = _none_to_zeros(
double_grads, (in1_torch, in2_torch, weights_torch, out_grad_torch)
)
triple_grads = torch.autograd.grad(
outputs=double_grads,
inputs=[
in1_torch,
in2_torch,
weights_torch,
out_grad_torch,
in1_dgrad_torch,
in2_dgrad_torch,
weights_dgrad_torch,
],
grad_outputs=[
in1_tgrad_torch,
in2_tgrad_torch,
weights_tgrad_torch,
out_tgrad_torch,
],
allow_unused=True,
)
triple_grads = _none_to_zeros(
triple_grads,
(
in1_torch,
in2_torch,
weights_torch,
out_grad_torch,
in1_dgrad_torch,
in2_dgrad_torch,
weights_dgrad_torch,
),
)

return tuple(grad.detach().cpu().numpy() for grad in triple_grads)
123 changes: 122 additions & 1 deletion openequivariance/openequivariance/_torch/TensorProduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ def fake_double_backward(kernel, hash, L1_in, L2_in, W, L3_grad, E, F, G):

def register_autograd():
backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward
double_backward_op = torch.ops.libtorch_tp_jit.jit_tp_double_backward

def zero_if_none(grad_output, like):
if grad_output is None:
return torch.zeros_like(like)
return grad_output

def setup_context(ctx, inputs, output):
ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_dim = inputs
Expand All @@ -218,7 +224,7 @@ def setup_context_double_backward(ctx, inputs, output):
ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs

def double_backward(ctx, E, F, G):
result = torch.ops.libtorch_tp_jit.jit_tp_double_backward(
result = double_backward_op(
ctx.kernel,
ctx.hash,
ctx.L1_in,
Expand All @@ -237,6 +243,121 @@ def double_backward(ctx, E, F, G):
setup_context=setup_context_double_backward,
)

def setup_context_triple_backward(ctx, inputs, output):
(
ctx.kernel,
ctx.hash,
ctx.L1_in,
ctx.L2_in,
ctx.weights,
ctx.L3_grad,
ctx.L1_dgrad,
ctx.L2_dgrad,
ctx.W_dgrad,
) = inputs

def triple_backward(ctx, t_L1_grad, t_L2_grad, t_W_grad, t_L3_dgrad):
t_L1_grad = zero_if_none(t_L1_grad, ctx.L1_in)
t_L2_grad = zero_if_none(t_L2_grad, ctx.L2_in)
t_W_grad = zero_if_none(t_W_grad, ctx.weights)
t_L3_dgrad = zero_if_none(t_L3_dgrad, ctx.L3_grad)

g1_L1_dgrad, g1_L2_dgrad, g1_W, g1_L3_grad = double_backward_op(
ctx.kernel,
ctx.hash,
ctx.L1_dgrad,
ctx.L2_dgrad,
ctx.weights,
ctx.L3_grad,
t_L1_grad,
t_L2_grad,
torch.zeros_like(ctx.weights),
)
g2_L1, g2_L2, g2_W_dgrad, g2_L3_grad = double_backward_op(
ctx.kernel,
ctx.hash,
ctx.L1_in,
ctx.L2_in,
ctx.W_dgrad,
ctx.L3_grad,
t_L1_grad,
t_L2_grad,
torch.zeros_like(ctx.W_dgrad),
)
g3_L1_dgrad, g3_L2, g3_W, g3_L3_grad = double_backward_op(
ctx.kernel,
ctx.hash,
ctx.L1_dgrad,
ctx.L2_in,
ctx.weights,
ctx.L3_grad,
torch.zeros_like(ctx.L1_dgrad),
torch.zeros_like(ctx.L2_in),
t_W_grad,
)
g4_L1, g4_L2_dgrad, g4_W, g4_L3_grad = double_backward_op(
ctx.kernel,
ctx.hash,
ctx.L1_in,
ctx.L2_dgrad,
ctx.weights,
ctx.L3_grad,
torch.zeros_like(ctx.L1_in),
torch.zeros_like(ctx.L2_dgrad),
t_W_grad,
)

g5_L1_dgrad, g5_L2, g5_W = backward_op(
ctx.kernel,
ctx.hash,
ctx.L1_dgrad,
ctx.L2_in,
ctx.weights,
t_L3_dgrad,
)
g6_L1, g6_L2_dgrad, g6_W = backward_op(
ctx.kernel,
ctx.hash,
ctx.L1_in,
ctx.L2_dgrad,
ctx.weights,
t_L3_dgrad,
)
g7_L1, g7_L2, g7_W_dgrad = backward_op(
ctx.kernel,
ctx.hash,
ctx.L1_in,
ctx.L2_in,
ctx.W_dgrad,
t_L3_dgrad,
)

grad_L1 = g2_L1 + g4_L1 + g6_L1 + g7_L1
grad_L2 = g2_L2 + g3_L2 + g5_L2 + g7_L2
grad_W = g1_W + g3_W + g4_W + g5_W + g6_W
grad_L3_grad = g1_L3_grad + g2_L3_grad + g3_L3_grad + g4_L3_grad
grad_L1_dgrad = g1_L1_dgrad + g3_L1_dgrad + g5_L1_dgrad
grad_L2_dgrad = g1_L2_dgrad + g4_L2_dgrad + g6_L2_dgrad
grad_W_dgrad = g2_W_dgrad + g7_W_dgrad

return (
None,
None,
grad_L1,
grad_L2,
grad_W,
grad_L3_grad,
grad_L1_dgrad,
grad_L2_dgrad,
grad_W_dgrad,
)

torch.library.register_autograd(
"libtorch_tp_jit::jit_tp_double_backward",
triple_backward,
setup_context=setup_context_triple_backward,
)


def register_autocast():
global torch
Expand Down
Loading
Loading