diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ce4adef..e11e074f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,8 @@ option(WITH_TORCH "Enable PyTorch C++ backend" OFF) option(WITH_NINETOOTHED "Enable NineToothed-generated kernels" OFF) +option(WITH_TRITON "Enable Triton-generated kernels" OFF) + # Default OFF until CANN's `extract_host_stub.py` path handling is fixed for # `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed # object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the @@ -302,6 +304,14 @@ if(WITH_NINETOOTHED) set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run NineToothed code generation") endif() +if(WITH_TRITON AND NOT WITH_NVIDIA) + message(FATAL_ERROR "`WITH_TRITON` temporarily requires `WITH_NVIDIA=ON` because Triton AOT temporarily targets CUDA.") +endif() + +if(WITH_TRITON) + set(TRITON_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run Triton AOT code generation") +endif() + if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) diff --git a/scripts/generate_triton_ops.py b/scripts/generate_triton_ops.py new file mode 100644 index 00000000..917981b6 --- /dev/null +++ b/scripts/generate_triton_ops.py @@ -0,0 +1,83 @@ +import argparse +import importlib.util +import pathlib +import shutil +import sys + +_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1] +_OPS_DIR = _PROJECT_DIR / "src" / "triton" / "ops" + + +def _find_op_modules(): + return { + path.parent.name: path + for path in sorted(_OPS_DIR.glob("*/build.py")) + if path.is_file() + } + + +def _build_manifest(output_dir): + return sorted(str(path) for path in pathlib.Path(output_dir).rglob("*.c")) + + +def _write_cmake_manifest(output_dir, sources): + manifest_path = pathlib.Path(output_dir) / "manifest.cmake" + lines = ["set(INFINIOPS_TRITON_SOURCES"] + lines.extend(f' "{source}"' for source in sources) + lines.append(")") + lines.append("") + lines.append(f'set(INFINIOPS_TRITON_INCLUDE_DIRS "{output_dir}")') + lines.append("") + manifest_path.write_text("\n".join(lines) + "\n") + + +def _load_op_module(op): + path = _find_op_modules()[op] + sys.path.insert(0, str(path.parent)) + spec = importlib.util.spec_from_file_location(path.stem, path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module + + +def generate(ops, *, output_dir): + op_modules = _find_op_modules() + unknown_ops = tuple(op for op in ops if op not in op_modules) + + if unknown_ops: + raise ValueError(f"unsupported Triton ops: {', '.join(unknown_ops)}") + + output_dir = pathlib.Path(output_dir) + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(parents=True, exist_ok=True) + + for op in ops: + module = _load_op_module(op) + module.build(output_dir) + + sources = _build_manifest(output_dir) + _write_cmake_manifest(output_dir, sources) + + return sources + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate Triton operator sources for InfiniOps." + ) + parser.add_argument("--output-dir", required=True) + parser.add_argument("--ops", nargs="+", default=tuple(_find_op_modules())) + + return parser.parse_args() + + +def main(): + args = _parse_args() + generate(args.ops, output_dir=args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index e5ebbc0c..832212af 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1154,13 +1154,15 @@ def _index_impl_headers(impl_roots, scan_dirs): return by_operator -def _get_all_ops(devices, with_torch=False, with_ninetoothed=False): +def _get_all_ops(devices, with_torch=False, with_ninetoothed=False, with_triton=False): scan_dirs = set(devices) if with_torch: scan_dirs.add("torch") if with_ninetoothed: scan_dirs.add("ninetoothed") + if with_triton: + scan_dirs.add("triton") ops = {} @@ -1287,6 +1289,12 @@ def _dispatch_gen_batch_size(): help="Include NineToothed backend implementations.", ) + parser.add_argument( + "--with-triton", + action="store_true", + help="Include Triton backend implementations.", + ) + args = parser.parse_args() # Wipe previous outputs so files for ops that have since been removed @@ -1307,6 +1315,7 @@ def _dispatch_gen_batch_size(): args.devices, with_torch=args.with_torch, with_ninetoothed=args.with_ninetoothed, + with_triton=args.with_triton, ) bind_func_names = [] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d3784e68..1006b9fb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -84,6 +84,41 @@ if(WITH_NINETOOTHED) target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES}) endif() +if(WITH_TRITON) + find_package(Python COMPONENTS Interpreter REQUIRED) + + if(TRITON_PYTHON_EXECUTABLE) + set(_triton_python "${TRITON_PYTHON_EXECUTABLE}") + elseif(_TORCH_PYTHON) + set(_triton_python "${_TORCH_PYTHON}") + else() + set(_triton_python "${Python_EXECUTABLE}") + endif() + message(STATUS "Triton codegen Python: ${_triton_python}") + + set(_triton_output_dir "${CMAKE_CURRENT_BINARY_DIR}/triton") + set(_triton_generator_args + "${PROJECT_SOURCE_DIR}/scripts/generate_triton_ops.py" + --output-dir "${_triton_output_dir}") + + execute_process( + COMMAND "${_triton_python}" ${_triton_generator_args} + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + RESULT_VARIABLE _triton_generation_result + ) + + if(NOT _triton_generation_result EQUAL 0) + message(FATAL_ERROR "Generating Triton AOT operator sources failed with `${_triton_python}`. Set `TRITON_PYTHON_EXECUTABLE` to a Python with `triton` and CUDA dependencies installed.") + endif() + + enable_language(C) + + include("${_triton_output_dir}/manifest.cmake") + target_compile_definitions(infiniops PUBLIC WITH_TRITON=1) + target_include_directories(infiniops PRIVATE ${INFINIOPS_TRITON_INCLUDE_DIRS}) + target_sources(infiniops PRIVATE ${INFINIOPS_TRITON_SOURCES}) +endif() + if(WITH_ILUVATAR) set(ILUVATAR_PATTERNS "native/cuda/*.cc" @@ -525,6 +560,10 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS) list(APPEND GENERATOR_ARGS --with-ninetoothed) endif() + if(WITH_TRITON) + list(APPEND GENERATOR_ARGS --with-triton) + endif() + execute_process( COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} @@ -764,6 +803,12 @@ if(GENERATE_PYTHON_BINDINGS) target_include_directories(ops PRIVATE ${INFINIOPS_NINETOOTHED_INCLUDE_DIRS}) endif() + + if(WITH_TRITON) + target_include_directories(ops PRIVATE + ${INFINIOPS_TRITON_INCLUDE_DIRS}) + endif() + target_link_libraries(ops PRIVATE infiniops) # Cambricon generated dispatch is compiled into the Python extension and diff --git a/src/triton/ops/add/add.h b/src/triton/ops/add/add.h new file mode 100644 index 00000000..70b0525c --- /dev/null +++ b/src/triton/ops/add/add.h @@ -0,0 +1,208 @@ +#ifndef INFINI_OPS_TRITON_ADD_H_ +#define INFINI_OPS_TRITON_ADD_H_ + +#include + +#include +#include +#include +#include +#include + +#include "base/add.h" +#include "data_type.h" + +extern "C" { +#include "add/infini_ops_triton_add_bf16.h" +#include "add/infini_ops_triton_add_fp16.h" +#include "add/infini_ops_triton_add_fp32.h" +#include "add/infini_ops_triton_add_fp64.h" +#include "add/infini_ops_triton_add_i16.h" +#include "add/infini_ops_triton_add_i32.h" +#include "add/infini_ops_triton_add_i64.h" +#include "add/infini_ops_triton_add_i8.h" +#include "add/infini_ops_triton_add_u16.h" +#include "add/infini_ops_triton_add_u32.h" +#include "add/infini_ops_triton_add_u64.h" +#include "add/infini_ops_triton_add_u8.h" +} + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + using Add::operator(); + + Operator(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + const int ndim = static_cast(ndim_); + std::vector h_metadata(4 * std::max(ndim, 1), 0); + for (int i = 0; i < ndim; ++i) { + h_metadata[0 * ndim + i] = static_cast(out_shape_[i]); + h_metadata[1 * ndim + i] = static_cast(input_strides_[i]); + h_metadata[2 * ndim + i] = static_cast(other_strides_[i]); + h_metadata[3 * ndim + i] = static_cast(out_strides_[i]); + } + + const size_t bytes = h_metadata.size() * sizeof(int64_t); + cuMemAlloc(&d_metadata_, bytes); + cuMemcpyHtoD(d_metadata_, h_metadata.data(), bytes); + + const size_t stride_bytes = ndim * sizeof(int64_t); + d_out_shape_ = d_metadata_ + 0 * stride_bytes; + d_input_stride_ = d_metadata_ + 1 * stride_bytes; + d_other_stride_ = d_metadata_ + 2 * stride_bytes; + d_out_stride_ = d_metadata_ + 3 * stride_bytes; + } + + ~Operator() { + if (d_metadata_) { + cuMemFree(d_metadata_); + } + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + EnsureLoaded(out.dtype()); + + CUstream stream = static_cast(stream_); + auto x = reinterpret_cast(const_cast(input.data())); + auto y = reinterpret_cast(const_cast(other.data())); + auto o = reinterpret_cast(out.data()); + + int32_t n = static_cast(out.numel()); + int32_t ndim_val = static_cast(ndim_); + int32_t x_contig = static_cast(is_input_contiguous_); + int32_t y_contig = static_cast(is_other_contiguous_); + int32_t out_contig = static_cast(is_out_contiguous_); + + CUresult rc = CUDA_ERROR_INVALID_VALUE; + switch (out.dtype()) { + case DataType::kFloat16: + rc = infini_ops_triton_add_fp16_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kBFloat16: + rc = infini_ops_triton_add_bf16_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kFloat32: + rc = infini_ops_triton_add_fp32_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kFloat64: + rc = infini_ops_triton_add_fp64_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kInt8: + rc = infini_ops_triton_add_i8_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kInt16: + rc = infini_ops_triton_add_i16_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kInt32: + rc = infini_ops_triton_add_i32_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kInt64: + rc = infini_ops_triton_add_i64_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kUInt8: + rc = infini_ops_triton_add_u8_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kUInt16: + rc = infini_ops_triton_add_u16_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kUInt32: + rc = infini_ops_triton_add_u32_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kUInt64: + rc = infini_ops_triton_add_u64_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + default: + assert(false && "triton `Add` does not support this dtype"); + } + + assert(rc == CUDA_SUCCESS && "Triton `Add` launch failed"); + } + + private: + CUdeviceptr d_metadata_{0}; + + CUdeviceptr d_out_shape_{0}; + + CUdeviceptr d_input_stride_{0}; + + CUdeviceptr d_other_stride_{0}; + + CUdeviceptr d_out_stride_{0}; + + static void EnsureLoaded(DataType dtype) { + static std::once_flag fp16, bf16, fp32, fp64, i8, i16, i32, i64, u8, u16, + u32, u64; + switch (dtype) { + case DataType::kFloat16: + std::call_once(fp16, &load_infini_ops_triton_add_fp16); + break; + case DataType::kBFloat16: + std::call_once(bf16, &load_infini_ops_triton_add_bf16); + break; + case DataType::kFloat32: + std::call_once(fp32, &load_infini_ops_triton_add_fp32); + break; + case DataType::kFloat64: + std::call_once(fp64, &load_infini_ops_triton_add_fp64); + break; + case DataType::kInt8: + std::call_once(i8, &load_infini_ops_triton_add_i8); + break; + case DataType::kInt16: + std::call_once(i16, &load_infini_ops_triton_add_i16); + break; + case DataType::kInt32: + std::call_once(i32, &load_infini_ops_triton_add_i32); + break; + case DataType::kInt64: + std::call_once(i64, &load_infini_ops_triton_add_i64); + break; + case DataType::kUInt8: + std::call_once(u8, &load_infini_ops_triton_add_u8); + break; + case DataType::kUInt16: + std::call_once(u16, &load_infini_ops_triton_add_u16); + break; + case DataType::kUInt32: + std::call_once(u32, &load_infini_ops_triton_add_u32); + break; + case DataType::kUInt64: + std::call_once(u64, &load_infini_ops_triton_add_u64); + break; + default: + break; + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/triton/ops/add/add.py b/src/triton/ops/add/add.py new file mode 100644 index 00000000..eb3e3572 --- /dev/null +++ b/src/triton/ops/add/add.py @@ -0,0 +1,52 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + x_ptr, + y_ptr, + out_ptr, + out_shape_ptr, + x_stride_ptr, + y_stride_ptr, + out_stride_ptr, + x_contig, + y_contig, + out_contig, + ndim, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = (pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)).to(tl.int64) + mask = offsets < n_elements + + if (x_contig != 0) and (y_contig != 0) and (out_contig != 0): + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x + y, mask=mask) + else: + x_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + y_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + out_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + tmp = offsets + + for i in range(ndim): + s = tl.load(out_shape_ptr + (ndim - 1 - i)) + d = tmp % s + tmp = tmp // s + x_offs += d * tl.load(x_stride_ptr + (ndim - 1 - i)) + y_offs += d * tl.load(y_stride_ptr + (ndim - 1 - i)) + out_offs += d * tl.load(out_stride_ptr + (ndim - 1 - i)) + + if x_contig != 0: + x_offs = offsets + if y_contig != 0: + y_offs = offsets + if out_contig != 0: + out_offs = offsets + + x = tl.load(x_ptr + x_offs, mask=mask) + y = tl.load(y_ptr + y_offs, mask=mask) + tl.store(out_ptr + out_offs, x + y, mask=mask) diff --git a/src/triton/ops/add/build.py b/src/triton/ops/add/build.py new file mode 100644 index 00000000..28f77872 --- /dev/null +++ b/src/triton/ops/add/build.py @@ -0,0 +1,114 @@ +import pathlib + +from triton.tools.compile import CompileArgs, compile_kernel +from triton.tools import link + +_KERNEL_PATH = pathlib.Path(__file__).parent / "add.py" +_KERNEL_NAME = "kernel" +_DTYPES = ( + "fp16", + "bf16", + "fp32", + "fp64", + "i8", + "i16", + "i32", + "i64", + "u8", + "u16", + "u32", + "u64", +) +_BLOCK_SIZES = (512, 1024) +_NUM_WARPS = 4 +_NUM_STAGES = 3 + + +def _compile_variants(variant_dir, dtype): + out_name = f"infini_ops_triton_add_{dtype}" + headers = [] + for block_size in _BLOCK_SIZES: + aligned_sig = ( + f"*{dtype}:16, *{dtype}:16, *{dtype}:16, " + f"*i64, *i64, *i64, *i64, " + f"i32, i32, i32, i32, i32, {block_size}" + ) + _, files = compile_kernel( + CompileArgs( + path=str(_KERNEL_PATH), + kernel_name=_KERNEL_NAME, + signature=aligned_sig, + grid=f"(n_elements + {block_size} - 1) / {block_size}, 1, 1", + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + out_name=out_name, + out_path=variant_dir / out_name, + target=None, + ) + ) + headers.extend(f for f in files if f.suffix == ".h") + + generic_sig = ( + f"*{dtype}, *{dtype}, *{dtype}, " + f"*i64, *i64, *i64, *i64, " + f"i32, i32, i32, i32, i32, {block_size}" + ) + _, files = compile_kernel( + CompileArgs( + path=str(_KERNEL_PATH), + kernel_name=_KERNEL_NAME, + signature=generic_sig, + grid=f"(n_elements + {block_size} - 1) / {block_size}, 1, 1", + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + out_name=out_name, + out_path=variant_dir / out_name, + target=None, + ) + ) + headers.extend(f for f in files if f.suffix == ".h") + return headers, out_name + + +def _link_one_dtype(variant_dir, headers, out_name): + parser = link.HeaderParser() + for h in headers: + parser.extract_linker_meta(h.read_text()) + + out_base = variant_dir / out_name + first_meta = next(iter(parser.kernels.values()))[0] + backend_prelude = ( + pathlib.Path(link.__file__).parent / "extra" / parser.backend_name / "link.h" + ).read_text() + + algo_decls = [link.make_algo_decls(name, m) for name, m in parser.kernels.items()] + out_base.with_suffix(".h").write_text( + backend_prelude + + "\n".join(algo_decls) + + "\n" + + link.make_get_num_algos_decl(first_meta) + + "\n" + + link.make_global_decl(first_meta) + ) + defs = [ + link.make_kernel_hints_dispatcher(name, m) for name, m in parser.kernels.items() + ] + names = list(parser.kernels.keys()) + src = backend_prelude + src += "#include \n#include \n\n" + src += "\n".join(defs) + "\n" + src += link.make_func_pointers(names, first_meta) + "\n" + src += link.make_get_num_algos_def(first_meta) + "\n" + src += link.make_kernel_meta_const_dispatcher(first_meta) + "\n" + src += link.make_kernel_load_def(names, first_meta) + "\n" + src += link.make_default_algo_kernel(first_meta) + out_base.with_suffix(".c").write_text(src) + + +def build(output_dir: pathlib.Path): + variant_dir = output_dir / "add" + variant_dir.mkdir(parents=True, exist_ok=True) + + for dtype in _DTYPES: + headers, out_name = _compile_variants(variant_dir, dtype) + _link_one_dtype(variant_dir, headers, out_name)