diff --git a/docs/tests_and_benchmarks.rst b/docs/tests_and_benchmarks.rst index f602ab44..e0be94b6 100644 --- a/docs/tests_and_benchmarks.rst +++ b/docs/tests_and_benchmarks.rst @@ -42,6 +42,7 @@ To set up an editable install and run our tests, use the following code: pytest --jax tests/example_test.py pytest --jax tests/batch_test.py pytest --jax tests/conv_test.py + pytest --jax tests/vmap_test.py Browse the ``tests`` directory to run specific components. diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 50f35bd4..f3aa466c 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -9,6 +9,7 @@ import tempfile from enum import IntEnum +import hashlib class DTypeEnum(IntEnum): @@ -199,3 +200,7 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = kernel_time return time_millis + + +def hash_str_64(s: str) -> int: + return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big") diff --git a/openequivariance/openequivariance/extension/convolution.hpp b/openequivariance/openequivariance/extension/convolution.hpp index 3b2ce1e6..92aa6880 100644 --- a/openequivariance/openequivariance/extension/convolution.hpp +++ b/openequivariance/openequivariance/extension/convolution.hpp @@ -15,9 +15,10 @@ template class __attribute__ ((visibility ("default"))) JITConvImpl { public: JIT_IMPL jit; - KernelLaunchConfig forward_config; - KernelLaunchConfig backward_config; - KernelLaunchConfig double_backward_config; + + KernelLaunchConfig forward_config_ref; + KernelLaunchConfig backward_config_ref; + KernelLaunchConfig double_backward_config_ref; int opt_level; JITConvImpl( @@ -27,25 +28,25 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { KernelLaunchConfig double_backward_config_i, int opt_level_i) : jit(jit_kernel), - forward_config(forward_config_i), - backward_config(backward_config_i), - double_backward_config(double_backward_config_i), + forward_config_ref(forward_config_i), + backward_config_ref(backward_config_i), + double_backward_config_ref(double_backward_config_i), opt_level(opt_level_i) { vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"}; jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level); - if(forward_config.smem > 0) { - jit.set_max_smem(0, forward_config.smem); - jit.set_max_smem(4, forward_config.smem); + if(forward_config_ref.smem > 0) { + jit.set_max_smem(0, forward_config_ref.smem); + jit.set_max_smem(4, forward_config_ref.smem); } - if(backward_config.smem > 0) { - jit.set_max_smem(1, backward_config.smem); + if(backward_config_ref.smem > 0) { + jit.set_max_smem(1, backward_config_ref.smem); } - if(double_backward_config.smem > 0) { - jit.set_max_smem(5, double_backward_config.smem); + if(double_backward_config_ref.smem > 0) { + jit.set_max_smem(5, double_backward_config_ref.smem); } } @@ -89,16 +90,16 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { ConvData conv_data = {rows, cols, nnz, node_count}; void *args[] = {&L1_in, &L2_in, &weights, &L3_out, &conv_data, &workspace}; - forward_config.hStream = stream; - jit.execute(0, args, forward_config); + jit.execute(0, args, with_stream(forward_config_ref, stream)); if(reinterpret_cast(workspace) != 0) { void *fixup_args[] = {&workspace, &L3_out}; - KernelLaunchConfig fixup_config; - fixup_config.num_blocks = forward_config.num_blocks; - fixup_config.num_threads = forward_config.num_threads; - fixup_config.smem = 0; + KernelLaunchConfig fixup_config( + forward_config_ref.num_blocks, + forward_config_ref.num_threads, + 0 + ); fixup_config.hStream = stream; jit.execute(2, fixup_args, fixup_config); @@ -118,16 +119,17 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { ConvData conv_data = {rows, cols, nnz, node_count}; void *args[] = {&L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad, &conv_data, &workspace, &transpose_perm}; - backward_config.hStream = stream; - jit.execute(1, args, backward_config); + jit.execute(1, args, with_stream(backward_config_ref, stream)); if(reinterpret_cast(workspace) != 0) { void *fixup_args[] = {&workspace, &L1_grad}; - KernelLaunchConfig fixup_config; - fixup_config.num_blocks = backward_config.num_blocks; - fixup_config.num_threads = backward_config.num_threads; - fixup_config.smem = 0; fixup_config.hStream = stream; + KernelLaunchConfig fixup_config( + backward_config_ref.num_blocks, + backward_config_ref.num_threads, + 0 + ); + fixup_config.hStream = stream; jit.execute(3, fixup_args, fixup_config); } @@ -147,24 +149,28 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, &L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm }; - double_backward_config.hStream = stream; - jit.execute(4, args, forward_config); + + jit.execute(4, args, with_stream(forward_config_ref, stream)); if(reinterpret_cast(wspace) != 0) { void *fixup_args[] = {&wspace, &L3_dgrad}; - KernelLaunchConfig fixup_config; - fixup_config.num_blocks = forward_config.num_blocks; - fixup_config.num_threads = forward_config.num_threads; - fixup_config.smem = 0; fixup_config.hStream = stream; + KernelLaunchConfig fixup_config( + forward_config_ref.num_blocks, + forward_config_ref.num_threads, + 0 + ); + fixup_config.hStream = stream; jit.execute(2, fixup_args, fixup_config); } - jit.execute(5, args, double_backward_config); + jit.execute(5, args, with_stream(double_backward_config_ref, stream)); if(reinterpret_cast(wspace) != 0) { void *fixup_args[] = {&wspace, &L1_grad}; - KernelLaunchConfig fixup_config; - fixup_config.num_blocks = double_backward_config.num_blocks; - fixup_config.num_threads = double_backward_config.num_threads; - fixup_config.smem = 0; fixup_config.hStream = stream; + KernelLaunchConfig fixup_config( + double_backward_config_ref.num_blocks, + double_backward_config_ref.num_threads, + 0 + ); + fixup_config.hStream = stream; jit.execute(6, fixup_args, fixup_config); } } diff --git a/openequivariance/openequivariance/extension/tensorproducts.hpp b/openequivariance/openequivariance/extension/tensorproducts.hpp index d9040866..ee8def66 100644 --- a/openequivariance/openequivariance/extension/tensorproducts.hpp +++ b/openequivariance/openequivariance/extension/tensorproducts.hpp @@ -10,7 +10,11 @@ template class __attribute__ ((visibility ("default"))) JITTPImpl { public: JIT_IMPL jit; - KernelLaunchConfig forward_config, backward_config, double_backward_config; + + // Configs are suffixed with _ref because they + // need to be copied and modified with the stream. In-place + // modification not possible due to concurrency requirements. + KernelLaunchConfig forward_config_ref, backward_config_ref, double_backward_config_ref; int opt_level; JITTPImpl( @@ -20,25 +24,25 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { KernelLaunchConfig double_backward_config_i, int opt_level_i) : jit(jit_kernel), - forward_config(forward_config_i), - backward_config(backward_config_i), - double_backward_config(double_backward_config_i), + forward_config_ref(forward_config_i), + backward_config_ref(backward_config_i), + double_backward_config_ref(double_backward_config_i), opt_level(opt_level_i) { vector kernels = {"forward", "backward", "double_backward_A", "double_backward_B"}; jit.compile(kernels, {{}, {}, {}, {}}, opt_level); - if(forward_config.smem > 0) { - jit.set_max_smem(0, forward_config.smem); - jit.set_max_smem(2, forward_config.smem); + if(forward_config_ref.smem > 0) { + jit.set_max_smem(0, forward_config_ref.smem); + jit.set_max_smem(2, forward_config_ref.smem); } - if(backward_config.smem > 0) { - jit.set_max_smem(1, backward_config.smem); + if(backward_config_ref.smem > 0) { + jit.set_max_smem(1, backward_config_ref.smem); } - if(double_backward_config.smem > 0) { - jit.set_max_smem(3, double_backward_config.smem); + if(double_backward_config_ref.smem > 0) { + jit.set_max_smem(3, double_backward_config_ref.smem); } } @@ -77,8 +81,7 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { Stream stream) { void *args[] = { &num_products, &L1_in, &L2_in, &L3_out, &weights}; - forward_config.hStream = stream; - jit.execute(0, args, forward_config); + jit.execute(0, args, with_stream(forward_config_ref, stream)); } void backward( @@ -88,8 +91,7 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { void* weight, void* weight_grad, void* L3_grad, Stream stream) { void *args[] = { &num_products, &L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad}; - backward_config.hStream = stream; - jit.execute(1, args, backward_config); + jit.execute(1, args, with_stream(backward_config_ref, stream)); } void double_backward( @@ -102,9 +104,9 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { &num_products, &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, &L1_grad, &L2_grad, &W_grad, &L3_dgrad }; - double_backward_config.hStream = stream; - jit.execute(2, args, forward_config); - jit.execute(3, args, double_backward_config); + double_backward_config_ref.hStream = stream; + jit.execute(2, args, with_stream(forward_config_ref, stream)); + jit.execute(3, args, with_stream(double_backward_config_ref, stream)); } ~JITTPImpl() = default; diff --git a/openequivariance/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/openequivariance/extension/util/backend_cuda.hpp index 4c79faed..84a5e6ff 100644 --- a/openequivariance/openequivariance/extension/util/backend_cuda.hpp +++ b/openequivariance/openequivariance/extension/util/backend_cuda.hpp @@ -358,4 +358,10 @@ class __attribute__((visibility("default"))) CUJITKernel { } NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); } -}; \ No newline at end of file +}; + +KernelLaunchConfig with_stream(const KernelLaunchConfig& config, Stream stream) { + KernelLaunchConfig new_config = config; + new_config.hStream = stream; + return new_config; +} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/util/backend_hip.hpp b/openequivariance/openequivariance/extension/util/backend_hip.hpp index 5ee1a11a..ae5cc008 100644 --- a/openequivariance/openequivariance/extension/util/backend_hip.hpp +++ b/openequivariance/openequivariance/extension/util/backend_hip.hpp @@ -310,4 +310,10 @@ class __attribute__((visibility("default"))) HIPJITKernel { ~HIPJITKernel() { HIPRTC_SAFE_CALL(hiprtcDestroyProgram(&prog)); } -}; \ No newline at end of file +}; + +KernelLaunchConfig with_stream(const KernelLaunchConfig& config, Stream stream) { + KernelLaunchConfig new_config = config; + new_config.hStream = stream; + return new_config; +} \ No newline at end of file diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index bf7a1445..a140674b 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -3,6 +3,7 @@ from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP +from openequivariance.core.utils import hash_str_64 from openequivariance.jax.utils import reorder_jax from openequivariance.jax.jvp.tp_prim import tp_fwd_p import json @@ -30,7 +31,7 @@ def __init__(self, problem: TPProblem): "kernel_prop": self.kernelProp, } ) - self.hash = self.kernel.__hash__() + self.hash = hash_str_64(self.kernel) self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 52102294..2be0fe7a 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -5,6 +5,8 @@ from typing import Optional from openequivariance.jax import extlib + +from openequivariance.core.utils import hash_str_64 from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.jax.utils import reorder_jax @@ -60,7 +62,7 @@ def __init__( "kernel_prop": self.kernel_prop, } ) - self.hash = self.kernel.__hash__() + self.hash = hash_str_64(self.kernel) self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index 9324b785..ca910147 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp from jax.extend import core -from jax.interpreters import mlir, ad +from jax.interpreters import mlir, ad, batching from openequivariance.jax.utils import clean_tensors # ============================================================================== @@ -505,3 +505,120 @@ def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): ad.primitive_transposes[conv_dbwd_p] = conv_dbwd_transpose + +# ============================================================================== +# 14. Batching rules for all primitives +# ============================================================================== + + +def find_batch_size(vector_arg_values, batch_axes): + B = None + for arg, ba in zip(vector_arg_values, batch_axes): + if ba is not None: + if B is None: + B = arg.shape[ba] + else: + assert B == arg.shape[ba], "Batch size mismatch among arguments" + return B + + +def flatten_args(vector_arg_values, batch_axes): + X = vector_arg_values[0] + + if len(X.shape) == 2: + num_nodes = X.shape[0] + elif len(X.shape) == 3: + num_nodes = X.shape[1] + else: + raise ValueError("Unexpected input shape for X: {}".format(X.shape)) + + B = find_batch_size(vector_arg_values, batch_axes) + + new_args = [] + for i, (arg, batch_axis) in enumerate(zip(vector_arg_values, batch_axes)): + if i != len(vector_arg_values) - 2: + if batch_axis is None and arg is not None: + arg = jnp.broadcast_to(arg, (B,) + arg.shape) + elif batch_axis is not None and batch_axis != 0: + arg = jnp.moveaxis(arg, batch_axis, 0) + new_args.append(arg) + + vector_arg_values = new_args + + rows, cols, workspace, sender_perm = vector_arg_values[-4:] + rows_offset, cols_offset, sender_perm_offset = rows, cols, sender_perm + if B > 1: + batch_offsets = (jnp.arange(B) * num_nodes).astype(rows.dtype) + rows_offset = rows + batch_offsets[:, None] + cols_offset = cols + batch_offsets[:, None] + + if sender_perm is not None: + sender_perm_offset = sender_perm + batch_offsets[:, None] + + new_args = [arg.reshape(-1, *arg.shape[2:]) for arg in vector_arg_values[:-4]] + [ + jnp.ravel(rows_offset), + jnp.ravel(cols_offset), + workspace, + jnp.ravel(sender_perm_offset), + ] + + return new_args + + +def unflatten_results(result, batch_size): + if not isinstance(result, tuple) and not isinstance(result, list): + return result.reshape(batch_size, -1, result.shape[-1]), 0 + else: + return tuple(r.reshape(batch_size, -1, r.shape[-1]) for r in result), ( + 0 for _ in result + ) + + +def fwd_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): + B = find_batch_size(vector_arg_values, batch_axes) + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_fwd_p.bind(*new_args, L3_dim=L3_dim, kernel=kernel, hash=hash) + return unflatten_results(result, B) + + +batching.primitive_batchers[conv_fwd_p] = fwd_batch + + +def batch_bwd(vector_arg_values, batch_axes, *, kernel, hash): + B = find_batch_size(vector_arg_values, batch_axes) + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_bwd_p.bind(*new_args, kernel=kernel, hash=hash) + return unflatten_results(result, B) + + +batching.primitive_batchers[conv_bwd_p] = batch_bwd + + +def batch_dbwd(vector_arg_values, batch_axes, *, kernel, hash): + B = find_batch_size(vector_arg_values, batch_axes) + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_dbwd_p.bind(*new_args, kernel=kernel, hash=hash) + return unflatten_results(result, B) + + +batching.primitive_batchers[conv_dbwd_p] = batch_dbwd + + +def fwd_jvp_batch(vector_arg_values, batch_axes, L3_dim, kernel, hash): + B = find_batch_size(vector_arg_values, batch_axes) + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_fwd_jvp_p.bind(*new_args, L3_dim=L3_dim, kernel=kernel, hash=hash) + return unflatten_results(result, B) + + +batching.primitive_batchers[conv_fwd_jvp_p] = fwd_jvp_batch + + +def bwd_jvp_batch(vector_arg_values, batch_axes, *, kernel, hash): + B = find_batch_size(vector_arg_values, batch_axes) + new_args = flatten_args(vector_arg_values, batch_axes) + result = conv_bwd_jvp_p.bind(*new_args, kernel=kernel, hash=hash) + return unflatten_results(result, B) + + +batching.primitive_batchers[conv_bwd_jvp_p] = bwd_jvp_batch diff --git a/tests/vmap_test.py b/tests/vmap_test.py new file mode 100644 index 00000000..e165cf25 --- /dev/null +++ b/tests/vmap_test.py @@ -0,0 +1,82 @@ +import pytest +import os + + +@pytest.fixture +def with_jax(request): + return request.config.getoption("--jax") + + +@pytest.fixture +def ctx(with_jax): + if not with_jax: + pytest.skip("Skipping JAX tests") + os.environ["OEQ_NOTORCH"] = "1" + import openequivariance as oeq + import jax + import jax.numpy as jnp + + key = jax.random.PRNGKey(42) + dim, n_nodes, n_nz = 10, 3, 4 + + problem = oeq.TPProblem( + oeq.Irreps("1x2e"), + oeq.Irreps("1x3e"), + oeq.Irreps("1x2e"), + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + ) + edge = jnp.array([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=jnp.int32) + + X = jax.random.uniform( + key, (dim, n_nodes, problem.irreps_in1.dim), dtype=jnp.float32 + ) + Y = jax.random.uniform(key, (dim, n_nz, problem.irreps_in2.dim), dtype=jnp.float32) + W = jax.random.uniform(key, (dim, n_nz, problem.weight_numel), dtype=jnp.float32) + + return { + "X": X, + "Y": Y, + "W": W, + "r": edge[0], + "c": edge[1], + "conv": oeq.jax.TensorProductConv(problem, deterministic=False), + "jax": jax, + "jnp": jnp, + "dim": dim, + } + + +def verify(ctx, in_axes, args): + jax, jnp = ctx["jax"], ctx["jnp"] + res_vmap = jax.vmap(ctx["conv"].forward, in_axes)(*args) + + res_loop = [] + for i in range(ctx["dim"]): + i_args = [a[i] if ax == 0 else a for a, ax in zip(args, in_axes)] + res_loop.append(ctx["conv"].forward(*i_args)) + + assert jnp.allclose(res_vmap, jnp.stack(res_loop), atol=1e-5) + + +def test_vmap_std(ctx): + verify( + ctx, (0, 0, 0, None, None), (ctx["X"], ctx["Y"], ctx["W"], ctx["r"], ctx["c"]) + ) + + +def test_vmap_bcast_X(ctx): + verify( + ctx, + (None, 0, 0, None, None), + (ctx["X"][0], ctx["Y"], ctx["W"], ctx["r"], ctx["c"]), + ) + + +def test_vmap_bcast_XW(ctx): + verify( + ctx, + (None, 0, None, None, None), + (ctx["X"][0], ctx["Y"], ctx["W"][0], ctx["r"], ctx["c"]), + )