From 58cb42039fc1214898abb094ea776eac7d097576 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 5 Feb 2026 21:53:50 -0800 Subject: [PATCH 1/7] Making progress with vmap. --- .../openequivariance/jax/jvp/conv_prim.py | 68 ++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index 9324b78..a7b19c2 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp from jax.extend import core -from jax.interpreters import mlir, ad +from jax.interpreters import mlir, ad, batching from openequivariance.jax.utils import clean_tensors # ============================================================================== @@ -505,3 +505,69 @@ def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): ad.primitive_transposes[conv_dbwd_p] = conv_dbwd_transpose + +# ============================================================================== +# 14. Batching rules for all primitives +# ============================================================================== + +#def conv_fwd_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): +# X, Y, W, rows, cols, workspace, sender_perm = vector_arg_values + +# B = X.shape[0] +# num_nodes = X.shape[1] + +# rows_offset = rows +# cols_offset = cols +# if B > 1: +# batch_offsets = jnp.arange(B) * num_nodes +# rows_offset = rows + batch_offsets[:, None] +# cols_offset = cols + batch_offsets[:, None] + +# X_flat = X.reshape(-1, *X.shape[2:]) +# Y_flat = Y.reshape(-1, *Y.shape[2:]) +# W_flat = W.reshape(-1, *W.shape[2:]) + +# rows_flat = rows_offset.reshape(-1) +# cols_flat = cols_offset.reshape(-1) + +# return conv_fwd_p.bind( +# X_flat, +# Y_flat, +# W_flat, +# rows_flat, +# cols_flat, +# workspace, +# sender_perm, +# L3_dim=L3_dim, +# kernel=kernel, +# hash=hash +# ).reshape(B, num_nodes, -1), 0 +# batching.primitive_batchers[conv_fwd_p] = conv_fwd_batch + +for primitive in [conv_fwd_p, conv_bwd_p, conv_dbwd_p, conv_fwd_jvp_p, conv_bwd_jvp_p]: + def batch_fn(vector_arg_values, batch_axes, **kwargs): + X = vector_arg_values[0] + B = X.shape[0] + num_nodes = X.shape[1] + + rows, cols, workspace, sender_perm = vector_arg_values[-4:] + rows_offset, cols_offset, sender_perm_offset = rows, cols, sender_perm + if B > 1: + batch_offsets = jnp.arange(B) * num_nodes + rows_offset = rows + batch_offsets[:, None] + cols_offset = cols + batch_offsets[:, None] + sender_perm_offset = (sender_perm + batch_offsets[:, None]).reshape(-1) + + new_args = [ + arg.reshape(-1, *arg.shape[2:]) for arg in vector_arg_values[:-4] + ] + [rows_offset, cols_offset, workspace, sender_perm_offset] + + result = primitive.bind(*new_args, **kwargs) + if isinstance(result, tuple): + return tuple(r.reshape(B, -1, result.shape[-1]) for r in result), 0 + else: + return result.reshape(B, -1, result.shape[-1]), 0 + + batching.primitive_batchers[primitive] = batch_fn + + From 29e9ecae8bd5d22b5f18dc0b3e7d740f21438499 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 6 Feb 2026 20:27:03 -0800 Subject: [PATCH 2/7] More progress on vmap. --- .../openequivariance/jax/jvp/conv_prim.py | 124 ++++++++++-------- tests/vmap_test.py | 64 +++++++++ 2 files changed, 134 insertions(+), 54 deletions(-) create mode 100644 tests/vmap_test.py diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index a7b19c2..f71379d 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -510,64 +510,80 @@ def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): # 14. Batching rules for all primitives # ============================================================================== -#def conv_fwd_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): -# X, Y, W, rows, cols, workspace, sender_perm = vector_arg_values - -# B = X.shape[0] -# num_nodes = X.shape[1] - -# rows_offset = rows -# cols_offset = cols -# if B > 1: -# batch_offsets = jnp.arange(B) * num_nodes -# rows_offset = rows + batch_offsets[:, None] -# cols_offset = cols + batch_offsets[:, None] - -# X_flat = X.reshape(-1, *X.shape[2:]) -# Y_flat = Y.reshape(-1, *Y.shape[2:]) -# W_flat = W.reshape(-1, *W.shape[2:]) - -# rows_flat = rows_offset.reshape(-1) -# cols_flat = cols_offset.reshape(-1) - -# return conv_fwd_p.bind( -# X_flat, -# Y_flat, -# W_flat, -# rows_flat, -# cols_flat, -# workspace, -# sender_perm, -# L3_dim=L3_dim, -# kernel=kernel, -# hash=hash -# ).reshape(B, num_nodes, -1), 0 -# batching.primitive_batchers[conv_fwd_p] = conv_fwd_batch - -for primitive in [conv_fwd_p, conv_bwd_p, conv_dbwd_p, conv_fwd_jvp_p, conv_bwd_jvp_p]: - def batch_fn(vector_arg_values, batch_axes, **kwargs): - X = vector_arg_values[0] - B = X.shape[0] +def flatten_args(vector_arg_values, batch_axes): + X = vector_arg_values[0] + + if len(X.shape) == 2: + num_nodes = X.shape[0] + elif len(X.shape) == 3: num_nodes = X.shape[1] + else: + raise ValueError("Unexpected input shape for X: {}".format(X.shape)) + + B = None + for arg, ba in zip(vector_arg_values, batch_axes): + if ba is not None: + if B is None: + B = arg.shape[ba] + else: + assert B == arg.shape[ba], "Batch size mismatch among arguments" + + new_args = [] + for i, (arg, batch_axis) in enumerate(zip(vector_arg_values, batch_axes)): + if i != len(vector_arg_values) - 2: + if batch_axis is not None: + assert batch_axis == 0, "Batch axis must be 0 for all arguments" + if batch_axis is None and arg is not None: + arg = jnp.broadcast_to(arg, (B,) + arg.shape) + new_args.append(arg) + + vector_arg_values = new_args + + rows, cols, workspace, sender_perm = vector_arg_values[-4:] + rows_offset, cols_offset, sender_perm_offset = rows, cols, sender_perm + if B > 1: + batch_offsets = jnp.arange(B) * num_nodes + rows_offset = rows + batch_offsets[:, None] + cols_offset = cols + batch_offsets[:, None] + + if sender_perm is not None: + sender_perm_offset = (sender_perm + batch_offsets[:, None]) + + new_args = [ + arg.reshape(-1, *arg.shape[2:]) for arg in vector_arg_values[:-4] + ] + [jnp.ravel(rows_offset), jnp.ravel(cols_offset), workspace, jnp.ravel(sender_perm_offset)] + + return new_args + +def unflatten_results(result, batch_size): + if not isinstance(result, tuple) and not isinstance(result, list): + return result.reshape(batch_size, -1, result.shape[-1]), 0 + else: + return tuple(r.reshape(batch_size, -1, r.shape[-1]) for r in result), (0 for _ in result) + + +def fwd_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): + B = vector_arg_values[0].shape[0] + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_fwd_p.bind(*new_args, L3_dim=L3_dim, kernel=kernel, hash=hash) + return unflatten_results(result, B) + +batching.primitive_batchers[conv_fwd_p] = fwd_batch - rows, cols, workspace, sender_perm = vector_arg_values[-4:] - rows_offset, cols_offset, sender_perm_offset = rows, cols, sender_perm - if B > 1: - batch_offsets = jnp.arange(B) * num_nodes - rows_offset = rows + batch_offsets[:, None] - cols_offset = cols + batch_offsets[:, None] - sender_perm_offset = (sender_perm + batch_offsets[:, None]).reshape(-1) - new_args = [ - arg.reshape(-1, *arg.shape[2:]) for arg in vector_arg_values[:-4] - ] + [rows_offset, cols_offset, workspace, sender_perm_offset] +def batch_bwd(vector_arg_values, batch_axes, *, kernel, hash): + B = vector_arg_values[0].shape[0] + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_bwd_p.bind(*new_args, kernel=kernel, hash=hash) + return unflatten_results(result, B) - result = primitive.bind(*new_args, **kwargs) - if isinstance(result, tuple): - return tuple(r.reshape(B, -1, result.shape[-1]) for r in result), 0 - else: - return result.reshape(B, -1, result.shape[-1]), 0 +batching.primitive_batchers[conv_bwd_p] = batch_bwd - batching.primitive_batchers[primitive] = batch_fn +def batch_dbwd(vector_arg_values, batch_axes, *, kernel, hash): + B = vector_arg_values[0].shape[0] + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_dbwd_p.bind(*new_args, kernel=kernel, hash=hash) + return unflatten_results(result, B) +batching.primitive_batchers[conv_dbwd_p] = batch_dbwd \ No newline at end of file diff --git a/tests/vmap_test.py b/tests/vmap_test.py new file mode 100644 index 0000000..a07aaf2 --- /dev/null +++ b/tests/vmap_test.py @@ -0,0 +1,64 @@ +import pytest +import os + +@pytest.fixture +def with_jax(request): + return request.config.getoption("--jax") + +def test_tutorial_vmap(with_jax): + if not with_jax: + pytest.skip("Skipping JAX tutorial when testing PyTorch") + + os.environ["OEQ_NOTORCH"] = "1" + import openequivariance as oeq + import jax + import jax.numpy as jnp + + seed = 42 + key = jax.random.PRNGKey(seed) + + vmap_dim = 10 + batch_size = 1000 + X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e") + instructions = [(0, 0, 0, "uvu", True)] + + problem = oeq.TPProblem( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ) + + edge_index = jnp.array( + [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ], + dtype=jnp.int32, # NOTE: This int32, not int64 + ) + + node_ct, nonzero_ct = 3, 4 + X = jax.random.uniform( + key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jnp.float32 + ) + Y = jax.random.uniform( + key, + shape=(vmap_dim, nonzero_ct, Y_ir.dim), + minval=0.0, + maxval=1.0, + dtype=jnp.float32, + ) + W = jax.random.uniform( + key, + shape=(vmap_dim, nonzero_ct, problem.weight_numel), + minval=0.0, + maxval=1.0, + dtype=jnp.float32, + ) + tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) + Z_vmap = jax.vmap(tp_conv.forward, in_axes=(None, 0, 0, None, None))(X, Y, W, edge_index[0], edge_index[1]) + + Z_loop = jnp.empty_like(Z_vmap) + for i in range(vmap_dim): + Z_loop = tp_conv.forward(X, Y[i], W[i], edge_index[0], edge_index[1]) + Z_vmap_i = Z_vmap[i] + + assert jnp.allclose(Z_vmap, Z_loop, atol=1e-5), "vmap and loop results do not match" + From 118122235c040bf362b8ea50f8192fbca7596212 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 6 Feb 2026 20:40:12 -0800 Subject: [PATCH 3/7] Making progress with a test. --- openequivariance/openequivariance/jax/jvp/conv_prim.py | 2 +- tests/vmap_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index f71379d..cf6272d 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -542,7 +542,7 @@ def flatten_args(vector_arg_values, batch_axes): rows, cols, workspace, sender_perm = vector_arg_values[-4:] rows_offset, cols_offset, sender_perm_offset = rows, cols, sender_perm if B > 1: - batch_offsets = jnp.arange(B) * num_nodes + batch_offsets = (jnp.arange(B) * num_nodes).astype(rows.dtype) rows_offset = rows + batch_offsets[:, None] cols_offset = cols + batch_offsets[:, None] diff --git a/tests/vmap_test.py b/tests/vmap_test.py index a07aaf2..dcf1277 100644 --- a/tests/vmap_test.py +++ b/tests/vmap_test.py @@ -7,7 +7,7 @@ def with_jax(request): def test_tutorial_vmap(with_jax): if not with_jax: - pytest.skip("Skipping JAX tutorial when testing PyTorch") + pytest.skip("Skipping JAX VMAP when testing PyTorch") os.environ["OEQ_NOTORCH"] = "1" import openequivariance as oeq @@ -31,7 +31,7 @@ def test_tutorial_vmap(with_jax): [0, 1, 1, 2], [1, 0, 2, 1], ], - dtype=jnp.int32, # NOTE: This int32, not int64 + dtype=jnp.int32 ) node_ct, nonzero_ct = 3, 4 From 0c9950705cf7fbedadc7c8185004a63f1213ba20 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 6 Feb 2026 21:24:13 -0800 Subject: [PATCH 4/7] Finished writing tests. --- .../openequivariance/jax/jvp/conv_prim.py | 47 +++++++--- tests/vmap_test.py | 88 ++++++++----------- 2 files changed, 73 insertions(+), 62 deletions(-) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index cf6272d..5167c92 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -510,6 +510,16 @@ def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): # 14. Batching rules for all primitives # ============================================================================== +def find_batch_size(vector_arg_values, batch_axes): + B = None + for arg, ba in zip(vector_arg_values, batch_axes): + if ba is not None: + if B is None: + B = arg.shape[ba] + else: + assert B == arg.shape[ba], "Batch size mismatch among arguments" + return B + def flatten_args(vector_arg_values, batch_axes): X = vector_arg_values[0] @@ -520,21 +530,15 @@ def flatten_args(vector_arg_values, batch_axes): else: raise ValueError("Unexpected input shape for X: {}".format(X.shape)) - B = None - for arg, ba in zip(vector_arg_values, batch_axes): - if ba is not None: - if B is None: - B = arg.shape[ba] - else: - assert B == arg.shape[ba], "Batch size mismatch among arguments" + B = find_batch_size(vector_arg_values, batch_axes) new_args = [] for i, (arg, batch_axis) in enumerate(zip(vector_arg_values, batch_axes)): if i != len(vector_arg_values) - 2: - if batch_axis is not None: - assert batch_axis == 0, "Batch axis must be 0 for all arguments" if batch_axis is None and arg is not None: arg = jnp.broadcast_to(arg, (B,) + arg.shape) + elif batch_axis is not None and batch_axis != 0: + arg = jnp.moveaxis(arg, batch_axis, 0) new_args.append(arg) vector_arg_values = new_args @@ -563,7 +567,7 @@ def unflatten_results(result, batch_size): def fwd_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): - B = vector_arg_values[0].shape[0] + B = find_batch_size(vector_arg_values, batch_axes) new_args = flatten_args(vector_arg_values, batch_axes) result = conv_fwd_p.bind(*new_args, L3_dim=L3_dim, kernel=kernel, hash=hash) return unflatten_results(result, B) @@ -572,7 +576,7 @@ def fwd_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): def batch_bwd(vector_arg_values, batch_axes, *, kernel, hash): - B = vector_arg_values[0].shape[0] + B = find_batch_size(vector_arg_values, batch_axes) new_args = flatten_args(vector_arg_values, batch_axes) result = conv_bwd_p.bind(*new_args, kernel=kernel, hash=hash) return unflatten_results(result, B) @@ -581,9 +585,26 @@ def batch_bwd(vector_arg_values, batch_axes, *, kernel, hash): def batch_dbwd(vector_arg_values, batch_axes, *, kernel, hash): - B = vector_arg_values[0].shape[0] + B = find_batch_size(vector_arg_values, batch_axes) new_args = flatten_args(vector_arg_values, batch_axes) result = conv_dbwd_p.bind(*new_args, kernel=kernel, hash=hash) return unflatten_results(result, B) -batching.primitive_batchers[conv_dbwd_p] = batch_dbwd \ No newline at end of file +batching.primitive_batchers[conv_dbwd_p] = batch_dbwd + +def fwd_jvp_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): + B = find_batch_size(vector_arg_values, batch_axes) + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_fwd_jvp_p.bind(*new_args, L3_dim=L3_dim, kernel=kernel, hash=hash) + return unflatten_results(result, B) + +batching.primitive_batchers[conv_fwd_jvp_p] = fwd_jvp_batch + + +def bwd_jvp_batch(vector_arg_values, batch_axes, *, kernel, hash): + B = find_batch_size(vector_arg_values, batch_axes) + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_bwd_jvp_p.bind(*new_args, kernel=kernel, hash=hash) + return unflatten_results(result, B) + +batching.primitive_batchers[conv_bwd_jvp_p] = bwd_jvp_batch \ No newline at end of file diff --git a/tests/vmap_test.py b/tests/vmap_test.py index dcf1277..e20dc19 100644 --- a/tests/vmap_test.py +++ b/tests/vmap_test.py @@ -5,60 +5,50 @@ def with_jax(request): return request.config.getoption("--jax") -def test_tutorial_vmap(with_jax): +@pytest.fixture +def ctx(with_jax): if not with_jax: - pytest.skip("Skipping JAX VMAP when testing PyTorch") - + pytest.skip("Skipping JAX tests") os.environ["OEQ_NOTORCH"] = "1" import openequivariance as oeq import jax import jax.numpy as jnp - seed = 42 - key = jax.random.PRNGKey(seed) - - vmap_dim = 10 - batch_size = 1000 - X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e") - instructions = [(0, 0, 0, "uvu", True)] - + key = jax.random.PRNGKey(42) + dim, n_nodes, n_nz = 10, 3, 4 + problem = oeq.TPProblem( - X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e"), + [(0, 0, 0, "uvu", True)], shared_weights=False, internal_weights=False ) - - edge_index = jnp.array( - [ - [0, 1, 1, 2], - [1, 0, 2, 1], - ], - dtype=jnp.int32 - ) - - node_ct, nonzero_ct = 3, 4 - X = jax.random.uniform( - key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jnp.float32 - ) - Y = jax.random.uniform( - key, - shape=(vmap_dim, nonzero_ct, Y_ir.dim), - minval=0.0, - maxval=1.0, - dtype=jnp.float32, - ) - W = jax.random.uniform( - key, - shape=(vmap_dim, nonzero_ct, problem.weight_numel), - minval=0.0, - maxval=1.0, - dtype=jnp.float32, - ) - tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) - Z_vmap = jax.vmap(tp_conv.forward, in_axes=(None, 0, 0, None, None))(X, Y, W, edge_index[0], edge_index[1]) - - Z_loop = jnp.empty_like(Z_vmap) - for i in range(vmap_dim): - Z_loop = tp_conv.forward(X, Y[i], W[i], edge_index[0], edge_index[1]) - Z_vmap_i = Z_vmap[i] - - assert jnp.allclose(Z_vmap, Z_loop, atol=1e-5), "vmap and loop results do not match" - + edge = jnp.array([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=jnp.int32) + + X = jax.random.uniform(key, (dim, n_nodes, problem.irreps_in1.dim), dtype=jnp.float32) + Y = jax.random.uniform(key, (dim, n_nz, problem.irreps_in2.dim), dtype=jnp.float32) + W = jax.random.uniform(key, (dim, n_nz, problem.weight_numel), dtype=jnp.float32) + + return { + "X": X, "Y": Y, "W": W, "r": edge[0], "c": edge[1], + "conv": oeq.jax.TensorProductConv(problem, deterministic=False), + "jax": jax, "jnp": jnp, "dim": dim + } + +def verify(ctx, in_axes, args): + jax, jnp = ctx["jax"], ctx["jnp"] + res_vmap = jax.vmap(ctx["conv"].forward, in_axes)(*args) + + res_loop = [] + for i in range(ctx["dim"]): + i_args = [a[i] if ax == 0 else a for a, ax in zip(args, in_axes)] + res_loop.append(ctx["conv"].forward(*i_args)) + + assert jnp.allclose(res_vmap, jnp.stack(res_loop), atol=1e-5) + +def test_vmap_std(ctx): + verify(ctx, (0, 0, 0, None, None), (ctx["X"], ctx["Y"], ctx["W"], ctx["r"], ctx["c"])) + +def test_vmap_bcast_X(ctx): + verify(ctx, (None, 0, 0, None, None), (ctx["X"][0], ctx["Y"], ctx["W"], ctx["r"], ctx["c"])) + +def test_vmap_bcast_XW(ctx): + verify(ctx, (None, 0, None, None, None), (ctx["X"][0], ctx["Y"], ctx["W"][0], ctx["r"], ctx["c"])) \ No newline at end of file From d3fc4a5b697ad2fe3f4aab09c33fbd33005d28f9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 6 Feb 2026 21:37:22 -0800 Subject: [PATCH 5/7] More testing. --- .../openequivariance/core/utils.py | 5 ++ .../openequivariance/jax/TensorProduct.py | 3 +- .../openequivariance/jax/TensorProductConv.py | 4 +- .../openequivariance/jax/jvp/conv_prim.py | 40 +++++++++----- tests/vmap_test.py | 54 ++++++++++++++----- 5 files changed, 78 insertions(+), 28 deletions(-) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 50f35bd..f3aa466 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -9,6 +9,7 @@ import tempfile from enum import IntEnum +import hashlib class DTypeEnum(IntEnum): @@ -199,3 +200,7 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = kernel_time return time_millis + + +def hash_str_64(s: str) -> int: + return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big") diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index bf7a144..a140674 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -3,6 +3,7 @@ from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP +from openequivariance.core.utils import hash_str_64 from openequivariance.jax.utils import reorder_jax from openequivariance.jax.jvp.tp_prim import tp_fwd_p import json @@ -30,7 +31,7 @@ def __init__(self, problem: TPProblem): "kernel_prop": self.kernelProp, } ) - self.hash = self.kernel.__hash__() + self.hash = hash_str_64(self.kernel) self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 5210229..2be0fe7 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -5,6 +5,8 @@ from typing import Optional from openequivariance.jax import extlib + +from openequivariance.core.utils import hash_str_64 from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.jax.utils import reorder_jax @@ -60,7 +62,7 @@ def __init__( "kernel_prop": self.kernel_prop, } ) - self.hash = self.kernel.__hash__() + self.hash = hash_str_64(self.kernel) self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index 5167c92..ca91014 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -507,9 +507,10 @@ def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): ad.primitive_transposes[conv_dbwd_p] = conv_dbwd_transpose # ============================================================================== -# 14. Batching rules for all primitives +# 14. Batching rules for all primitives # ============================================================================== + def find_batch_size(vector_arg_values, batch_axes): B = None for arg, ba in zip(vector_arg_values, batch_axes): @@ -520,6 +521,7 @@ def find_batch_size(vector_arg_values, batch_axes): assert B == arg.shape[ba], "Batch size mismatch among arguments" return B + def flatten_args(vector_arg_values, batch_axes): X = vector_arg_values[0] @@ -551,60 +553,72 @@ def flatten_args(vector_arg_values, batch_axes): cols_offset = cols + batch_offsets[:, None] if sender_perm is not None: - sender_perm_offset = (sender_perm + batch_offsets[:, None]) + sender_perm_offset = sender_perm + batch_offsets[:, None] - new_args = [ - arg.reshape(-1, *arg.shape[2:]) for arg in vector_arg_values[:-4] - ] + [jnp.ravel(rows_offset), jnp.ravel(cols_offset), workspace, jnp.ravel(sender_perm_offset)] + new_args = [arg.reshape(-1, *arg.shape[2:]) for arg in vector_arg_values[:-4]] + [ + jnp.ravel(rows_offset), + jnp.ravel(cols_offset), + workspace, + jnp.ravel(sender_perm_offset), + ] return new_args + def unflatten_results(result, batch_size): if not isinstance(result, tuple) and not isinstance(result, list): return result.reshape(batch_size, -1, result.shape[-1]), 0 else: - return tuple(r.reshape(batch_size, -1, r.shape[-1]) for r in result), (0 for _ in result) + return tuple(r.reshape(batch_size, -1, r.shape[-1]) for r in result), ( + 0 for _ in result + ) def fwd_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): - B = find_batch_size(vector_arg_values, batch_axes) + B = find_batch_size(vector_arg_values, batch_axes) new_args = flatten_args(vector_arg_values, batch_axes) result = conv_fwd_p.bind(*new_args, L3_dim=L3_dim, kernel=kernel, hash=hash) - return unflatten_results(result, B) + return unflatten_results(result, B) + batching.primitive_batchers[conv_fwd_p] = fwd_batch def batch_bwd(vector_arg_values, batch_axes, *, kernel, hash): - B = find_batch_size(vector_arg_values, batch_axes) + B = find_batch_size(vector_arg_values, batch_axes) new_args = flatten_args(vector_arg_values, batch_axes) result = conv_bwd_p.bind(*new_args, kernel=kernel, hash=hash) return unflatten_results(result, B) + batching.primitive_batchers[conv_bwd_p] = batch_bwd def batch_dbwd(vector_arg_values, batch_axes, *, kernel, hash): - B = find_batch_size(vector_arg_values, batch_axes) + B = find_batch_size(vector_arg_values, batch_axes) new_args = flatten_args(vector_arg_values, batch_axes) result = conv_dbwd_p.bind(*new_args, kernel=kernel, hash=hash) return unflatten_results(result, B) + batching.primitive_batchers[conv_dbwd_p] = batch_dbwd + def fwd_jvp_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): - B = find_batch_size(vector_arg_values, batch_axes) + B = find_batch_size(vector_arg_values, batch_axes) new_args = flatten_args(vector_arg_values, batch_axes) result = conv_fwd_jvp_p.bind(*new_args, L3_dim=L3_dim, kernel=kernel, hash=hash) return unflatten_results(result, B) + batching.primitive_batchers[conv_fwd_jvp_p] = fwd_jvp_batch def bwd_jvp_batch(vector_arg_values, batch_axes, *, kernel, hash): - B = find_batch_size(vector_arg_values, batch_axes) + B = find_batch_size(vector_arg_values, batch_axes) new_args = flatten_args(vector_arg_values, batch_axes) result = conv_bwd_jvp_p.bind(*new_args, kernel=kernel, hash=hash) return unflatten_results(result, B) -batching.primitive_batchers[conv_bwd_jvp_p] = bwd_jvp_batch \ No newline at end of file + +batching.primitive_batchers[conv_bwd_jvp_p] = bwd_jvp_batch diff --git a/tests/vmap_test.py b/tests/vmap_test.py index e20dc19..e165cf2 100644 --- a/tests/vmap_test.py +++ b/tests/vmap_test.py @@ -1,10 +1,12 @@ import pytest import os + @pytest.fixture def with_jax(request): return request.config.getoption("--jax") + @pytest.fixture def ctx(with_jax): if not with_jax: @@ -16,39 +18,65 @@ def ctx(with_jax): key = jax.random.PRNGKey(42) dim, n_nodes, n_nz = 10, 3, 4 - + problem = oeq.TPProblem( - oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e"), - [(0, 0, 0, "uvu", True)], shared_weights=False, internal_weights=False + oeq.Irreps("1x2e"), + oeq.Irreps("1x3e"), + oeq.Irreps("1x2e"), + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, ) edge = jnp.array([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=jnp.int32) - - X = jax.random.uniform(key, (dim, n_nodes, problem.irreps_in1.dim), dtype=jnp.float32) + + X = jax.random.uniform( + key, (dim, n_nodes, problem.irreps_in1.dim), dtype=jnp.float32 + ) Y = jax.random.uniform(key, (dim, n_nz, problem.irreps_in2.dim), dtype=jnp.float32) W = jax.random.uniform(key, (dim, n_nz, problem.weight_numel), dtype=jnp.float32) - + return { - "X": X, "Y": Y, "W": W, "r": edge[0], "c": edge[1], + "X": X, + "Y": Y, + "W": W, + "r": edge[0], + "c": edge[1], "conv": oeq.jax.TensorProductConv(problem, deterministic=False), - "jax": jax, "jnp": jnp, "dim": dim + "jax": jax, + "jnp": jnp, + "dim": dim, } + def verify(ctx, in_axes, args): jax, jnp = ctx["jax"], ctx["jnp"] res_vmap = jax.vmap(ctx["conv"].forward, in_axes)(*args) - + res_loop = [] for i in range(ctx["dim"]): i_args = [a[i] if ax == 0 else a for a, ax in zip(args, in_axes)] res_loop.append(ctx["conv"].forward(*i_args)) - + assert jnp.allclose(res_vmap, jnp.stack(res_loop), atol=1e-5) + def test_vmap_std(ctx): - verify(ctx, (0, 0, 0, None, None), (ctx["X"], ctx["Y"], ctx["W"], ctx["r"], ctx["c"])) + verify( + ctx, (0, 0, 0, None, None), (ctx["X"], ctx["Y"], ctx["W"], ctx["r"], ctx["c"]) + ) + def test_vmap_bcast_X(ctx): - verify(ctx, (None, 0, 0, None, None), (ctx["X"][0], ctx["Y"], ctx["W"], ctx["r"], ctx["c"])) + verify( + ctx, + (None, 0, 0, None, None), + (ctx["X"][0], ctx["Y"], ctx["W"], ctx["r"], ctx["c"]), + ) + def test_vmap_bcast_XW(ctx): - verify(ctx, (None, 0, None, None, None), (ctx["X"][0], ctx["Y"], ctx["W"][0], ctx["r"], ctx["c"])) \ No newline at end of file + verify( + ctx, + (None, 0, None, None, None), + (ctx["X"][0], ctx["Y"], ctx["W"][0], ctx["r"], ctx["c"]), + ) From 086057cf88de31713de43ef4caca841e5f66c7e4 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 6 Feb 2026 21:39:40 -0800 Subject: [PATCH 6/7] Updated docs. --- docs/tests_and_benchmarks.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tests_and_benchmarks.rst b/docs/tests_and_benchmarks.rst index f602ab4..e0be94b 100644 --- a/docs/tests_and_benchmarks.rst +++ b/docs/tests_and_benchmarks.rst @@ -42,6 +42,7 @@ To set up an editable install and run our tests, use the following code: pytest --jax tests/example_test.py pytest --jax tests/batch_test.py pytest --jax tests/conv_test.py + pytest --jax tests/vmap_test.py Browse the ``tests`` directory to run specific components. From b9c5fbf64052fb62a80b03771cd74e5899f2b17e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 8 Feb 2026 20:39:26 -0800 Subject: [PATCH 7/7] Fixed the stream issues. --- .../extension/convolution.hpp | 78 ++++++++++--------- .../extension/tensorproducts.hpp | 38 ++++----- .../extension/util/backend_cuda.hpp | 8 +- .../extension/util/backend_hip.hpp | 8 +- 4 files changed, 76 insertions(+), 56 deletions(-) diff --git a/openequivariance/openequivariance/extension/convolution.hpp b/openequivariance/openequivariance/extension/convolution.hpp index 3b2ce1e..92aa688 100644 --- a/openequivariance/openequivariance/extension/convolution.hpp +++ b/openequivariance/openequivariance/extension/convolution.hpp @@ -15,9 +15,10 @@ template class __attribute__ ((visibility ("default"))) JITConvImpl { public: JIT_IMPL jit; - KernelLaunchConfig forward_config; - KernelLaunchConfig backward_config; - KernelLaunchConfig double_backward_config; + + KernelLaunchConfig forward_config_ref; + KernelLaunchConfig backward_config_ref; + KernelLaunchConfig double_backward_config_ref; int opt_level; JITConvImpl( @@ -27,25 +28,25 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { KernelLaunchConfig double_backward_config_i, int opt_level_i) : jit(jit_kernel), - forward_config(forward_config_i), - backward_config(backward_config_i), - double_backward_config(double_backward_config_i), + forward_config_ref(forward_config_i), + backward_config_ref(backward_config_i), + double_backward_config_ref(double_backward_config_i), opt_level(opt_level_i) { vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"}; jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level); - if(forward_config.smem > 0) { - jit.set_max_smem(0, forward_config.smem); - jit.set_max_smem(4, forward_config.smem); + if(forward_config_ref.smem > 0) { + jit.set_max_smem(0, forward_config_ref.smem); + jit.set_max_smem(4, forward_config_ref.smem); } - if(backward_config.smem > 0) { - jit.set_max_smem(1, backward_config.smem); + if(backward_config_ref.smem > 0) { + jit.set_max_smem(1, backward_config_ref.smem); } - if(double_backward_config.smem > 0) { - jit.set_max_smem(5, double_backward_config.smem); + if(double_backward_config_ref.smem > 0) { + jit.set_max_smem(5, double_backward_config_ref.smem); } } @@ -89,16 +90,16 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { ConvData conv_data = {rows, cols, nnz, node_count}; void *args[] = {&L1_in, &L2_in, &weights, &L3_out, &conv_data, &workspace}; - forward_config.hStream = stream; - jit.execute(0, args, forward_config); + jit.execute(0, args, with_stream(forward_config_ref, stream)); if(reinterpret_cast(workspace) != 0) { void *fixup_args[] = {&workspace, &L3_out}; - KernelLaunchConfig fixup_config; - fixup_config.num_blocks = forward_config.num_blocks; - fixup_config.num_threads = forward_config.num_threads; - fixup_config.smem = 0; + KernelLaunchConfig fixup_config( + forward_config_ref.num_blocks, + forward_config_ref.num_threads, + 0 + ); fixup_config.hStream = stream; jit.execute(2, fixup_args, fixup_config); @@ -118,16 +119,17 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { ConvData conv_data = {rows, cols, nnz, node_count}; void *args[] = {&L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad, &conv_data, &workspace, &transpose_perm}; - backward_config.hStream = stream; - jit.execute(1, args, backward_config); + jit.execute(1, args, with_stream(backward_config_ref, stream)); if(reinterpret_cast(workspace) != 0) { void *fixup_args[] = {&workspace, &L1_grad}; - KernelLaunchConfig fixup_config; - fixup_config.num_blocks = backward_config.num_blocks; - fixup_config.num_threads = backward_config.num_threads; - fixup_config.smem = 0; fixup_config.hStream = stream; + KernelLaunchConfig fixup_config( + backward_config_ref.num_blocks, + backward_config_ref.num_threads, + 0 + ); + fixup_config.hStream = stream; jit.execute(3, fixup_args, fixup_config); } @@ -147,24 +149,28 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, &L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm }; - double_backward_config.hStream = stream; - jit.execute(4, args, forward_config); + + jit.execute(4, args, with_stream(forward_config_ref, stream)); if(reinterpret_cast(wspace) != 0) { void *fixup_args[] = {&wspace, &L3_dgrad}; - KernelLaunchConfig fixup_config; - fixup_config.num_blocks = forward_config.num_blocks; - fixup_config.num_threads = forward_config.num_threads; - fixup_config.smem = 0; fixup_config.hStream = stream; + KernelLaunchConfig fixup_config( + forward_config_ref.num_blocks, + forward_config_ref.num_threads, + 0 + ); + fixup_config.hStream = stream; jit.execute(2, fixup_args, fixup_config); } - jit.execute(5, args, double_backward_config); + jit.execute(5, args, with_stream(double_backward_config_ref, stream)); if(reinterpret_cast(wspace) != 0) { void *fixup_args[] = {&wspace, &L1_grad}; - KernelLaunchConfig fixup_config; - fixup_config.num_blocks = double_backward_config.num_blocks; - fixup_config.num_threads = double_backward_config.num_threads; - fixup_config.smem = 0; fixup_config.hStream = stream; + KernelLaunchConfig fixup_config( + double_backward_config_ref.num_blocks, + double_backward_config_ref.num_threads, + 0 + ); + fixup_config.hStream = stream; jit.execute(6, fixup_args, fixup_config); } } diff --git a/openequivariance/openequivariance/extension/tensorproducts.hpp b/openequivariance/openequivariance/extension/tensorproducts.hpp index d904086..ee8def6 100644 --- a/openequivariance/openequivariance/extension/tensorproducts.hpp +++ b/openequivariance/openequivariance/extension/tensorproducts.hpp @@ -10,7 +10,11 @@ template class __attribute__ ((visibility ("default"))) JITTPImpl { public: JIT_IMPL jit; - KernelLaunchConfig forward_config, backward_config, double_backward_config; + + // Configs are suffixed with _ref because they + // need to be copied and modified with the stream. In-place + // modification not possible due to concurrency requirements. + KernelLaunchConfig forward_config_ref, backward_config_ref, double_backward_config_ref; int opt_level; JITTPImpl( @@ -20,25 +24,25 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { KernelLaunchConfig double_backward_config_i, int opt_level_i) : jit(jit_kernel), - forward_config(forward_config_i), - backward_config(backward_config_i), - double_backward_config(double_backward_config_i), + forward_config_ref(forward_config_i), + backward_config_ref(backward_config_i), + double_backward_config_ref(double_backward_config_i), opt_level(opt_level_i) { vector kernels = {"forward", "backward", "double_backward_A", "double_backward_B"}; jit.compile(kernels, {{}, {}, {}, {}}, opt_level); - if(forward_config.smem > 0) { - jit.set_max_smem(0, forward_config.smem); - jit.set_max_smem(2, forward_config.smem); + if(forward_config_ref.smem > 0) { + jit.set_max_smem(0, forward_config_ref.smem); + jit.set_max_smem(2, forward_config_ref.smem); } - if(backward_config.smem > 0) { - jit.set_max_smem(1, backward_config.smem); + if(backward_config_ref.smem > 0) { + jit.set_max_smem(1, backward_config_ref.smem); } - if(double_backward_config.smem > 0) { - jit.set_max_smem(3, double_backward_config.smem); + if(double_backward_config_ref.smem > 0) { + jit.set_max_smem(3, double_backward_config_ref.smem); } } @@ -77,8 +81,7 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { Stream stream) { void *args[] = { &num_products, &L1_in, &L2_in, &L3_out, &weights}; - forward_config.hStream = stream; - jit.execute(0, args, forward_config); + jit.execute(0, args, with_stream(forward_config_ref, stream)); } void backward( @@ -88,8 +91,7 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { void* weight, void* weight_grad, void* L3_grad, Stream stream) { void *args[] = { &num_products, &L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad}; - backward_config.hStream = stream; - jit.execute(1, args, backward_config); + jit.execute(1, args, with_stream(backward_config_ref, stream)); } void double_backward( @@ -102,9 +104,9 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { &num_products, &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, &L1_grad, &L2_grad, &W_grad, &L3_dgrad }; - double_backward_config.hStream = stream; - jit.execute(2, args, forward_config); - jit.execute(3, args, double_backward_config); + double_backward_config_ref.hStream = stream; + jit.execute(2, args, with_stream(forward_config_ref, stream)); + jit.execute(3, args, with_stream(double_backward_config_ref, stream)); } ~JITTPImpl() = default; diff --git a/openequivariance/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/openequivariance/extension/util/backend_cuda.hpp index 4c79fae..84a5e6f 100644 --- a/openequivariance/openequivariance/extension/util/backend_cuda.hpp +++ b/openequivariance/openequivariance/extension/util/backend_cuda.hpp @@ -358,4 +358,10 @@ class __attribute__((visibility("default"))) CUJITKernel { } NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); } -}; \ No newline at end of file +}; + +KernelLaunchConfig with_stream(const KernelLaunchConfig& config, Stream stream) { + KernelLaunchConfig new_config = config; + new_config.hStream = stream; + return new_config; +} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/util/backend_hip.hpp b/openequivariance/openequivariance/extension/util/backend_hip.hpp index 5ee1a11..ae5cc00 100644 --- a/openequivariance/openequivariance/extension/util/backend_hip.hpp +++ b/openequivariance/openequivariance/extension/util/backend_hip.hpp @@ -310,4 +310,10 @@ class __attribute__((visibility("default"))) HIPJITKernel { ~HIPJITKernel() { HIPRTC_SAFE_CALL(hiprtcDestroyProgram(&prog)); } -}; \ No newline at end of file +}; + +KernelLaunchConfig with_stream(const KernelLaunchConfig& config, Stream stream) { + KernelLaunchConfig new_config = config; + new_config.hStream = stream; + return new_config; +} \ No newline at end of file