Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/tests_and_benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ To set up an editable install and run our tests, use the following code:
pytest --jax tests/example_test.py
pytest --jax tests/batch_test.py
pytest --jax tests/conv_test.py
pytest --jax tests/vmap_test.py

Browse the ``tests`` directory to run specific components.

Expand Down
5 changes: 5 additions & 0 deletions openequivariance/openequivariance/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tempfile

from enum import IntEnum
import hashlib


class DTypeEnum(IntEnum):
Expand Down Expand Up @@ -199,3 +200,7 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]):
time_millis[i] = kernel_time

return time_millis


def hash_str_64(s: str) -> int:
return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big")
78 changes: 42 additions & 36 deletions openequivariance/openequivariance/extension/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ template<typename JIT_IMPL>
class __attribute__ ((visibility ("default"))) JITConvImpl {
public:
JIT_IMPL jit;
KernelLaunchConfig forward_config;
KernelLaunchConfig backward_config;
KernelLaunchConfig double_backward_config;

KernelLaunchConfig forward_config_ref;
KernelLaunchConfig backward_config_ref;
KernelLaunchConfig double_backward_config_ref;
int opt_level;

JITConvImpl(
Expand All @@ -27,25 +28,25 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
KernelLaunchConfig double_backward_config_i,
int opt_level_i) :
jit(jit_kernel),
forward_config(forward_config_i),
backward_config(backward_config_i),
double_backward_config(double_backward_config_i),
forward_config_ref(forward_config_i),
backward_config_ref(backward_config_i),
double_backward_config_ref(double_backward_config_i),
opt_level(opt_level_i) {

vector<string> kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"};
jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level);

if(forward_config.smem > 0) {
jit.set_max_smem(0, forward_config.smem);
jit.set_max_smem(4, forward_config.smem);
if(forward_config_ref.smem > 0) {
jit.set_max_smem(0, forward_config_ref.smem);
jit.set_max_smem(4, forward_config_ref.smem);
}

if(backward_config.smem > 0) {
jit.set_max_smem(1, backward_config.smem);
if(backward_config_ref.smem > 0) {
jit.set_max_smem(1, backward_config_ref.smem);
}

if(double_backward_config.smem > 0) {
jit.set_max_smem(5, double_backward_config.smem);
if(double_backward_config_ref.smem > 0) {
jit.set_max_smem(5, double_backward_config_ref.smem);
}
}

Expand Down Expand Up @@ -89,16 +90,16 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
ConvData conv_data = {rows, cols, nnz, node_count};

void *args[] = {&L1_in, &L2_in, &weights, &L3_out, &conv_data, &workspace};
forward_config.hStream = stream;
jit.execute(0, args, forward_config);
jit.execute(0, args, with_stream(forward_config_ref, stream));

if(reinterpret_cast<uint64_t>(workspace) != 0) {
void *fixup_args[] = {&workspace, &L3_out};

KernelLaunchConfig fixup_config;
fixup_config.num_blocks = forward_config.num_blocks;
fixup_config.num_threads = forward_config.num_threads;
fixup_config.smem = 0;
KernelLaunchConfig fixup_config(
forward_config_ref.num_blocks,
forward_config_ref.num_threads,
0
);
fixup_config.hStream = stream;

jit.execute(2, fixup_args, fixup_config);
Expand All @@ -118,16 +119,17 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {

ConvData conv_data = {rows, cols, nnz, node_count};
void *args[] = {&L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad, &conv_data, &workspace, &transpose_perm};
backward_config.hStream = stream;
jit.execute(1, args, backward_config);
jit.execute(1, args, with_stream(backward_config_ref, stream));

if(reinterpret_cast<uint64_t>(workspace) != 0) {
void *fixup_args[] = {&workspace, &L1_grad};

KernelLaunchConfig fixup_config;
fixup_config.num_blocks = backward_config.num_blocks;
fixup_config.num_threads = backward_config.num_threads;
fixup_config.smem = 0; fixup_config.hStream = stream;
KernelLaunchConfig fixup_config(
backward_config_ref.num_blocks,
backward_config_ref.num_threads,
0
);
fixup_config.hStream = stream;

jit.execute(3, fixup_args, fixup_config);
}
Expand All @@ -147,24 +149,28 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
&L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
&L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm
};
double_backward_config.hStream = stream;
jit.execute(4, args, forward_config);

jit.execute(4, args, with_stream(forward_config_ref, stream));
if(reinterpret_cast<uint64_t>(wspace) != 0) {
void *fixup_args[] = {&wspace, &L3_dgrad};
KernelLaunchConfig fixup_config;
fixup_config.num_blocks = forward_config.num_blocks;
fixup_config.num_threads = forward_config.num_threads;
fixup_config.smem = 0; fixup_config.hStream = stream;
KernelLaunchConfig fixup_config(
forward_config_ref.num_blocks,
forward_config_ref.num_threads,
0
);
fixup_config.hStream = stream;
jit.execute(2, fixup_args, fixup_config);
}

jit.execute(5, args, double_backward_config);
jit.execute(5, args, with_stream(double_backward_config_ref, stream));
if(reinterpret_cast<uint64_t>(wspace) != 0) {
void *fixup_args[] = {&wspace, &L1_grad};
KernelLaunchConfig fixup_config;
fixup_config.num_blocks = double_backward_config.num_blocks;
fixup_config.num_threads = double_backward_config.num_threads;
fixup_config.smem = 0; fixup_config.hStream = stream;
KernelLaunchConfig fixup_config(
double_backward_config_ref.num_blocks,
double_backward_config_ref.num_threads,
0
);
fixup_config.hStream = stream;
jit.execute(6, fixup_args, fixup_config);
}
}
Expand Down
38 changes: 20 additions & 18 deletions openequivariance/openequivariance/extension/tensorproducts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ template<typename JIT_IMPL>
class __attribute__ ((visibility ("default"))) JITTPImpl {
public:
JIT_IMPL jit;
KernelLaunchConfig forward_config, backward_config, double_backward_config;

// Configs are suffixed with _ref because they
// need to be copied and modified with the stream. In-place
// modification not possible due to concurrency requirements.
KernelLaunchConfig forward_config_ref, backward_config_ref, double_backward_config_ref;
int opt_level;

JITTPImpl(
Expand All @@ -20,25 +24,25 @@ class __attribute__ ((visibility ("default"))) JITTPImpl {
KernelLaunchConfig double_backward_config_i,
int opt_level_i) :
jit(jit_kernel),
forward_config(forward_config_i),
backward_config(backward_config_i),
double_backward_config(double_backward_config_i),
forward_config_ref(forward_config_i),
backward_config_ref(backward_config_i),
double_backward_config_ref(double_backward_config_i),
opt_level(opt_level_i) {

vector<string> kernels = {"forward", "backward", "double_backward_A", "double_backward_B"};
jit.compile(kernels, {{}, {}, {}, {}}, opt_level);

if(forward_config.smem > 0) {
jit.set_max_smem(0, forward_config.smem);
jit.set_max_smem(2, forward_config.smem);
if(forward_config_ref.smem > 0) {
jit.set_max_smem(0, forward_config_ref.smem);
jit.set_max_smem(2, forward_config_ref.smem);
}

if(backward_config.smem > 0) {
jit.set_max_smem(1, backward_config.smem);
if(backward_config_ref.smem > 0) {
jit.set_max_smem(1, backward_config_ref.smem);

}
if(double_backward_config.smem > 0) {
jit.set_max_smem(3, double_backward_config.smem);
if(double_backward_config_ref.smem > 0) {
jit.set_max_smem(3, double_backward_config_ref.smem);
}
}

Expand Down Expand Up @@ -77,8 +81,7 @@ class __attribute__ ((visibility ("default"))) JITTPImpl {
Stream stream) {

void *args[] = { &num_products, &L1_in, &L2_in, &L3_out, &weights};
forward_config.hStream = stream;
jit.execute(0, args, forward_config);
jit.execute(0, args, with_stream(forward_config_ref, stream));
}

void backward(
Expand All @@ -88,8 +91,7 @@ class __attribute__ ((visibility ("default"))) JITTPImpl {
void* weight, void* weight_grad,
void* L3_grad, Stream stream) {
void *args[] = { &num_products, &L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad};
backward_config.hStream = stream;
jit.execute(1, args, backward_config);
jit.execute(1, args, with_stream(backward_config_ref, stream));
}

void double_backward(
Expand All @@ -102,9 +104,9 @@ class __attribute__ ((visibility ("default"))) JITTPImpl {
&num_products, &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
&L1_grad, &L2_grad, &W_grad, &L3_dgrad
};
double_backward_config.hStream = stream;
jit.execute(2, args, forward_config);
jit.execute(3, args, double_backward_config);
double_backward_config_ref.hStream = stream;
jit.execute(2, args, with_stream(forward_config_ref, stream));
jit.execute(3, args, with_stream(double_backward_config_ref, stream));
}

~JITTPImpl() = default;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,4 +358,10 @@ class __attribute__((visibility("default"))) CUJITKernel {
}
NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
}
};
};

KernelLaunchConfig with_stream(const KernelLaunchConfig& config, Stream stream) {
KernelLaunchConfig new_config = config;
new_config.hStream = stream;
return new_config;
}
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,10 @@ class __attribute__((visibility("default"))) HIPJITKernel {
~HIPJITKernel() {
HIPRTC_SAFE_CALL(hiprtcDestroyProgram(&prog));
}
};
};

KernelLaunchConfig with_stream(const KernelLaunchConfig& config, Stream stream) {
KernelLaunchConfig new_config = config;
new_config.hStream = stream;
return new_config;
}
3 changes: 2 additions & 1 deletion openequivariance/openequivariance/jax/TensorProduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from openequivariance.jax import extlib
from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
from openequivariance.core.utils import hash_str_64
from openequivariance.jax.utils import reorder_jax
from openequivariance.jax.jvp.tp_prim import tp_fwd_p
import json
Expand Down Expand Up @@ -30,7 +31,7 @@ def __init__(self, problem: TPProblem):
"kernel_prop": self.kernelProp,
}
)
self.hash = self.kernel.__hash__()
self.hash = hash_str_64(self.kernel)

self.weight_numel = problem.weight_numel
self.L3_dim = self.config.irreps_out.dim
Expand Down
4 changes: 3 additions & 1 deletion openequivariance/openequivariance/jax/TensorProductConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Optional
from openequivariance.jax import extlib


from openequivariance.core.utils import hash_str_64
from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.core.LoopUnrollConv import LoopUnrollConv
from openequivariance.jax.utils import reorder_jax
Expand Down Expand Up @@ -60,7 +62,7 @@ def __init__(
"kernel_prop": self.kernel_prop,
}
)
self.hash = self.kernel.__hash__()
self.hash = hash_str_64(self.kernel)

self.weight_numel = config.weight_numel
self.L3_dim = self.config.irreps_out.dim
Expand Down
Loading