Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ option(WITH_TORCH "Enable PyTorch C++ backend" OFF)

option(WITH_NINETOOTHED "Enable NineToothed-generated kernels" OFF)

option(WITH_TRITON "Enable Triton-generated kernels" OFF)

# Default OFF until CANN's `extract_host_stub.py` path handling is fixed for
# `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed
# object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the
Expand Down Expand Up @@ -302,6 +304,14 @@ if(WITH_NINETOOTHED)
set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run NineToothed code generation")
endif()

if(WITH_TRITON AND NOT WITH_NVIDIA)
message(FATAL_ERROR "`WITH_TRITON` temporarily requires `WITH_NVIDIA=ON` because Triton AOT temporarily targets CUDA.")
endif()

if(WITH_TRITON)
set(TRITON_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run Triton AOT code generation")
endif()

if(WITH_NVIDIA)
add_compile_definitions(WITH_NVIDIA=1)
enable_language(CUDA)
Expand Down
83 changes: 83 additions & 0 deletions scripts/generate_triton_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse
import importlib.util
import pathlib
import shutil
import sys

_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1]
_OPS_DIR = _PROJECT_DIR / "src" / "triton" / "ops"


def _find_op_modules():
return {
path.parent.name: path
for path in sorted(_OPS_DIR.glob("*/build.py"))
if path.is_file()
}


def _build_manifest(output_dir):
return sorted(str(path) for path in pathlib.Path(output_dir).rglob("*.c"))


def _write_cmake_manifest(output_dir, sources):
manifest_path = pathlib.Path(output_dir) / "manifest.cmake"
lines = ["set(INFINIOPS_TRITON_SOURCES"]
lines.extend(f' "{source}"' for source in sources)
lines.append(")")
lines.append("")
lines.append(f'set(INFINIOPS_TRITON_INCLUDE_DIRS "{output_dir}")')
lines.append("")
manifest_path.write_text("\n".join(lines) + "\n")


def _load_op_module(op):
path = _find_op_modules()[op]
sys.path.insert(0, str(path.parent))
spec = importlib.util.spec_from_file_location(path.stem, path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[spec.name] = module
spec.loader.exec_module(module)

return module


def generate(ops, *, output_dir):
op_modules = _find_op_modules()
unknown_ops = tuple(op for op in ops if op not in op_modules)

if unknown_ops:
raise ValueError(f"unsupported Triton ops: {', '.join(unknown_ops)}")

output_dir = pathlib.Path(output_dir)
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(parents=True, exist_ok=True)

for op in ops:
module = _load_op_module(op)
module.build(output_dir)

sources = _build_manifest(output_dir)
_write_cmake_manifest(output_dir, sources)

return sources


def _parse_args():
parser = argparse.ArgumentParser(
description="Generate Triton operator sources for InfiniOps."
)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--ops", nargs="+", default=tuple(_find_op_modules()))

return parser.parse_args()


def main():
args = _parse_args()
generate(args.ops, output_dir=args.output_dir)


if __name__ == "__main__":
main()
11 changes: 10 additions & 1 deletion scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,13 +1154,15 @@ def _index_impl_headers(impl_roots, scan_dirs):
return by_operator


def _get_all_ops(devices, with_torch=False, with_ninetoothed=False):
def _get_all_ops(devices, with_torch=False, with_ninetoothed=False, with_triton=False):
scan_dirs = set(devices)

if with_torch:
scan_dirs.add("torch")
if with_ninetoothed:
scan_dirs.add("ninetoothed")
if with_triton:
scan_dirs.add("triton")

ops = {}

Expand Down Expand Up @@ -1287,6 +1289,12 @@ def _dispatch_gen_batch_size():
help="Include NineToothed backend implementations.",
)

parser.add_argument(
"--with-triton",
action="store_true",
help="Include Triton backend implementations.",
)

args = parser.parse_args()

# Wipe previous outputs so files for ops that have since been removed
Expand All @@ -1307,6 +1315,7 @@ def _dispatch_gen_batch_size():
args.devices,
with_torch=args.with_torch,
with_ninetoothed=args.with_ninetoothed,
with_triton=args.with_triton,
)

bind_func_names = []
Expand Down
45 changes: 45 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,41 @@ if(WITH_NINETOOTHED)
target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES})
endif()

if(WITH_TRITON)
find_package(Python COMPONENTS Interpreter REQUIRED)

if(TRITON_PYTHON_EXECUTABLE)
set(_triton_python "${TRITON_PYTHON_EXECUTABLE}")
elseif(_TORCH_PYTHON)
set(_triton_python "${_TORCH_PYTHON}")
else()
set(_triton_python "${Python_EXECUTABLE}")
endif()
message(STATUS "Triton codegen Python: ${_triton_python}")

set(_triton_output_dir "${CMAKE_CURRENT_BINARY_DIR}/triton")
set(_triton_generator_args
"${PROJECT_SOURCE_DIR}/scripts/generate_triton_ops.py"
--output-dir "${_triton_output_dir}")

execute_process(
COMMAND "${_triton_python}" ${_triton_generator_args}
WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
RESULT_VARIABLE _triton_generation_result
)

if(NOT _triton_generation_result EQUAL 0)
message(FATAL_ERROR "Generating Triton AOT operator sources failed with `${_triton_python}`. Set `TRITON_PYTHON_EXECUTABLE` to a Python with `triton` and CUDA dependencies installed.")
endif()

enable_language(C)

include("${_triton_output_dir}/manifest.cmake")
target_compile_definitions(infiniops PUBLIC WITH_TRITON=1)
target_include_directories(infiniops PRIVATE ${INFINIOPS_TRITON_INCLUDE_DIRS})
target_sources(infiniops PRIVATE ${INFINIOPS_TRITON_SOURCES})
endif()

if(WITH_ILUVATAR)
set(ILUVATAR_PATTERNS
"native/cuda/*.cc"
Expand Down Expand Up @@ -525,6 +560,10 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS)
list(APPEND GENERATOR_ARGS --with-ninetoothed)
endif()

if(WITH_TRITON)
list(APPEND GENERATOR_ARGS --with-triton)
endif()

execute_process(
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS}
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
Expand Down Expand Up @@ -764,6 +803,12 @@ if(GENERATE_PYTHON_BINDINGS)
target_include_directories(ops PRIVATE
${INFINIOPS_NINETOOTHED_INCLUDE_DIRS})
endif()

if(WITH_TRITON)
target_include_directories(ops PRIVATE
${INFINIOPS_TRITON_INCLUDE_DIRS})
endif()

target_link_libraries(ops PRIVATE infiniops)

# Cambricon generated dispatch is compiled into the Python extension and
Expand Down
Loading
Loading