diff --git a/.github/workflows/docker-base-image-2-8.yml b/.github/workflows/docker-base-image-2-8.yml index f8649303..3a1d97a1 100644 --- a/.github/workflows/docker-base-image-2-8.yml +++ b/.github/workflows/docker-base-image-2-8.yml @@ -2,7 +2,7 @@ name: Docker Base Image CI (PyTorch 2.8) on: push: - branches: [ "base" ] + branches: [ "base_v2.8" ] workflow_dispatch: repository_dispatch: types: [ build_base ] diff --git a/.github/workflows/docker-image-2-8.yml b/.github/workflows/docker-image-2-8.yml index cb5f73d1..4d511a1a 100644 --- a/.github/workflows/docker-image-2-8.yml +++ b/.github/workflows/docker-image-2-8.yml @@ -1,7 +1,7 @@ name: Docker image CI (PyTorch 2.8) on: - pull_request: + push: branches: [ "torch_v2.8" ] workflow_dispatch: diff --git a/Dockerfile.base b/Dockerfile.base index 897b8195..c5f200bc 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -34,7 +34,7 @@ RUN apt -y update && \ python3-dev python-is-python3 libboost-all-dev \ libhdf5-serial-dev python3-pydot libpng-dev libelf-dev pkg-config pip \ python3-venv black libssl-dev libasan5 libubsan1 curl device-tree-compiler wget ninja-build && \ - pip install onnx matplotlib scikit-learn pydot tabulate && pip install --user conan==1.56.0 && rm -rf /var/lib/apt/lists/* + pip install onnx matplotlib scikit-learn pydot tabulate && pip install --user conan==1.56.0 cmake==3.26.4 && rm -rf /var/lib/apt/lists/* # Download RISC-V tool chain RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2023.12.14/riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.12.14-nightly.tar.gz && \ diff --git a/PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp b/PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp new file mode 100644 index 00000000..a0b1395d --- /dev/null +++ b/PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp @@ -0,0 +1,8 @@ +#include "ExtensionDeviceGuardImpl.h" +#include + +namespace c10::extension_device::impl { + +C10_REGISTER_GUARD_IMPL(extension_device, ExtensionDeviceGuardImpl); + +} // namespace c10::extension_device::impl diff --git a/PyTorchSimDevice/ExtensionDeviceGuardImpl.h b/PyTorchSimDevice/ExtensionDeviceGuardImpl.h new file mode 100644 index 00000000..6d35677b --- /dev/null +++ b/PyTorchSimDevice/ExtensionDeviceGuardImpl.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10::extension_device::impl { + +struct ExtensionDeviceGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = DeviceType::PrivateUse1; // ✅ your backend type + + ExtensionDeviceGuardImpl() = default; + + explicit ExtensionDeviceGuardImpl(DeviceType t) { + TORCH_CHECK( + t == static_type, + "ExtensionDeviceGuardImpl initialized with non-extension_device DeviceType: ", + t); + } + + // -------------------------------------------------------------------------- + // 기본적인 device guard (CPU처럼 동작) + // -------------------------------------------------------------------------- + DeviceType type() const override { + return static_type; + } + + Device exchangeDevice(Device d) const override { + TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d); + return d; // nothing to exchange, CPU-like + } + + Device getDevice() const override { + return Device(static_type, 0); + } + + void setDevice(Device d) const override { + TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d); + } + + void uncheckedSetDevice(Device d) const noexcept override {} + + DeviceIndex deviceCount() const noexcept override { + return 1; // pretend single device + } + + // -------------------------------------------------------------------------- + // Stream handling (동기식이므로 기본 stream만 사용) + // -------------------------------------------------------------------------- + Stream getStream(Device d) const override { + return Stream(Stream::DEFAULT, d); + } + + Stream getNewStream(Device d, int priority = 0) const override { + return Stream(Stream::DEFAULT, d); + } + + Stream getStreamFromGlobalPool(Device d, bool = false) const override { + return Stream(Stream::DEFAULT, d); + } + + Stream exchangeStream(Stream s) const override { + return s; + } + + bool queryStream(const Stream& stream) const override { + (void)stream; + return true; + } + + void synchronizeStream(const Stream& stream) const override { + (void)stream; + } + + void synchronizeDevice(DeviceIndex device_index) const override { + (void)device_index; + } + + // -------------------------------------------------------------------------- + // Event handling (전부 no-op) + // -------------------------------------------------------------------------- + void destroyEvent(void* event, const DeviceIndex device_index) const noexcept override { + (void)event; + (void)device_index; + } + + void record(void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override { + (void)event; + (void)stream; + (void)device_index; + (void)flag; + } + + void block(void* event, const Stream& stream) const override { + (void)event; + (void)stream; + } + + bool queryEvent(void* event) const override { + (void)event; + return true; + } + + void synchronizeEvent(void* event) const override { + (void)event; + } + + double elapsedTime(void* start_event, void* end_event, const DeviceIndex device_index) const override { + (void)start_event; + (void)end_event; + (void)device_index; + return 0.0; + } + + // -------------------------------------------------------------------------- + // Misc (allocator integration) + // -------------------------------------------------------------------------- + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override { + (void)data_ptr; + (void)stream; + } +}; + +} // namespace c10::extension_device::impl diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimDevice/extension_device.cpp similarity index 99% rename from PyTorchSimFrontend/extension_device.cpp rename to PyTorchSimDevice/extension_device.cpp index cfaecf2b..a1dcfcf4 100644 --- a/PyTorchSimFrontend/extension_device.cpp +++ b/PyTorchSimDevice/extension_device.cpp @@ -55,16 +55,12 @@ static inline at::MemoryFormat fix_memory_format(c10::optional return mf; } +#include "ExtensionDeviceGuardImpl.h" + static uint64_t op_counter = 0; static uint64_t last_saved_value = 0; -// register guard -namespace at { -namespace detail { - -C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); - -}} // namespace at::detail +C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::extension_device::impl::ExtensionDeviceGuardImpl); // basic dummy add function at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { @@ -159,7 +155,7 @@ at::Tensor custom_to_device( // A dummy allocator for our custom device, that secretly uses the CPU struct DummyCustomAllocator final : at::Allocator { DummyCustomAllocator() = default; - at::DataPtr allocate(size_t nbytes) const override { + at::DataPtr allocate(size_t nbytes) override { void* data = c10::alloc_cpu(nbytes); return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; } @@ -174,6 +170,10 @@ struct DummyCustomAllocator final : at::Allocator { at::DeleterFnPtr raw_deleter() const override { return &ReportAndDelete; } + + void copy_data(void* dest, const void* src, std::size_t count) const override { + std::memcpy(dest, src, count); + } }; // Register our dummy allocator diff --git a/PyTorchSimDevice/extension_device_interface.py b/PyTorchSimDevice/extension_device_interface.py new file mode 100644 index 00000000..e5875ab7 --- /dev/null +++ b/PyTorchSimDevice/extension_device_interface.py @@ -0,0 +1,63 @@ +import torch +from torch._dynamo.device_interface import DeviceInterface, caching_worker_current_devices, caching_worker_device_properties + +class _ExtensionDeviceProperties: # FIXME: Dummy property values + name: str = "Extension_device" + platform_name: str + vendor: str + driver_version: str + version: str + max_compute_units: int + gpu_eu_count: int + max_work_group_size: int + max_num_sub_groups: int + sub_group_sizes: list[int] + has_fp16: bool + has_fp64: bool + has_atomic64: bool + has_bfloat16_conversions: bool + has_subgroup_matrix_multiply_accumulate: bool + has_subgroup_matrix_multiply_accumulate_tensor_float32: bool + has_subgroup_2d_block_io: bool + total_memory: int + multi_processor_count: int = 128 # gpu_subslice_count, num_sm + architecture: int + type: str + +_ExtensionDeviceProperties = _ExtensionDeviceProperties + +class ExtensionDeviceInterface(DeviceInterface): + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["extension_device"] = device + + @staticmethod + def current_device() -> int: + if "extension_device" in caching_worker_current_devices: + return caching_worker_current_devices["extension_device"] + return torch.xpu.current_device() + + @staticmethod + def get_device_properties(device: torch.types.Device = None) -> _ExtensionDeviceProperties: + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "extension_device" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = ExtensionDeviceInterface.Worker.current_device() + + if "extension_device" not in caching_worker_device_properties: + device_prop = [ + torch.cuda.get_device_properties(i) + for i in range(torch.cuda.device_count()) + ] + caching_worker_device_properties["extension_device"] = device_prop + + return _ExtensionDeviceProperties + + @staticmethod + def get_compute_capability(device: torch.types.Device = None): + return 36 \ No newline at end of file diff --git a/PyTorchSimDevice/extension_device_op_overrides.py b/PyTorchSimDevice/extension_device_op_overrides.py new file mode 100644 index 00000000..27a47357 --- /dev/null +++ b/PyTorchSimDevice/extension_device_op_overrides.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from textwrap import dedent + +from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op_overrides +from torch._inductor.codegen.cpu_device_op_overrides import CpuDeviceOpOverrides + +class ExtensionDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return dedent( + """ + def get_raw_stream(_): + return 0 + """ + ) + + def set_device(self, device_idx: int) -> str: + return "pass" + + def synchronize(self) -> str: + return "pass" + + def device_guard(self, device_idx: int) -> str: + return "pass" + +register_device_op_overrides("npu", ExtensionDeviceOpOverrides()) +register_device_op_overrides("cpu", CpuDeviceOpOverrides()) \ No newline at end of file diff --git a/PyTorchSimDevice/extension_hooks.cpp b/PyTorchSimDevice/extension_hooks.cpp new file mode 100644 index 00000000..aadd6d2a --- /dev/null +++ b/PyTorchSimDevice/extension_hooks.cpp @@ -0,0 +1,48 @@ +#include "extension_hooks.h" + +bool ExtensionPU1Hooks::isBuilt() const { return true; } +bool ExtensionPU1Hooks::isAvailable() const { return true; } + +const at::Generator& ExtensionPU1Hooks::getDefaultGenerator(c10::DeviceIndex idx) const { + if (idx < 0) idx = 0; + static std::vector gens; + static std::mutex m; + std::lock_guard g(m); + if (gens.size() <= (size_t)idx) gens.resize((size_t)idx + 1); + if (!gens[idx].defined()) gens[idx] = at::GetGeneratorForPrivateuse1(idx); + return gens[idx]; // 영속 객체 참조 반환 +} + +at::Generator ExtensionPU1Hooks::getNewGenerator(c10::DeviceIndex idx) const { + if (idx < 0) idx = 0; + return at::GetGeneratorForPrivateuse1(idx); +} + +at::Device ExtensionPU1Hooks::getDeviceFromPtr(void* data) const { + return at::Device(at::kPrivateUse1, 0); // MVP: 단일 디바이스 가정 +} + +bool ExtensionPU1Hooks::isPinnedPtr(const void* data) const { + return false; +} + +at::Allocator* ExtensionPU1Hooks::getPinnedMemoryAllocator() const { + return at::getHostAllocator(at::kPrivateUse1); +} + +bool ExtensionPU1Hooks::hasPrimaryContext(c10::DeviceIndex device_index) const { return true; } + +void ExtensionPU1Hooks::resizePrivateUse1Bytes(const c10::Storage&, size_t) const { + TORCH_CHECK(false, "resizePrivateUse1Bytes not implemented"); +} + +// REGISTER_EXTENSION_HOOKS(ExtensionPU1Hooks); + +namespace { +struct AutoRegistrar { + AutoRegistrar() { + at::RegisterPrivateUse1HooksInterface(new ExtensionPU1Hooks()); + } +}; +static AutoRegistrar _auto_registrar; +} diff --git a/PyTorchSimDevice/extension_hooks.h b/PyTorchSimDevice/extension_hooks.h new file mode 100644 index 00000000..fdf3505a --- /dev/null +++ b/PyTorchSimDevice/extension_hooks.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +struct ExtensionPU1Hooks final : public at::PrivateUse1HooksInterface { + ExtensionPU1Hooks() {} + bool isBuilt() const; + bool isAvailable() const; + + const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override; + + at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override; + + at::Device getDeviceFromPtr(void* data) const override; + + bool isPinnedPtr(const void* data) const override; + + at::Allocator* getPinnedMemoryAllocator() const override; + + bool hasPrimaryContext(c10::DeviceIndex device_index) const override; + + void resizePrivateUse1Bytes(const c10::Storage& /*storage*/, size_t /*newsize*/) const override; +}; \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 2e35220c..5066d214 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -3,12 +3,16 @@ import shlex import subprocess -from torch._inductor.codecache import AsyncCompile, get_lock_dir, get_hash, write +from torch._inductor.codecache import get_lock_dir, get_hash, write +from torch._inductor.async_compile import AsyncCompile from AsmParser.tog_generator import tog_generator from PyTorchSimFrontend.mlir.mlir_caller_codegen import MLIRKernelCallerCodeGen from PyTorchSimFrontend import extension_config from Simulator.simulator import FunctionalSimulator, CycleSimulator, TOGSimulator +# Configure logger for extension_codecache module (WARNING level by default) +logger = extension_config.setup_logger() + LOCK_TIMEOUT = 600 def hash_prefix(hash_value): @@ -165,8 +169,8 @@ def load(cls, source_code, subprocess.check_call(translate_cmd) subprocess.check_call(llc_cmd) except subprocess.CalledProcessError as e: - print("Command failed with exit code", e.returncode) - print("Error output:", e.output) + logger.error(f"Command failed with exit code {e.returncode}") + logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") assert(0) val_llvm_caller = MLIRKernelCallerCodeGen(extension_config.pytorchsim_functional_mode, arg_attributes) @@ -178,8 +182,10 @@ def load(cls, source_code, spad_size = val_llvm_caller.get_spad_size(target) spad_usage = stack_size + spad_size # Spad usage per lane if extension_config.CONFIG_SPAD_INFO["spad_size"] < spad_usage: - print(f"[Warning] Scratchpad size exceeded: required {spad_usage} bytes, " - f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available.") + logger.debug( + f"Scratchpad size exceeded: required {spad_usage} bytes, " + f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available." + ) raise SpadOverflowError() # Launch tile graph generator @@ -196,8 +202,8 @@ def load(cls, source_code, subprocess.check_call(gem5_translate_cmd) subprocess.check_call(gem5_llc_cmd) except subprocess.CalledProcessError as e: - print("Command failed with exit code", e.returncode) - print("Error output:", e.output) + logger.error(f"Command failed with exit code {e.returncode}") + logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") assert(0) if not extension_config.pytorchsim_timing_mode: diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 2b1b3102..b0bcac7f 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -2,6 +2,7 @@ import sys import importlib import yaml +import logging CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') CONFIG_GEM5_PATH = os.environ.get('GEM5_PATH', default="/workspace/gem5/build/RISCV/gem5.opt") @@ -134,4 +135,43 @@ def load_plan_from_module(module_path): CONFIG_USE_TIMING_POOLING = int(os.environ.get('TORCHSIM_USE_TIMING_POOLING', default=0)) -CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0)) \ No newline at end of file +CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0)) + + +def setup_logger(name=None, level=None): + """ + Setup a logger with consistent formatting across all modules. + + Args: + name: Logger name (default: __name__ of calling module) + level: Logging level (default: DEBUG if CONFIG_DEBUG_MODE else INFO) + + Returns: + Logger instance + """ + if name is None: + import inspect + # Get the calling module's name + frame = inspect.currentframe().f_back + name = frame.f_globals.get('__name__', 'PyTorchSim') + + # Convert logger name to lowercase + name = name.lower() + logger = logging.getLogger(name) + + # Only configure if not already configured (avoid duplicate handlers) + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + fmt='[%(asctime)s.%(msecs)03d] [%(levelname)s] [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # Set log level + if level is None: + level = logging.DEBUG if CONFIG_DEBUG_MODE else logging.INFO + logger.setLevel(level) + + return logger \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_utils.py b/PyTorchSimFrontend/extension_utils.py new file mode 100644 index 00000000..0418cacd --- /dev/null +++ b/PyTorchSimFrontend/extension_utils.py @@ -0,0 +1,26 @@ +import sympy +import torch + +""" +NOTE: Temporary File + +This file contains functions that were removed or changed in newer versions +of PyTorch. It is kept here only to temporarily enable compatibility while +upgrading to PyTorch 2.8 from PyTorch 2.2. + +These functions will eventually be integrated into the appropriate source files +or removed once no longer needed. + +This file is not intended to be permanent and should be deleted in the future. +""" + +def free_symbol_startswith(index: sympy.Expr, prefix: str): + return any(v.name.startswith(prefix) for v in index.free_symbols) + +def sympy_symbol(name: str) -> sympy.Symbol: + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index 988408ea..138bec50 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -49,6 +49,9 @@ def __init__( self.extra_args = extra_args #self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + def make_run_fn( self, input_tensors: torch.Tensor, output_tensors: torch.Tensor ) -> Callable[[], None]: @@ -84,5 +87,6 @@ def cached_run_fn(*args, **kwargs): *args, ) - def __str__(self) -> str: - return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" \ No newline at end of file + def update_workspace_size(self) -> None: + # FIXME: Not implemented yet. Checkout torch/_inductor/codegen/rocm/rocm_benchmark_request.py + return \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 297ea162..3d65c0a4 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -2,16 +2,17 @@ import sympy import re import os -import math from functools import reduce from operator import mul import torch +from typing import Optional from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from torch._dynamo.testing import rand_strided from torch._inductor.autotune_process import TensorMeta from torch._dynamo.utils import dynamo_timed from torch._inductor.codegen import cpp, wrapper, common, memory_planning +from torch._inductor.ir import GraphPartitionSignature from torch._inductor.virtualized import V, _ops as ops from torch._inductor.codecache import write_atomic from torch._inductor.utils import ( @@ -27,6 +28,11 @@ from .mlir_ops import ExtensionOverrides from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest +# Configure logger for mlir_codegen_backend module +logger = extension_config.setup_logger() + +from Simulator.simulator import ProgressBar + def reduction_init(reduction_type, dtype): if dtype in cpp.DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, the initial @@ -54,13 +60,28 @@ def reduction_partial_combine_vec(reduction_type, vector_value, init_value): if reduction_type == "min": return ops.minimum(vector_value, init_value) if reduction_type == "any": - return ops.logical_and(vector_value, init_value) + return ops.logical_or(vector_value, init_value) raise AssertionError(reduction_type) -class ExtensionWrapperCodegen(wrapper.WrapperCodeGen): +class ExtensionWrapperCodegen(wrapper.PythonWrapperCodegen): def __init__(self): super().__init__() + @classmethod + def create( + cls, + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[wrapper.PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ): + if is_subgraph: + assert subgraph_name is not None and parent_wrapper is not None + return wrapper.SubgraphPythonWrapperCodegen( + subgraph_name, parent_wrapper, partition_signatures + ) + return cls() + def write_header(self): self.header.splice( f""" @@ -74,21 +95,27 @@ def write_header(self): from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align + from torch._inductor.async_compile import AsyncCompile from torch import device, empty, empty_strided from {extension_codecache.__name__} import CustomAsyncCompile - from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE + from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE, setup_logger from Simulator.simulator import TOGSimulator from PyTorchSimFrontend.extension_op import sparse_mm_dummy_stonne_outer from torch._inductor.select_algorithm import extern_kernels + # Configure logger for generated wrapper code + _logger = setup_logger("PyTorchSimFrontend.mlir.generated_wrapper") + aten = torch.ops.aten inductor_ops = torch.ops.inductor assert_size_stride = torch._C._dynamo.guards.assert_size_stride alloc_from_pool = torch.ops.inductor._alloc_from_pool reinterpret_tensor = torch.ops.aten._reinterpret_tensor custom_async_compile = CustomAsyncCompile() + async_compile = AsyncCompile() os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__ + _logger.info(f'Wrapper Codegen Path = {{__file__}}') """ ) self.header.splice( @@ -120,6 +147,7 @@ def device2host_memcpy(buffer): ) def write_prefix(self): + self.write_async_compile_wait() self.prefix.splice( """ def call(args): @@ -132,7 +160,7 @@ def call(args): self.prefix.writeline(f"{lhs} = args") self.prefix.writeline("args.clear()") - self.codegen_inputs(self.prefix, V.graph.graph_inputs) + self.codegen_inputs() self.codegen_input_size_asserts() self.codegen_sram_plan_prefix() @@ -152,35 +180,60 @@ def codegen_sram_plan_postfix(self, outputs): continue self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") - @dynamo_timed + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + device = device or V.graph.get_current_device_or_throw() + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + return + def generate(self, is_inference): result = IndentedBuffer() - result.splice(self.header) + # result.splice(self.header) with contextlib.ExitStack() as stack: stack.enter_context(self.wrapper_call.indent()) self.memory_plan_reuse() - for line in self.lines: - # Add buffer plan hook for dealloc - if isinstance(line, memory_planning.DeallocFromPoolLine): - self.wrapper_call.writeline(f"sram_plan_postfix('{line.node.get_name()}', {line.node.get_name()})") - elif isinstance(line, str) and "del" in line: - name = line.split(" ")[1] - self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") - - if isinstance(line, wrapper.MemoryPlanningLine): - line.codegen(self.wrapper_call) - else: - self.wrapper_call.writeline(line) - # Add buffer plan hook for alloc - if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine): - self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})") + with self.set_writeline(self.wrapper_call.writeline): + for line in self.lines: + # Add buffer plan hook for dealloc + if isinstance(line, memory_planning.DeallocFromPoolLine): + self.wrapper_call.writeline(f"sram_plan_postfix('{line.node.get_name()}', {line.node.get_name()})") + elif isinstance(line, str) and "del" in line: + name = line.split(" ")[1] + self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") + + if isinstance(line, wrapper.MemoryPlanningLine): + line.codegen(self.wrapper_call) + elif isinstance(line, wrapper.KernelCallLine): + self.wrapper_call.writeline(self.wrap_kernel_call(line.kernel_name, line.call_args)) + else: + if isinstance(line, wrapper.WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + # Add buffer plan hook for alloc + if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine): + self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})") output_refs = self.get_output_refs() self.codegen_sram_plan_postfix(output_refs) self.mark_output_type() self.generate_return(output_refs) - self.append_precomputed_sizes_to_prefix() + # self.append_precomputed_sizes_to_prefix() # FIXME: Need to replace append_precomputed_sizes_to_prefix() + result.splice(self.header) + self.finalize_prefix() result.splice(self.prefix) @@ -189,7 +242,10 @@ def generate(self, is_inference): self.generate_end(result) self.add_benchmark_harness(result) - return result.getvaluewithlinemap() + return ( + result.getvaluewithlinemap(), + self.kernel_declarations.getvaluewithlinemap(), + ) def memory_plan(self): self.lines = memory_planning.MemoryPlanner(self).plan(self.lines) @@ -286,6 +342,7 @@ def convert_index(self, expr, buffer): expr_str = expr_str.replace("//", " floordiv ") else: raise NotImplementedError("What is this case?") + first_arg = expr.args[0] if len(first_arg.free_symbols) != 1: raise NotImplementedError("What is this case?") @@ -622,16 +679,17 @@ def store_reduction(self, name, index, value): dram_shape, tile_shape, attribute) self.reductions_suffix.writeline(common.DeferredLine(name, code)) - def indirect_indexing(self, index_var, size, check=True): + def indirect_indexing(self, index_var, size, check=True, wrap_neg=True): return str(index_var) def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): # In case of index expr, dimension size should be divisible by tile size if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges): new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges) + prior_tile_size, prior_ranges = self.kernel_group.tile_desc.get_tile_size(), self.ranges self.kernel_group.tile_desc.set_tile_size(new_tile_size) self.reset("recompile") - raise mlir_common.RecompileSignal(f"Index access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})") + raise mlir_common.RecompileSignal(f"Index access (tile size {prior_tile_size} is not divisible by {prior_ranges})") tile_size = tile_desc.get_tile_size_per_lane() compute_vec_size = tile_desc.get_compute_vec_size() @@ -719,6 +777,7 @@ def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): return accum def index_expr(self, index, dtype): + index = self.rename_indexing(index) base_tile_desc = self.kernel_group.tile_desc if len(self.ranges) != self.reduction_depth: # FIXME. This is a temporary solution to get tile stride of the reduction case @@ -851,15 +910,14 @@ def make_choices(self, nodes, kernel_name): # Try initial tile size self.reset(None) - src_code = super().codegen_nodes(nodes, kernel_name) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) search_space.add(current_tile_sz) - if extension_config.CONFIG_DEBUG_MODE: - print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + logger.debug(f"Auto-tune: Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride)) + choices.append((bench_runner, src_code, meta_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride)) while prevent_infinite_loop < 10 and candidate_axes: for axis in list(candidate_axes): @@ -881,7 +939,7 @@ def make_choices(self, nodes, kernel_name): continue self.reset(None) - src_code = super().codegen_nodes(nodes, kernel_name) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) # FIXME. How to intergrate this constraint to tile system? @@ -898,11 +956,10 @@ def make_choices(self, nodes, kernel_name): # Add this choice search_space.add(current_tile_sz) - if extension_config.CONFIG_DEBUG_MODE: - print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + logger.debug(f"Auto-tune: Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride)) + choices.append((bench_runner, src_code, meta_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride)) prevent_infinite_loop += 1 self.kernel_group.tile_desc.prev_tail_threshold = prev_tail_threshold return choices @@ -918,18 +975,24 @@ def get_cycle(choice): return float("inf") return float("inf") # Exceeded maximum number of autotuning attempts choices = self.make_choices(*args) - if len(choices) == 0: # Can't autotune - return [None, None] - with ThreadPoolExecutor(max_workers=8) as executor: - results = list(executor.map(get_cycle, choices)) - max_idx = results.index(min(results)) + return [None, None, None] + + # Get cycle time for each choice + # Show progress bar only when CONFIG_DEBUG_MODE is off + show_progress = not extension_config.CONFIG_DEBUG_MODE + with ProgressBar("[Auto-tune] Running benchmarks", silent_mode=not show_progress) if show_progress else contextlib.nullcontext(): + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(get_cycle, choices)) + + min_idx = results.index(min(results)) if min(results) == float("inf"): raise RuntimeError("Failed to find optimal tile size...") - if extension_config.CONFIG_DEBUG_MODE: - self._log_autotune_result(choices[max_idx], results[max_idx]) - optimal_src_code, loop_size = choices[max_idx][1], choices[max_idx][-1] - return optimal_src_code, loop_size + + self._log_autotune_result(choices[min_idx], results[min_idx]) + + optimal_src_code, meta_code, loop_size = choices[min_idx][1], choices[min_idx][2], choices[min_idx][-1] + return optimal_src_code, meta_code, loop_size def run_bench(self, nodes, kernel_name, src_code): _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() @@ -957,20 +1020,20 @@ def run_bench(self, nodes, kernel_name, src_code): return bmreq.make_run_fn(dummy_inputs, dummy_outputs) def _log_autotune_result(self, best_choice, best_cycle): - print( - f"[Auto-tune] Optimal tile size: {list(best_choice[2])}, " - f"vlane_stride: {best_choice[3]}, " + logger.debug( + f"Auto-tune: Optimal tile size: {list(best_choice[3])}, " + f"vlane_stride: {best_choice[4]}, " f"cycles: {best_cycle}" ) def codegen_nodes(self, nodes, kernel_name): - src_code = super().codegen_nodes(nodes, kernel_name) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) self._prepare_simulator_headers(src_code) if "autotune" in extension_config.codegen_mapping_strategy and extension_config.pytorchsim_timing_mode: - optimal_src_code = self.autotune(nodes, kernel_name)[0] + optimal_src_code, meta_code = self.autotune(nodes, kernel_name)[:2] if optimal_src_code is not None: - return optimal_src_code - return src_code + return optimal_src_code, meta_code + return src_code, meta_code def _prepare_simulator_headers(self, src_code): write_path = extension_codecache.get_write_path(src_code) @@ -1110,7 +1173,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 for i in range(max_dim): target_dim = f"index{i}" - if target_dim not in str(index): + if sympy.Symbol(target_dim) not in index.free_symbols: dram_dict[target_dim] = [0] sorted_keys = sorted(dram_dict.keys()) dram_stride = sum((dram_dict[key] for key in sorted_keys), []) @@ -1127,14 +1190,19 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe dim_idx = int((str(sub.args[0])[5:])) if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: # In this case, need to recompile - original_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] - divisor = sub.args[1] + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + divisor = sub.args[1] * self.kernel_group.tile_desc.vmap.vlane_stride new_size = ((original_size + divisor - 1) // divisor) * divisor new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) new_tile_sizes[dim_idx] = new_size self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + # Can't use dim_idx as vlane_split_axis + if dim_idx == self.kernel_group.tile_desc.vmap.vlane_split_axis: + self.kernel_group.tile_desc.vmap.vlane_split_axis = (dim_idx + 1) % len(original_tile) + # Send recompile signal self.reset("recompile") raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") @@ -1150,6 +1218,57 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split") offset = offset+1 + # Support ModularIndexing pattern + # This pattern can be used to broadcast ex) torch.cat([a,a]) + # ModularIndexing(x, y, z) means (x // y) % z + # tile_size must be: multiple of y (floorDiv divisor) and divisor of z (modular divisor) + if index.has(ModularIndexing): + for sub in sympy.preorder_traversal(index): + if isinstance(sub, ModularIndexing): + if not str(sub.args[0]).startswith("index"): + continue + dim_idx = int((str(sub.args[0])[5:])) + floor_divisor = sub.args[1] # y: floorDiv divisor + mod_divisor = sub.args[2] # z: modular divisor + current_tile_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] + + # Check if tile_size is multiple of floorDiv divisor + if int(current_tile_size % floor_divisor) != 0: + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + divisor = floor_divisor * self.kernel_group.tile_desc.vmap.vlane_stride + new_size = ((original_size + divisor - 1) // divisor) * divisor + new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) + new_tile_sizes[dim_idx] = new_size + self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) + self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a multiple of floorDiv divisor {floor_divisor} in ModularIndexing") + + # Check if tile_size is a divisor of modular divisor + if int((mod_divisor * floor_divisor) % current_tile_size) != 0: + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + # Find the largest divisor of mod_divisor that is <= original_size + # and is a multiple of floor_divisor + new_size = original_size + while new_size > 0: + if mod_divisor % new_size == 0 and new_size % floor_divisor == 0: + break + new_size -= floor_divisor + + if new_size <= 0: + new_size = mod_divisor * floor_divisor + + new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) + new_tile_sizes[dim_idx] = new_size + self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) + self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a divisor of modular divisor {mod_divisor} in ModularIndexing") + # FIXME. It will be nice to modify node instead of this exception handling... if len(self.itervars) == 1 and self.reduction_depth == 0: # In case of reduction loop only case, we will add dummy loop so shift it once diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index b86607ea..e31555ba 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -14,7 +14,8 @@ from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep -from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod +from torch._inductor.codegen.wrapper import KernelDefinitionLine +from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity import sympy import contextlib @@ -22,18 +23,21 @@ import sympy -import torch.fx from torch.utils._sympy.value_ranges import ValueRanges from torch._inductor.utils import ( - free_symbol_startswith, get_sympy_Expr_dtype, IndentedBuffer, sympy_subs, - sympy_symbol, unique, ) from PyTorchSimFrontend import extension_config from PyTorchSimFrontend import extension_codecache + +from PyTorchSimFrontend.extension_utils import ( + free_symbol_startswith, + sympy_symbol +) + schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") DTYPE_TO_MLIR = { @@ -605,9 +609,9 @@ def __init__(self, kernel_group, reason=None): self.recodegen = reason # spad overflow, tile size, vlane stride self.stop_autotune = False - # Context var for codegen - self.target_buffer_override = contextvars.ContextVar("Handler_compute_override", default=self.compute) - self.target_cse_override = contextvars.ContextVar("Handler_cse_override", default=self.cse) + instance_id = id(self) + self.target_buffer_override = contextvars.ContextVar(f"Handler_compute_override_{instance_id}", default=self.compute) + self.target_cse_override = contextvars.ContextVar(f"Handler_cse_override_{instance_id}", default=self.cse) def set_ranges(self, lengths, reduction_lengths): if self.call_ranges: @@ -641,7 +645,7 @@ def store(self, name, index, value, mode=None): def reduction(self, dtype, src_dtype, reduction_type, value): raise NotImplementedError() - def indirect_indexing(self, index_var, size, check): + def indirect_indexing(self, index_var, size, check, wrap_neg): raise NotImplementedError() def codegen_global_init(self): @@ -654,7 +658,7 @@ def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this - wrapper.generate_kernel_call(kernel_name, call_args, cuda=False) + wrapper.generate_kernel_call(kernel_name, call_args, triton=False) def is_modular_indexing(self, expr): return "ModularIndexing" in str(expr) @@ -688,7 +692,9 @@ def extract_dividers(self, implicit_ops): } new_index = operand.index.subs(subs_map) for arg in new_index.args: - if len(arg.free_symbols) != 1: + if arg.is_number: + continue + if len(arg.free_symbols) > 1: raise NotImplementedError("Not supporting this view operation...!") if arg.is_Mul and arg.args[0].is_number: arg = arg.args[1] @@ -778,8 +784,8 @@ def codegen_nodes(self, nodes, kernel_name): V.graph.removed_buffers |= self.removed_buffers # V.graph.inplaced_to_remove |= self.inplaced_to_remove src_code = self.codegen_kernel(kernel_name=kernel_name) - self.meta_kernel() - return src_code + meta_code = self.meta_kernel() + return src_code, meta_code def codegen_kernel(self, kernel_name): arg_defs, _, _, _ = self.kernel_group.args.mlir_argdefs() @@ -797,12 +803,9 @@ def codegen_kernel(self, kernel_name): return code.getvalue() def meta_kernel(self): - wrapper = V.graph.wrapper_code _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - # Dump loop and load/store information - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") - return arg_attributes + meta_code = arg_attributes + return meta_code def get_constant_vector(self, expr): constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] @@ -835,6 +838,21 @@ def rename_indexing(self, index) -> sympy.Expr: # and renames variables in index expressions to kernel arg names if isinstance(index, (list, tuple)): return [self.rename_indexing(x) for x in index] + + # FIXME. This is a temporary solution to remove Identity wrappers from index expression. + # Remove Identity wrappers from index expression + # Check if index itself is Identity + if isinstance(index, Identity): + index = index.args[0] if index.args else index + + # Replace Identity arguments with Identity.args[0] + if hasattr(index, 'args') and len(index.args) > 0: + for arg in index.args: + if arg.is_Mul and arg.args[0].is_number and isinstance(arg.args[1], Identity): + index = index.replace(arg.args[1], arg.args[1].args[0] if arg.args[1].args else arg.args[1]) + if isinstance(arg, Identity): + index = index.replace(arg, arg.args[0] if arg.args else arg) + index = V.graph.sizevars.simplify(index) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) replacements = { @@ -846,18 +864,20 @@ def rename_indexing(self, index) -> sympy.Expr: @contextmanager def override_buffer_cse(self, *, buffer=None, cse=None): + buffer_override = self.target_buffer_override + cse_override = self.target_cse_override target_buffer = target_cse = None try: if buffer is not None: - target_buffer = self.target_buffer_override.set(buffer) + target_buffer = buffer_override.set(buffer) if cse is not None: - target_cse = self.target_cse_override.set(cse) + target_cse = cse_override.set(cse) yield self finally: if target_cse is not None: - self.target_cse_override.reset(target_cse) + cse_override.reset(target_cse) if target_buffer is not None: - self.target_buffer_override.reset(target_buffer) + buffer_override.reset(target_buffer) def __enter__(self): class CSEProxy: @@ -866,7 +886,7 @@ class CSEProxy: @staticmethod def __getattr__(name: str) -> Callable[..., common.CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): - code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info, **kwargs) + code, ret_info = getattr(parent_handler, name)(*args, **kwargs) target_buffer = self.target_buffer_override.get() target_cse = self.target_cse_override.get() if isinstance(code, common.DeferredLine): @@ -887,9 +907,9 @@ def inner(*args, **kwargs): return inner @staticmethod - def indirect_indexing(index_var, size, check=True): + def indirect_indexing(index_var, size, check=True, wrap_neg=True): # Skip CSE since this doesn't return an expression - return self.indirect_indexing(index_var, size, check) + return self.indirect_indexing(index_var, size, check, wrap_neg) @staticmethod def load(name: str, index: sympy.Expr): @@ -903,10 +923,10 @@ def load(name: str, index: sympy.Expr): if name in store_cache: return store_cache[name] key = name+str(index) - if key not in self.cse.cache: + if key not in self.cse._cache: result = self.load(name, index) - self.cse.cache[key] = result - return self.cse.cache[key] + self.cse._cache[key] = result + return self.cse._cache[key] @staticmethod def store(name, index, value, mode=None): @@ -914,7 +934,7 @@ def store(name, index, value, mode=None): if mode is None: self.cse.store_cache[name] = value if self.current_node: - for other_name in self.current_node.get_mutations(): + for other_name in self.current_node.get_output(name).get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: return self.store(name, index, value, mode=mode) @@ -924,7 +944,7 @@ def store_reduction(name, index, value): self.store_buffer_names.add(name) self.cse.store_cache[name] = value if self.current_node: - for other_name in self.current_node.get_mutations(): + for other_name in self.current_node.get_output(name).get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: @@ -970,7 +990,7 @@ def bucketize( super().__enter__() assert self.overrides - parent_handler = self.overrides(V.get_ops_handler()) + parent_handler = self.overrides() self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) self.exit_stack.enter_context(V.set_kernel_handler(self)) return self diff --git a/PyTorchSimFrontend/mlir/mlir_decomposition.py b/PyTorchSimFrontend/mlir/mlir_decomposition.py new file mode 100644 index 00000000..284d25d7 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_decomposition.py @@ -0,0 +1,167 @@ +import math +import torch +import torch.nn.functional as F +from torch._inductor.decomposition import register_decomposition + +aten = torch.ops.aten + +@register_decomposition(aten._native_multi_head_attention.default) +def decompose_native_multi_head_attention( + query, + key, + value, + embed_dim: int, + num_heads: int, + qkv_weight, + qkv_bias, + proj_weight, + proj_bias, + mask=None, + need_weights: bool = False, +): + """ + Decompose _native_multi_head_attention into scaled_dot_product_attention operations. + + Based on F.scaled_dot_product_attention and nn.MultiheadAttention implementation: + 1. QKV projection (if needed - but query/key/value may already be projected) + 2. Reshape to multi-head format + 3. Scaled dot product: Q @ K^T / sqrt(head_dim) + 4. Softmax + 5. Attention @ V + 6. Reshape back and output projection + """ + head_dim = embed_dim // num_heads + scale_factor = 1.0 / math.sqrt(head_dim) + + # Get input shapes - assuming [batch, seq_len, embed_dim] format + query_shape = query.shape + if len(query_shape) == 3: + # [batch, seq_len, embed_dim] format + batch_size = query_shape[0] + seq_len = query_shape[1] + elif len(query_shape) == 2: + # [seq_len, embed_dim] -> add batch dimension + batch_size = 1 + seq_len = query_shape[0] + query = query.unsqueeze(0) # [1, seq_len, embed_dim] + key = key.unsqueeze(0) + value = value.unsqueeze(0) + else: + # Fallback: assume first dim is batch, second is seq_len + batch_size = query_shape[0] if len(query_shape) > 0 else 1 + seq_len = query_shape[1] if len(query_shape) > 1 else query_shape[0] + + # Step 1: QKV projection (if query/key/value are not already projected) + # In many cases, query/key/value are already projected, so we check if qkv_weight is used + # For now, assume they might need projection + # Note: In practice, _native_multi_head_attention often receives already projected inputs + + # Reshape for projection: [batch, seq_len, embed_dim] -> [batch*seq_len, embed_dim] + if len(query.shape) == 3: + query_flat = query.view(-1, embed_dim) + key_flat = key.view(-1, embed_dim) + value_flat = value.view(-1, embed_dim) + else: + query_flat = query + key_flat = key + value_flat = value + + # QKV projection using qkv_weight and qkv_bias + # Check if GQA (Grouped Query Attention) is used + # Standard MHA: qkv_weight shape = [3*embed_dim, embed_dim] + # GQA: qkv_weight shape = [embed_dim + 2*kv_embed_dim, embed_dim] where kv_embed_dim < embed_dim + qkv_weight_total = qkv_weight.shape[0] + + # Determine if GQA: if qkv_weight is not exactly 3*embed_dim, it might be GQA + if qkv_weight_total == 3 * embed_dim: + # Standard MHA: split equally + qkv_weight_q, qkv_weight_k, qkv_weight_v = torch.split(qkv_weight, embed_dim, dim=0) + if qkv_bias is not None: + qkv_bias_q, qkv_bias_k, qkv_bias_v = torch.split(qkv_bias, embed_dim, dim=0) + else: + qkv_bias_q = qkv_bias_k = qkv_bias_v = None + kv_embed_dim = embed_dim + kv_heads = num_heads + else: + # GQA: Q has embed_dim, K and V share the rest + # Assume Q = embed_dim, K = V = (qkv_weight_total - embed_dim) / 2 + q_dim = embed_dim + kv_dim = (qkv_weight_total - embed_dim) // 2 + qkv_weight_q = qkv_weight[:q_dim] + qkv_weight_k = qkv_weight[q_dim:q_dim + kv_dim] + qkv_weight_v = qkv_weight[q_dim + kv_dim:] + if qkv_bias is not None: + qkv_bias_q = qkv_bias[:q_dim] + qkv_bias_k = qkv_bias[q_dim:q_dim + kv_dim] + qkv_bias_v = qkv_bias[q_dim + kv_dim:] + else: + qkv_bias_q = qkv_bias_k = qkv_bias_v = None + kv_embed_dim = kv_dim + kv_heads = kv_embed_dim // head_dim # Number of KV heads + + # Project Q, K, V + q = torch.nn.functional.linear(query_flat, qkv_weight_q, qkv_bias_q) + k = torch.nn.functional.linear(key_flat, qkv_weight_k, qkv_bias_k) + v = torch.nn.functional.linear(value_flat, qkv_weight_v, qkv_bias_v) + + # Reshape back: [batch*seq_len, embed_dim] -> [batch, seq_len, embed_dim] + q = q.view(batch_size, seq_len, embed_dim) + k = k.view(batch_size, seq_len, kv_embed_dim) + v = v.view(batch_size, seq_len, kv_embed_dim) + + # Step 2: Reshape to multi-head format + # [batch, seq_len, embed_dim] -> [batch, seq_len, num_heads, head_dim] + q = q.view(batch_size, seq_len, num_heads, head_dim) + k = k.view(batch_size, seq_len, kv_heads, head_dim) + v = v.view(batch_size, seq_len, kv_heads, head_dim) + + # Transpose to [batch, num_heads, seq_len, head_dim] for bmm + q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] + k = k.transpose(1, 2) # [batch, kv_heads, seq_len, head_dim] + v = v.transpose(1, 2) # [batch, kv_heads, seq_len, head_dim] + + # GQA: If key/value have fewer heads, repeat them to match query heads + if kv_heads < num_heads: + repeat_factor = num_heads // kv_heads + k = k.repeat_interleave(repeat_factor, dim=1) # [batch, num_heads, seq_len, head_dim] + v = v.repeat_interleave(repeat_factor, dim=1) # [batch, num_heads, seq_len, head_dim] + + # Step 3: Scaled dot product attention + # Scale Q + q_scaled = q * scale_factor + + # Q @ K^T: [batch, num_heads, seq_len, head_dim] @ [batch, num_heads, head_dim, seq_len] + # -> [batch, num_heads, seq_len, seq_len] + k_transposed = k.transpose(-2, -1) # [batch, num_heads, head_dim, seq_len] + scores = torch.matmul(q_scaled, k_transposed) # [batch, num_heads, seq_len, seq_len] + + # Step 4: Apply mask if provided + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + # Step 5: Softmax along the last dimension (seq_len dimension) + attn_weights = F.softmax(scores, dim=-1) # [batch, num_heads, seq_len, seq_len] + + # Step 6: Attention @ V + # [batch, num_heads, seq_len, seq_len] @ [batch, num_heads, seq_len, head_dim] + # -> [batch, num_heads, seq_len, head_dim] + attn_output = torch.matmul(attn_weights, v) + + # Step 7: Reshape back to [batch, seq_len, embed_dim] + attn_output = attn_output.transpose(1, 2) # [batch, seq_len, num_heads, head_dim] + attn_output = attn_output.contiguous().view(batch_size, seq_len, embed_dim) + + # Step 8: Output projection + attn_output_flat = attn_output.view(-1, embed_dim) + output = torch.nn.functional.linear(attn_output_flat, proj_weight, proj_bias) + output = output.view(batch_size, seq_len, embed_dim) + + if need_weights: + # Return attention weights: [batch, num_heads, seq_len, seq_len] -> [batch, seq_len, seq_len] + attn_weights_mean = attn_weights.mean(dim=1) # Average over heads + return output, attn_weights_mean + else: + return (output, None) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index 21995512..c3d3952e 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -1,10 +1,13 @@ import math import torch +import warnings from torch._inductor.codegen import common from torch._inductor.virtualized import V, _ops as ops from . import mlir_common +warnings.filterwarnings('ignore', message='undefined OpHandler\\..*, please add missing op schema') + def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, reduced_shape): if reduction_type == "sum": return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" @@ -15,12 +18,12 @@ def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, if reduction_type == "min": return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" if reduction_type == "any": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" raise AssertionError(reduction_type) class ExtensionOverrides(common.OpOverrides): @staticmethod - def constant(value, src_type, *args, var_info=None, **kwargs): + def constant(value, src_type, *args, **kwargs): if isinstance(src_type, torch.dtype): src_type = mlir_common.DTYPE_TO_MLIR[src_type] @@ -37,8 +40,8 @@ def constant(value, src_type, *args, var_info=None, **kwargs): return f'arith.constant {value} : {src_type}', [1, src_type] @staticmethod - def broadcast(operand, target_size, *args, var_info=None, **kwargs): - src_size, dtype = var_info[operand] + def broadcast(operand, target_size, *args, **kwargs): + src_size, dtype = V.kernel.var_info[operand] src_shape = f"vector<{src_size}x{dtype}>" if src_size > 1 else dtype dst_shape = f"vector<{target_size}x{dtype}>" @@ -63,8 +66,8 @@ def broadcast(operand, target_size, *args, var_info=None, **kwargs): return op_str, [target_size, dtype] @staticmethod - def broadcast_unflat(operand, target_size, *args, var_info=None, **kwargs): - src_size, dtype = var_info[operand] + def broadcast_unflat(operand, target_size, *args, **kwargs): + src_size, dtype = V.kernel.var_info[operand] outer_dim = target_size // src_size src_shape = f"vector<{src_size}x{dtype}>" @@ -87,33 +90,33 @@ def randint64(self, *args, **kwargs): # Special operaitons @staticmethod - def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): + def masked(mask, body, other, *args, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): result = body() val = ops.constant(other, dtype, *args, **kwargs) result = ops.where(mask, result, val) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def where(condition, operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - cond_type = var_info[condition] - operand_type = var_info[operand1] + def where(condition, operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + cond_type = V.kernel.var_info[condition] + operand_type = V.kernel.var_info[operand1] condition = ops.to_bool(condition) if cond_type[0] < tile_size: condition = ops.broadcast(condition, tile_size) elif cond_type[0] > tile_size: operand1 = ops.broadcast(operand1, cond_type[0]) operand2 = ops.broadcast(operand2, cond_type[0]) - tile_size, ret_type = var_info[operand1] + tile_size, ret_type = V.kernel.var_info[operand1] shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type cond_shape = f"vector<{tile_size}xi1>" if tile_size > 1 else "" return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape}, {shape}", [tile_size, ret_type] @staticmethod - def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): + def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): # Extract source information - src_mlir_dtype = var_info[operand][1] - tile_size = var_info[operand][0] + src_mlir_dtype = V.kernel.var_info[operand][1] + tile_size = V.kernel.var_info[operand][0] # Normalize destination type (Torch dtype -> MLIR string) if isinstance(dst_mlir_dtype, torch.dtype): @@ -172,13 +175,13 @@ def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): return op_str, [tile_size, dst_mlir_dtype] @staticmethod - def identity(operand, *args, var_info=None, **kwargs): - operand_info = var_info[operand] + def identity(operand, *args, **kwargs): + operand_info = V.kernel.var_info[operand] return operand, operand_info @staticmethod - def to_dtype_bitcast(operand, dtype, *args, var_info=None, **kwargs): - tile_size, current_src_type = var_info[operand] + def to_dtype_bitcast(operand, dtype, *args, **kwargs): + tile_size, current_src_type = V.kernel.var_info[operand] if isinstance(dtype, torch.dtype): dst_mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] @@ -201,11 +204,12 @@ def to_dtype_bitcast(operand, dtype, *args, var_info=None, **kwargs): # Binary element wise operations @staticmethod - def binary_elementwise_common(operand1, operand2, var_info): + def binary_elementwise_common(operand1, operand2): + V.kernel.var_info = V.kernel.var_info operand1.bounds = operand1.bounds.unknown() operand2.bounds = operand2.bounds.unknown() - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] + op_type1 = V.kernel.var_info[operand1] + op_type2 = V.kernel.var_info[operand2] # Tile size check if op_type1[0] != op_type2[0]: # Try to broad cast @@ -213,33 +217,47 @@ def binary_elementwise_common(operand1, operand2, var_info): rhs_tile_size, rhs_dtype = op_type2 if lhs_tile_size > rhs_tile_size: operand2 = ops.broadcast(operand2, lhs_tile_size) - op_type2 = var_info[operand2] + op_type2 = V.kernel.var_info[operand2] elif lhs_tile_size < rhs_tile_size: operand1 = ops.broadcast(operand1, rhs_tile_size) - op_type1 = var_info[operand1] + op_type1 = V.kernel.var_info[operand1] # Data type check if op_type1[1] != op_type2[1]: if op_type1[1] == "index" or op_type1 == "index": if op_type1[1] == "index": - operand1 = ops.index_cast(operand1, op_type2[1]) - op_type1 = var_info[operand1] + # index -> target type: 2-step casting if target is float + if op_type2[1][0] == "f": + operand1 = ops.index_cast(operand1, "i64") + operand1 = ops.to_dtype(operand1, op_type2[1]) + op_type1 = V.kernel.var_info[operand1] + else: + # index -> integer: direct casting + operand1 = ops.index_cast(operand1, op_type2[1]) + op_type1 = V.kernel.var_info[operand1] if op_type2[1] == "index": - operand2 = ops.index_cast(operand2, op_type1[1]) - op_type2 = var_info[operand2] + # index -> target type: 2-step casting if target is float + if op_type1[1][0] == "f": + operand2 = ops.index_cast(operand2, "i64") + operand2 = ops.to_dtype(operand2, op_type1[1]) + op_type2 = V.kernel.var_info[operand2] + else: + # index -> integer: direct casting + operand2 = ops.index_cast(operand2, op_type1[1]) + op_type2 = V.kernel.var_info[operand2] elif op_type1[1][0] == "i" and op_type2[1][0] == "f": operand1 = ops.to_dtype(operand1, op_type2[1]) - op_type1 = var_info[operand1] + op_type1 = V.kernel.var_info[operand1] elif op_type1[1][0] == "f" and op_type2[1][0] == "i": operand2 = ops.to_dtype(operand2, op_type1[1]) - op_type2 = var_info[operand2] + op_type2 = V.kernel.var_info[operand2] elif op_type1[1][0] == op_type2[1][0]: if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: operand2 = ops.ext(operand2, op_type1[1]) - op_type2 = var_info[operand2] + op_type2 = V.kernel.var_info[operand2] elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: operand1 = ops.ext(operand1, op_type2[1]) - op_type1 = var_info[operand1] + op_type1 = V.kernel.var_info[operand1] else: raise NotImplementedError("Unsupported type converting") @@ -249,45 +267,45 @@ def binary_elementwise_common(operand1, operand2, var_info): return tile_size, ret_type, operand1, operand2 @staticmethod - def abs(operand, *args, var_info=None, **kwargs): + def abs(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def exp(operand, *args, var_info=None, **kwargs): + def exp(operand, *args, **kwargs): # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.exp(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype return f'math.exp %{operand} : {shape}', [tile_size, dtype] @staticmethod - def exp2(operand, *args, var_info=None, **kwargs): + def exp2(operand, *args, **kwargs): # Hands-on part: implement exp2 using math.exp2 - # var_info = {operand: [tile_size, dtype]} - # Ex) var_info[operand] = [8, "f32"] + # V.kernel.var_info = {operand: [tile_size, dtype]} + # Ex) V.kernel.var_info[operand] = [8, "f32"] ln2 = math.log(2) coeff = ops.constant(ln2, "f32") operand = ops.mul(operand, coeff) - return ops.exp(operand), var_info[operand] + return ops.exp(operand), V.kernel.var_info[operand] @staticmethod - def expm1(operand, *args, var_info=None, **kwargs): + def expm1(operand, *args, **kwargs): coeff = ops.constant(1.0, "f32") operand = ops.exp(operand) operand = ops.sub(operand, coeff) - return operand, var_info[operand] + return operand, V.kernel.var_info[operand] @staticmethod - def sqrt(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def sqrt(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -300,14 +318,14 @@ def sqrt(operand, *args, var_info=None, **kwargs): return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def relu(operand, *args, var_info=None, **kwargs): - src_mlir_dtype = var_info[operand][1] - tile_size = var_info[operand][0] + def relu(operand, *args, **kwargs): + src_mlir_dtype = V.kernel.var_info[operand][1] + tile_size = V.kernel.var_info[operand][0] return ops.maximum(operand, ops.constant(0, src_mlir_dtype)), [tile_size, src_mlir_dtype] @staticmethod - def minimum(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def minimum(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": opcode = f'arith.minimumf' @@ -316,8 +334,8 @@ def minimum(operand1, operand2, *args, var_info=None, **kwargs): return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def maximum(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def maximum(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": opcode = f'arith.maximumf' @@ -326,17 +344,17 @@ def maximum(operand1, operand2, *args, var_info=None, **kwargs): return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def cos(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def cos(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.cos(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -347,17 +365,17 @@ def cos(operand, *args, var_info=None, **kwargs): return f'math.cos %{operand} : {shape}', [tile_size, dtype] @staticmethod - def sin(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def sin(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.sin(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -368,51 +386,51 @@ def sin(operand, *args, var_info=None, **kwargs): return f'math.sin %{operand} : {shape}', [tile_size, dtype] @staticmethod - def tan(operand, *args, var_info=None, **kwargs): + def tan(operand, *args, **kwargs): sin_res = ops.sin(operand) cos_res = ops.cos(operand) operand = ops.truediv(sin_res, cos_res) - return operand, var_info[operand] + return operand, V.kernel.var_info[operand] @staticmethod - def lgamma(operand, *args, var_info=None, **kwargs): + def lgamma(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def erf(operand, *args, var_info=None, **kwargs): + def erf(operand, *args, **kwargs): # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.erf(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype return f'math.erf %{operand} : {shape}', [tile_size, dtype] @staticmethod - def cosh(operand, *args, var_info=None, **kwargs): + def cosh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def sinh(operand, *args, var_info=None, **kwargs): + def sinh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def tanh(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def tanh(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.tanh(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -423,80 +441,80 @@ def tanh(operand, *args, var_info=None, **kwargs): return f'math.tanh %{operand} : {shape}', [tile_size, dtype] @staticmethod - def acos(operand, *args, var_info=None, **kwargs): + def acos(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def acosh(operand, *args, var_info=None, **kwargs): + def acosh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def asin(operand, *args, var_info=None, **kwargs): + def asin(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def asinh(operand, *args, var_info=None, **kwargs): + def asinh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def atan2(operand1, operand2, *args, var_info=None, **kwargs): + def atan2(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def atan(operand, *args, var_info=None, **kwargs): + def atan(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def atanh(operand, *args, var_info=None, **kwargs): + def atanh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def copysign(operand1, operand2, *args, var_info=None, **kwargs): + def copysign(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def erfc(operand, *args, var_info=None, **kwargs): + def erfc(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def erfinv(operand, *args, var_info=None, **kwargs): + def erfinv(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def frexp(operand, *args, var_info=None, **kwargs): + def frexp(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def hypot(operand1, operand2, *args, var_info=None, **kwargs): + def hypot(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def log10(operand, *args, var_info=None, **kwargs): + def log10(operand, *args, **kwargs): val_ln = ops.log(operand) - tile_size, dtype = var_info[val_ln] + tile_size, dtype = V.kernel.var_info[val_ln] inv_ln10 = 1/math.log(10) const_op = ops.constant(inv_ln10, dtype) # Multiply: ln(x) * (1/ln(10)) result = ops.mul(val_ln, const_op) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def log2(operand, *args, var_info=None, **kwargs): + def log2(operand, *args, **kwargs): val_ln = ops.log(operand) - tile_size, dtype = var_info[val_ln] + tile_size, dtype = V.kernel.var_info[val_ln] inv_ln10 = 1/math.log(2) const_op = ops.constant(inv_ln10, dtype) # Multiply: ln(x) * (1/ln(10)) result = ops.mul(val_ln, const_op) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def log(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def log(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -508,109 +526,107 @@ def log(operand, *args, var_info=None, **kwargs): return f'math.log %{operand} : {shape}', [tile_size, dtype] @staticmethod - def log1p(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def log1p(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] const_one = ops.constant(1, dtype) - # 3. 덧셈 연산: (x + 1) - # ops.add가 (result_ssa, result_info)를 반환한다고 가정 val_add = ops.add(operand, const_one) result = ops.log(val_add) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def nextafter(operand1, operand2, *args, var_info=None, **kwargs): + def nextafter(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def logical_and(operand1, operand2, *args, var_info=None, **kwargs): - if var_info[operand1][1] != "i1": + def logical_and(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": operand1 = ops.to_bool(operand1) - if var_info[operand2][1] != "i1": + if V.kernel.var_info[operand2][1] != "i1": operand2 = ops.to_bool(operand2) result = ops.and_(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def logical_or(operand1, operand2, *args, var_info=None, **kwargs): - if var_info[operand1][1] != "i1": + def logical_or(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": operand1 = ops.to_bool(operand1) - if var_info[operand2][1] != "i1": + if V.kernel.var_info[operand2][1] != "i1": operand2 = ops.to_bool(operand2) result = ops.or_(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def logical_xor(operand1, operand2, *args, var_info=None, **kwargs): - if var_info[operand1][1] != "i1": + def logical_xor(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": operand1 = ops.to_bool(operand1) - if var_info[operand2][1] != "i1": + if V.kernel.var_info[operand2][1] != "i1": operand2 = ops.to_bool(operand2) result = ops.xor(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def logical_not(operand, *args, var_info=None, **kwargs): - op_info = var_info[operand] + def logical_not(operand, *args, **kwargs): + op_info = V.kernel.var_info[operand] tile_size = op_info[0] dtype = op_info[1] zero_const = ops.constant(0, dtype) result = ops.eq(operand, zero_const) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_and(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_and(operand1, operand2, *args, **kwargs): # Float check - if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): raise ValueError("Bitwise AND not supported for floats") result = ops.and_(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_not(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def bitwise_not(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] # Float check - if var_info[operand][1].startswith("f"): + if V.kernel.var_info[operand][1].startswith("f"): raise ValueError("Bitwise NOT not supported for floats") neg_one = ops.constant(-1, dtype) result = ops.xor(operand, neg_one) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_or(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_or(operand1, operand2, *args, **kwargs): # Float check - if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): raise ValueError("Bitwise AND not supported for floats") result = ops.or_(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_xor(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_xor(operand1, operand2, *args, **kwargs): # Float check - if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): raise ValueError("Bitwise AND not supported for floats") result = ops.xor(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_left_shift(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_left_shift(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def bitwise_right_shift(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_right_shift(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def rsqrt(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def rsqrt(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -622,28 +638,28 @@ def rsqrt(operand, *args, var_info=None, **kwargs): return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def sigmoid(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def sigmoid(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] one = ops.constant(1, dtype) return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, dtype] @staticmethod - def fmod(operand1, operand2, *args, var_info=None, **kwargs): + def fmod(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def isinf(operand, *args, var_info=None, **kwargs): + def isinf(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def isnan(operand, *args, var_info=None, **kwargs): + def isnan(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def round(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def round(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): @@ -652,8 +668,8 @@ def round(operand, *args, var_info=None, **kwargs): return operand, [tile_size, dtype] @staticmethod - def floor(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def floor(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): @@ -662,12 +678,12 @@ def floor(operand, *args, var_info=None, **kwargs): return operand, [tile_size, dtype] @staticmethod - def sign(operand, *args, var_info=None, **kwargs): + def sign(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def trunc(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def trunc(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): @@ -676,8 +692,8 @@ def trunc(operand, *args, var_info=None, **kwargs): return operand, [tile_size, dtype] @staticmethod - def ceil(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def ceil(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): @@ -687,8 +703,8 @@ def ceil(operand, *args, var_info=None, **kwargs): # Logical operations @staticmethod - def neg(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def neg(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -700,8 +716,8 @@ def neg(operand, *args, var_info=None, **kwargs): return f'arith.negf %{operand} : {shape}', [tile_size, dtype] @staticmethod - def reciprocal(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def reciprocal(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -712,8 +728,8 @@ def reciprocal(operand, *args, var_info=None, **kwargs): return ops.truediv(ops.constant(1.0, dtype), operand), [tile_size, dtype] @staticmethod - def eq(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def eq(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "oeq" @@ -727,8 +743,8 @@ def eq(operand1, operand2, *args, var_info=None, **kwargs): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def ne(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def ne(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "one" @@ -742,8 +758,8 @@ def ne(operand1, operand2, *args, var_info=None, **kwargs): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def lt(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def lt(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "olt" @@ -757,8 +773,8 @@ def lt(operand1, operand2, *args, var_info=None, **kwargs): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def gt(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def gt(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "ogt" @@ -772,8 +788,8 @@ def gt(operand1, operand2, *args, var_info=None, **kwargs): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def le(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def le(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "ole" @@ -787,8 +803,8 @@ def le(operand1, operand2, *args, var_info=None, **kwargs): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def ge(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def ge(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "oge" @@ -802,29 +818,29 @@ def ge(operand1, operand2, *args, var_info=None, **kwargs): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def add(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def add(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.add{ret_type[0]}' return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def sub(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def sub(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.sub{ret_type[0]}' return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def mul(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def mul(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.mul{ret_type[0]}' return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def pow(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def pow(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) # Type check & auto cast if ret_type.startswith("f"): operand1 = ops.to_dtype(operand1, "f32") @@ -837,37 +853,37 @@ def pow(operand1, operand2, *args, var_info=None, **kwargs): return f"math.pow{ret_type[0]} %{operand1}, %{operand2} : {shape}", [tile_size, ret_type] @staticmethod - def and_(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def and_(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def or_(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def or_(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def xor(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def xor(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type return f'arith.xori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def lshift(operand1, operand2, *args, var_info=None, **kwargs): + def lshift(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def rshift(operand1, operand2, *args, var_info=None, **kwargs): + def rshift(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def truncdiv(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def truncdiv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type.startswith("f"): @@ -877,8 +893,8 @@ def truncdiv(operand1, operand2, *args, var_info=None, **kwargs): return f'arith.divsi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def floordiv(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def floordiv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type.startswith("f"): @@ -889,8 +905,8 @@ def floordiv(operand1, operand2, *args, var_info=None, **kwargs): return f'arith.floordivsi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def truediv(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def truediv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if not ret_type.startswith("f"): @@ -899,12 +915,12 @@ def truediv(operand1, operand2, *args, var_info=None, **kwargs): return f'arith.divf %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def int_truediv(operand1, operand2, *args, var_info=None, **kwargs): + def int_truediv(operand1, operand2, *args, **kwargs): """ True division for Integers (Int -> Float). Promotes integers to floats, then performs floating-point division. """ - tile_size, src_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + tile_size, src_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if not src_type.startswith("f"): target_float_type = "f32" operand1 = ops.to_dtype(operand1, target_float_type) @@ -912,11 +928,11 @@ def int_truediv(operand1, operand2, *args, var_info=None, **kwargs): src_type = target_float_type result = ops.truediv(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def mod(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def mod(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": raise NotImplementedError("Not support remainder operation for floating point") @@ -925,8 +941,8 @@ def mod(operand1, operand2, *args, var_info=None, **kwargs): return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def remainder(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def remainder(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type.startswith("f"): @@ -937,28 +953,34 @@ def remainder(operand1, operand2, *args, var_info=None, **kwargs): return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def square(operand, *args, var_info=None, **kwargs): + def square(operand, *args, **kwargs): result = ops.mul(operand, operand) - return result, var_info[result] + return result, V.kernel.var_info[result] + + @staticmethod + def fma(operand1, operand2, operand3, *args, **kwargs): + result = ops.mul(operand1, operand2) + result = ops.add(result, operand3) + return result, V.kernel.var_info[result] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # PyTorchSim specific operations @staticmethod - def alloc(size, src_type, *args, var_info=None, **kwargs): + def alloc(size, src_type, *args, **kwargs): return f"memref.alloc() : memref<{size}x{src_type}>", [size, src_type] @staticmethod - def extractelement(operand, idx, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def extractelement(operand, idx, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype return f"vector.extract %{operand}[{idx}]: {dtype} from {shape}", [1, dtype] @staticmethod - def ext(operand, dtype, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def ext(operand, dtype, *args, **kwargs): + op_type = V.kernel.var_info[operand] shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else f"{op_type[1]}" target_type = f"vector<{op_type[0]}x{dtype}>" if op_type[0] > 1 else f"{dtype}" if op_type[0] == "f": @@ -968,15 +990,15 @@ def ext(operand, dtype, *args, var_info=None, **kwargs): return f'{opcode} %{operand} : {shape} to {target_type}', [op_type[0], dtype] @staticmethod - def to_bool(operand, *args, var_info=None, **kwargs): - tile_size, ret_type = var_info[operand] + def to_bool(operand, *args, **kwargs): + tile_size, ret_type = V.kernel.var_info[operand] if ret_type == "i1": return operand, [tile_size, ret_type] - const_one = ops.constant(0, ret_type) + const_zero = ops.constant(0, ret_type) if tile_size > 1: - const_one = ops.broadcast(const_one, tile_size) - ret = ops.ne(operand, const_one) + const_zero = ops.broadcast(const_zero, tile_size) + ret = ops.ne(operand, const_zero) return ret, [tile_size, "i1"] @staticmethod def step(size, dtype, *args, **kwargs): @@ -984,15 +1006,15 @@ def step(size, dtype, *args, **kwargs): return f"vector.step : {index_shape}", [size, dtype] @staticmethod - def index_cast(operand, target_type, *args, var_info=None, **kwrags): - op_type = var_info[operand] + def index_cast(operand, target_type, *args, **kwrags): + op_type = V.kernel.var_info[operand] src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] @staticmethod - def shape_cast(operand, src_shape, dst_shape, *args, var_info=None, **kwargs): - operand_type = var_info[operand] + def shape_cast(operand, src_shape, dst_shape, *args, **kwargs): + operand_type = V.kernel.var_info[operand] return f"vector.shape_cast %{operand} : {src_shape} to {dst_shape}", operand_type @staticmethod @@ -1008,7 +1030,7 @@ def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_nam return line, [red_size, type_name] @staticmethod - def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, var_info=None, **kwargs): + def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, **kwargs): if compute_vec_size == 1: vshape = f"{mlir_dtype}" operation = "affine.load" @@ -1020,8 +1042,8 @@ def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, va return line, [compute_vec_size, mlir_dtype] @staticmethod - def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, var_info=None, **kwargs): - compute_vec_size, mlir_dtype = var_info[operand][0], var_info[operand][1] + def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, **kwargs): + compute_vec_size, mlir_dtype = V.kernel.var_info[operand][0], V.kernel.var_info[operand][1] if compute_vec_size == 1: vshape = f"{mlir_dtype}" diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 23be941c..f2bcba7e 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -7,23 +7,24 @@ from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel +from torch.utils._ordered_set import OrderedSet from torch._inductor import config from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode from torch._inductor.utils import IndentedBuffer from torch._inductor.virtualized import V from torch._inductor.ir import LoopBody from torch._inductor import dependencies +from torch._inductor.codegen.common import BackendFeature from . import mlir_common from . import mlir_lowering # DO NOT REMOVE THIS LINE, it is used for lowering +from . import mlir_decomposition # DO NOT REMOVE THIS LINE, it is used for decomposition class MLIRScheduling(BaseScheduling): count = 0 target_kernel = MLIRKernel def __init__(self, scheduler): self.scheduler = scheduler - self.scheduler.can_fuse_origin = self.scheduler.can_fuse - self.scheduler.can_fuse = self.can_fuse_with_exceptions #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() self._ready_to_flush = False @@ -31,81 +32,33 @@ def __init__(self, scheduler): config.inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it! self.max_fusion_size = 5 - def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: - # Extract base template node - base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] - base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] - if node1.get_device() != node2.get_device(): - return False - if not (isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(node2, (SchedulerNode, FusedSchedulerNode))): - return False - - if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): - # For matmul/bmm+reduction case - size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) - target_symbol = symbols("r0") - try: - stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] - stride = int(sympify(stride).coeff(target_symbol)) - except: - return False - - # We can't fuse dim=-1 - layout_possible = stride != 1 - # Directed linked? - dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 - dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) - return size_match and layout_possible and dependency_check and dependency_size - - # For prologue fusion case - if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - target_node = base_template_node2[0].node - if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': - return False - if node1.is_reduction(): - return False - if len(node1.read_writes.writes) != 1: - return False - if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME - return False - - # Currently only BMM, MM support prologue fusion - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): - return False - # We don't fuse this edge case... - if base_template_node2[0].group[1][0][0] == 1: - return False - - if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: - node1 = self.revert_group(node1) - return True - - return self.scheduler.can_fuse_origin(node1, node2) - def _set_flush_status(self, status: bool): self._ready_to_flush = status + def reset_kernel_group(self): + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + + def get_backend_features(self, device): + """Return a set of .codegen.common.BackendFeature()""" + return OrderedSet([BackendFeature.REDUCE_TO_SINGLE_ELEMENT]) + def can_fuse_vertical(self, node1, node2): return self.can_fuse_horizontal(node1, node2) def can_fuse_horizontal(self, node1, node2): if not extension_config.CONFIG_FUSION: return False + if (len(node1.get_nodes())+ len(node2.get_nodes())) > self.max_fusion_size: return False + _, (vars1, reduce1) = node1.group _, (vars2, reduce2) = node2.group - - # Reduction is currently not supported - if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template() and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION: - return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users - if node1.is_reduction() or node2.is_reduction(): - return False + # For input/dependency checks + reads1 = {dep.name for dep in node1.read_writes.reads} + reads2 = {dep.name for dep in node2.read_writes.reads} + writes1 = {dep.name for dep in node1.read_writes.writes} + writes2 = {dep.name for dep in node2.read_writes.writes} # Can't fuse two template node if node1.is_template() and node2.is_template(): @@ -114,17 +67,37 @@ def can_fuse_horizontal(self, node1, node2): if '_unsafe_index' in node1.get_nodes()[0].node.origins or "_unsafe_index" in node2.get_nodes()[0].node.origins: return False - # Check template node fusion - if node1.is_template() or node2.is_template(): + # Extract base template node + base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] + base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] + + # Case 0: Reduction fusion + if ( + node1.is_reduction() + and node2.is_reduction() + and not node1.is_template() + and not node2.is_template() + and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION + ): + # 1) Same loop/iteration domain + same_iter = vars1 == vars2 and reduce1 == reduce2 + # 2) No data dependency between the two reductions + no_dependency = not ( + writes1 & (reads2 | writes2) or writes2 & (reads1 | writes1) + ) + return same_iter and no_dependency + + # Case 1: Template + Pointwise fusion + if len(base_template_node1) == 1 and len(base_template_node2) == 0 and not node2.is_reduction(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - template_node1 = next((n for n in node1.get_nodes() if n.is_template()), None) - template_node2 = next((n for n in node2.get_nodes() if n.is_template()), None) - if template_node1 and len(node1.get_nodes()) == 1 and isinstance(template_node1.node.template, MLIRMaxPoolTemplate) or \ - template_node2 and len(node2.get_nodes()) == 1 and isinstance(template_node2.node.template, MLIRMaxPoolTemplate): + template_node = base_template_node1[0] + epilogue_node = node2 + + if isinstance(template_node.node.template, MLIRMaxPoolTemplate): return False # Pointwise check @@ -133,23 +106,76 @@ def can_fuse_horizontal(self, node1, node2): if v1_total != v2_total: return False - # Pattern check - template_node, act_node = (template_node1, node2) if template_node1 else (template_node2, node1) - has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) + # Pattern check: check data dependency between act_node and template_node + template_sched_nodes = list(template_node.get_nodes()) + # Buffers produced by the template (its outputs) + template_writes = { + dep + for n in template_sched_nodes + for dep in n.read_writes.writes + } + # Buffers still required by the activation node (unmet) or read by it + epilogue_unmet = { dep for dep in epilogue_node.unmet_dependencies } + has_depedency = bool(template_writes) and epilogue_unmet.issubset(template_writes) if not has_depedency: return False # Revert act_node.group : simplify_and_reorder() modified _body, _size, group - if template_node.group != act_node.group: + if template_node.group != epilogue_node.group: # We don't fuse this case... if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: return False - if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): + if list(template_node.group[1][0]) != list(epilogue_node.get_nodes()[0].node.data.get_size()): return False - self.revert_group(act_node) + self.revert_group(epilogue_node) return True + # Case 2: Tempalte + Reduction fusion + if len(base_template_node1) == 1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + return False + + size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) + target_symbol = symbols("r0") + try: + stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] + stride = int(sympify(stride).coeff(target_symbol)) + except: + return False + + # We can't fuse dim=-1 + layout_possible = stride != 1 + # Directed linked? + dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 + dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) + return size_match and layout_possible and dependency_check and dependency_size + + # Case 3: Prologue(Pointwise) + Tempalte + if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + + target_node = base_template_node2[0].node + # Currently only BMM, MM support prologue fusion + if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + return False + + if len(node1.read_writes.writes) != 1: + return False + if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME + return False + + # We don't fuse this edge case... + if base_template_node2[0].group[1][0][0] == 1: + return False + + if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: + node1 = self.revert_group(node1) + return True + # Check elementwise fusion if vars1 == vars2 and reduce1 == reduce2: return True @@ -165,6 +191,8 @@ def revert_group(self, act_nodes, args=None, var_ranges=None): act_node.node.get_store_function(), (args if act_node.node.get_reduction_type() else args[:1]), var_ranges, + args[0], + args[1] ) index_size = [] reduce_size = [] @@ -180,12 +208,13 @@ def revert_group(self, act_nodes, args=None, var_ranges=None): def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) - def codegen_nodes(self, nodes): + def codegen_node(self, _node): + nodes = _node.get_nodes() _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group - # Note: We assume that ther is at least one loop in the nodes + # Note: We assume that there is at least one loop in the nodes # But, inductor simplifies the group, there could be no loop # In that case, we add dummy loop(size=1) to the group if len(group) == 0: @@ -210,8 +239,8 @@ def codegen_nodes(self, nodes): kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) - kernel_name = self.define_kernel(src_code, kernel_name_candidate, ex_kernel.vector_lane, + src_code, meta_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) + kernel_name = self.define_kernel(src_code, meta_code, kernel_name_candidate, ex_kernel.vector_lane, ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) ex_kernel.call_kernel(kernel_name) _, args, _, _ = ex_kernel.args.mlir_argdefs() @@ -230,57 +259,50 @@ def codegen_sync(self): pass def flush(self): - self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + self.reset_kernel_group() self._set_flush_status(False) def define_function(self, kernel): partial_code, function_name = kernel.def_function() if partial_code is not None and function_name not in self.outer_function: with V.set_kernel_handler(kernel): - code = partial_code.finalize() + code = partial_code.finalize_all() wrapper = V.graph.wrapper_code wrapper.header.writeline(code) self.outer_function.add(function_name) - def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): + def define_kernel(self, src_code, meta_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: kernel_name = wrapper.src_to_kernel[src_code] else: wrapper.src_to_kernel[src_code] = kernel_name - codecache_def = IndentedBuffer() codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") codecache_def.writeline(f"vectorlane_size={vector_lane},") codecache_def.writeline(f"loop_size={loop_size},") codecache_def.writeline(f"spad_info={spad_info},") codecache_def.writeline(f"origins={origins},") - codecache_def.writeline("arg_attributes=arg_attributes,") + codecache_def.writeline(f"arg_attributes={meta_code},") codecache_def.writeline(f"vlen={extension_config.vpu_vector_length_bits})") - wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) + wrapper.define_kernel(kernel_name, codecache_def.getvalue(), gpu=False) return kernel_name - def codegen_template(self, template_node, epilogue_nodes): - # Handle prologue pattern - prologue_nodes = [] - if not template_node.is_template(): - epilogue_nodes = [template_node] + epilogue_nodes - for i, node in enumerate(epilogue_nodes): - if node.is_template(): - template_node = node - prologue_nodes = epilogue_nodes[:i] - epilogue_nodes = epilogue_nodes[i+1:] - break - + def codegen_template(self, template_node, epilogue_nodes, prologue_nodes): # Generate template code template_buffer = template_node.node kernel, tile_candidates, render = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - src_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) + src_code, meta_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) with V.set_kernel_handler(kernel): - kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, + kernel_name = self.define_kernel(src_code, meta_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, kernel.loop_size, origins={str(i) for i in template_node.node.origins}) self.define_function(kernel) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index a36bc907..304d0090 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -13,8 +13,8 @@ from typing import List, Optional from unittest.mock import patch -from torch._inductor.codegen.common import KernelTemplate, ChoiceCaller, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer +from torch._inductor.codegen.common import KernelTemplate, CSE, DeferredLine +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -32,6 +32,9 @@ from PyTorchSimFrontend import extension_config from . import mlir_common +# Configure logger for mlir_template module +logger = extension_config.setup_logger() + class IndentedBufferGroup: def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""): self.kernel = kernel @@ -386,7 +389,6 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio return tile_candidates def meta_kernel(self): - wrapper = V.graph.wrapper_code kernel_arg_attributes = self.kernel_arg_attributes _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() if kernel_arg_attributes is not None: @@ -394,18 +396,14 @@ def meta_kernel(self): for idx in range(len(arg_attributes)): if arg_attributes[idx][0] == name: arg_attributes[idx][1] = attr - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - # Dump loop and load/store information - wrapper.add_import_once(f"loop_info = {self.loop_info}") - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") + return arg_attributes def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", - call_args, cuda=False) + kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args) def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: @@ -479,7 +477,7 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ src_code = ( partial_code if isinstance(partial_code, str) - else partial_code.finalize() + else partial_code.finalize_all() ) # For consistency, white space could make wrong write_path @@ -487,38 +485,36 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ buffer.splice(src_code) src_code = buffer.getvalue() self._prepare_simulator_headers(src_code) - return src_code + meta_code = self.meta_kernel() + return src_code, meta_code def make_choices(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): choices = [] for tile_info in tile_candidates: - if extension_config.CONFIG_DEBUG_MODE: - # Compute Tile M, N, K DMA Tile M, N, K - print(f"[Auto-tune] Trying tile size: {list(tile_info)}") - src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) + # Compute Tile M, N, K DMA Tile M, N, K + logger.debug(f"Auto-tune: Trying tile size: {list(tile_info)}") + src_code, meta_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) bench_runner = self.run_bench([template_node], self.kernel_name, src_code) - choices.append((bench_runner, src_code, tile_info, self.loop_size)) + choices.append((bench_runner, src_code, meta_code, tile_info, self.loop_size)) self.reset(reason=None) return choices def _log_autotune_result(self, best_choice, best_cycle): - tile_size = best_choice[2] - print( - f"[Auto-tune] Optimal tile size: {list(tile_size)}, " + tile_size = best_choice[3] + logger.debug( + f"Auto-tune: Optimal tile size: {list(tile_size)}, " f"cycles: {best_cycle}" ) def codegen_nodes(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): if "autotune" in extension_config.codegen_mapping_strategy and len(tile_candidates): - src_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) + src_code, meta_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) self.loop_size = loop_size else: tile_info = tile_candidates[0] if tile_candidates else None - src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) + src_code, meta_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) - with V.set_kernel_handler(self): - self.meta_kernel() - return src_code + return src_code, meta_code def _prepare_simulator_headers(self, src_code): spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" @@ -753,7 +749,7 @@ def hook(): return "" def def_function(self): - _, call_args, _ = self.kernel_group.args.python_argdefs() + _, call_args, _, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) return PartialRender( @@ -929,7 +925,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): _, operand_type = self.var_info[value] if mlir_dtype != operand_type: - value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) + value = ops.to_dtype(value, mlir_dtype) compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) # Generate vector load instruction buffer_name = name if not store_force else None @@ -1153,7 +1149,7 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): """ super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] - self.output_node: Buffer = Buffer("buf_out", layout) + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) self.input_reorder = input_reorder self.layout = layout @@ -1218,7 +1214,10 @@ def make_kernel_render( self.output_node.get_layout(), make_kernel_render, bmreq, + False, # supports_epilogue_fusion self, + kwargs, + "" # Currently Empty description ) def get_tile_candidates(self, **kwargs): diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index 8aa849b1..3f5673a8 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -1,5 +1,6 @@ from typing import List import os +import sys import numpy as np import torch from pathlib import Path @@ -7,6 +8,13 @@ from PyTorchSimFrontend.extension_codecache import hash_prefix from Simulator.simulator import TOGSimulator from PyTorchSimFrontend import extension_config +from PyTorchSimDevice.extension_device_interface import ExtensionDeviceInterface + +from torch._dynamo.device_interface import register_interface_for_device + +# Configure logger for Scheduler module +logger = extension_config.setup_logger() + def import_module_from_path(module_name, path): module_path = Path(path) # Convert to Path object for safety @@ -168,14 +176,16 @@ def setup_device(cls): return cls.NPU_MODULE source_file_path = os.path.dirname(os.path.abspath(__file__)) source_file = os.path.join( - source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimFrontend/extension_device.cpp" + source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimDevice/extension_device.cpp" ) + hook_file = os.path.join(source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimDevice/extension_hooks.cpp") import torch.utils.cpp_extension module = torch.utils.cpp_extension.load( name="npu", sources=[ str(source_file), + str(hook_file), ], extra_cflags=["-g"], verbose=True, @@ -194,17 +204,21 @@ def setup_device(cls): from PyTorchSimFrontend.mlir.mlir_scheduling import ( MLIRScheduling ) + register_backend_for_device( - "npu", MLIRScheduling, ExtensionWrapperCodegen - ) - assert( - get_scheduling_for_device("npu") == MLIRScheduling + "npu", + lambda scheduling: MLIRScheduling(scheduling), + ExtensionWrapperCodegen ) + import PyTorchSimDevice.extension_device_op_overrides + assert( get_wrapper_codegen_for_device("npu") == ExtensionWrapperCodegen ) cls.NPU_MODULE = module + sys.modules['torch.npu'] = module + register_interface_for_device(module.custom_device(), ExtensionDeviceInterface) return module def submit(self, batched_req, partition_idx) -> List[RequestReturn]: @@ -369,7 +383,7 @@ def __init__(self, num_request_queue=1, max_batch=1, engine_select=FIFO_ENGINE, elif engine_select == Scheduler.RR_ENGINE: self.execution_engine = RoundRobinRunner(self.tog_simulator, self.num_request_queue) else: - print(f"Not supporetd engine type {engine_select}") + logger.error(f"Not supported engine type {engine_select}") exit(1) def add_request(self, request: Request, request_time=-1): @@ -430,9 +444,11 @@ def finish_request(self, req : Request): self.finish_queue.append(req) self.request_queue[req.request_queue_idx].remove(req) turnaround_time, response_time, tbt_time = req.get_latency() - print(f"[Request-{req.id} finished] partition: {req.request_queue_idx} arrival_time: " - f"{req.arrival_time} start_time: {req.start_time[0]} turnaround latency: {turnaround_time}, " - f"response time: {response_time} tbt_time: {tbt_time}") + logger.info( + f"[Request-{req.id} finished] partition: {req.request_queue_idx} arrival_time: " + f"{req.arrival_time} start_time: {req.start_time[0]} turnaround latency: {turnaround_time}, " + f"response time: {response_time} tbt_time: {tbt_time}" + ) def per_schedule(self, request_queue_idx): # Wait partition is idle @@ -443,11 +459,13 @@ def per_schedule(self, request_queue_idx): if not request_list: return False - print(f"[Request issue] partition: {request_queue_idx} batch size: {len(request_list)}", flush=True) + logger.info(f"[Request issue] partition: {request_queue_idx} batch size: {len(request_list)}") for req in request_list: req.set_start(self.current_time()) - print(f"[Request-{req.id} issue] partition: {req.request_queue_idx} " - f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}", flush=True) + logger.info( + f"[Request-{req.id} issue] partition: {req.request_queue_idx} " + f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}" + ) # Submit batched request self.execution_engine.submit(request_list, request_queue_idx) diff --git a/Simulator/simulator.py b/Simulator/simulator.py index 672ae6ec..7a4f7e0d 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -17,7 +17,46 @@ from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs from PyTorchSimFrontend import extension_config -print_lock = threading.Lock() +# Configure logger for Simulator module +logger = extension_config.setup_logger() +from tqdm import tqdm + + +class ProgressBar: + def __init__(self, desc, silent_mode=False, update_interval=0.5): + self.desc = desc + self.silent_mode = silent_mode + self.update_interval = update_interval + self.pbar = None + self.finished = False + self.progress_thread = None + + def __enter__(self): + if not self.silent_mode: + self.pbar = tqdm( + desc=self.desc, + bar_format='{desc}: {elapsed}', + leave=False, # Don't leave the bar when done (it will disappear) + ncols=80, + disable=False, + total=100, # Use a total for smooth animation + ) + # Update progress bar in a separate thread + def update_progress(): + while not self.finished: + self.pbar.update(1) + time.sleep(self.update_interval) + + self.progress_thread = threading.Thread(target=update_progress, daemon=True) + self.progress_thread.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.finished = True + if not self.silent_mode and self.pbar is not None: + self.pbar.close() + return False + TORCH_TO_NUMPY = { torch.float32: np.float32, @@ -105,9 +144,9 @@ def run_spike(self, args, arg_attributes, runtime_path, binary, vectorlane_size= os.makedirs(os.path.join(runtime_path, "indirect_access"), exist_ok=True) os.makedirs(os.path.join(runtime_path, "dma_access"), exist_ok=True) run = f'spike --isa rv64gcv --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' - if not silent_mode and extension_config.CONFIG_DEBUG_MODE: - print("[Spike] cmd> ", run) - print("[Spike] Running Spike simulator") + if not silent_mode: + logger.debug(f"[Spike] cmd> {run}") + logger.info("[Spike] Running Spike simulator") run_cmd = shlex.split(run) try: stdout_setting = subprocess.DEVNULL if silent_mode else None @@ -115,7 +154,7 @@ def run_spike(self, args, arg_attributes, runtime_path, binary, vectorlane_size= subprocess.check_call(run_cmd, stdout=stdout_setting, stderr=stderr_setting) except subprocess.CalledProcessError as e: if not silent_mode: - print("[Spike] Command failed with exit code", e.returncode) + logger.error(f"[Spike] Command failed with exit code {e.returncode}") error_msg = "" if e.returncode == 200: error_msg = "INVALID_SPAD_ACCESS" @@ -155,41 +194,23 @@ def __init__(self) -> None: pass def compile_and_simulate(self, target_binary, array_size, vectorlane_size, silent_mode=False): - def show_progress(): - i = 0 - while not finished: - i = (i + 1) % 3 - tail = "." * i + " " * (3-i) - with print_lock: - sys.stdout.write("\r[Gem5] Gem5 is running." + tail) - sys.stdout.flush() - time.sleep(1) - with print_lock: - print("") - dir_path = os.path.join(os.path.dirname(target_binary), "m5out") gem5_script_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "gem5_script/script_systolic.py") gem5_cmd = [extension_config.CONFIG_GEM5_PATH, "-r", "--stdout-file=sto.log", "-d", dir_path, gem5_script_path, "-c", target_binary, "--vlane", str(vectorlane_size)] + + is_dryrun = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) or silent_mode + + if not is_dryrun: + logger.debug(f"[Gem5] cmd> {' '.join(gem5_cmd)}") + logger.info("[Gem5] Gem5 simulation started") + try: - # Create progress thread - is_dryrun = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) or silent_mode - if not is_dryrun: - if extension_config.CONFIG_DEBUG_MODE: - print("[Gem5] cmd> ", " ".join(gem5_cmd)) - finished = False - progress_thread = threading.Thread(target=show_progress) - progress_thread.start() - output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) - finished = True - progress_thread.join() - else: - output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) + #with ProgressBar("[Gem5] Running simulation", silent_mode=is_dryrun): + output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) except subprocess.CalledProcessError as e: - print(f"[Gem5] Gem5 simulation failed with error: \"{e.output.decode()}\"") - if not is_dryrun: - finished = True - progress_thread.join() - raise RuntimeError(f"Gem5 Simulation Failed: \"{e.output.decode()}\"") + output_error = e.output.decode() if isinstance(e.output, bytes) else str(e.output) + logger.debug(f"[Gem5] Gem5 simulation failed with error: \"{output_error}\"") + raise RuntimeError(f"Gem5 Simulation Failed: \"{output_error}\"") with open(f"{dir_path}/stats.txt", "r") as stat_file: raw_list = stat_file.readlines() @@ -216,39 +237,21 @@ def get_togsim_command(self): return cmd def simulation(self, model_path, attribute_path="", silent_mode=False, autotune_mode=False): - def show_progress(): - i = 0 - while not finished: - i = (i + 1) % 3 - tail = "." * i + " " * (3-i) - sys.stdout.write("\r[TOGSim] TOGSim is running." + tail) - time.sleep(1) - print("") cmd = f"{self.get_togsim_command()} --models_list {model_path}" if extension_config.CONFIG_TOGSIM_DEBUG_LEVEL: cmd += f" --log_level {extension_config.CONFIG_TOGSIM_DEBUG_LEVEL}" if attribute_path: cmd = f"{cmd} --attributes_list {attribute_path}" - if not silent_mode and extension_config.CONFIG_DEBUG_MODE: - print("[TOGSim] cmd> ", cmd) - - # Create progress thread if not silent_mode: - finished = False - progress_thread = threading.Thread(target=show_progress) - progress_thread.start() + logger.debug(f"[TOGSim] cmd> {cmd}") + logger.info("[TOGSim] TOGSim simulation started") + try: - result = subprocess.check_output(shlex.split(cmd)) - if not silent_mode: - finished = True - progress_thread.join() + with ProgressBar("[TOGSim] Running simulation", silent_mode=silent_mode): + result = subprocess.check_output(shlex.split(cmd)) except subprocess.CalledProcessError as e: - if not silent_mode: - finished = True - progress_thread.join() - with print_lock: - print("[TOGSim] Command failed with exit code", e.returncode) - print("[TOGSim] Error output:", e.output) + logger.error(f"[TOGSim] Command failed with exit code {e.returncode}") + logger.error(f"[TOGSim] Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") assert 0 # Separate Autotune logs @@ -271,10 +274,10 @@ def show_progress(): f.flush() os.fsync(f.fileno()) - if not silent_mode or extension_config.CONFIG_DEBUG_MODE: - model_path_log = f' of "{model_path}" ' if extension_config.CONFIG_DEBUG_MODE else " " - with print_lock: - print(f'[TOGSim] Simulation log{model_path_log}is stored to "{result_path}"') + if not silent_mode: + import logging as _logging + model_path_log = f' of "{model_path}" ' if logger.isEnabledFor(_logging.DEBUG) else " " + logger.info(f'[TOGSim] Simulation log{model_path_log}is stored to "{result_path}"') return result_path def interactive_simulation(self): @@ -282,8 +285,7 @@ def interactive_simulation(self): if extension_config.CONFIG_TOGSIM_DEBUG_LEVEL: cmd += f" --log_level {extension_config.CONFIG_TOGSIM_DEBUG_LEVEL}" - if extension_config.CONFIG_DEBUG_MODE: - print("[TOGSim] cmd> ", cmd) + logger.debug(f"[TOGSim] cmd> {cmd}") if self.process is None: self.process = subprocess.Popen( shlex.split(cmd), @@ -292,28 +294,27 @@ def interactive_simulation(self): universal_newlines=True ) else: - print("[TOGSim] Simulator is already running.") + logger.warning("[TOGSim] Simulator is already running.") def stop(self): if self.process: self.process.terminate() self.process.wait() self.process = None - print("[TOGSim] Simulator stopped.") + logger.info("[TOGSim] Simulator stopped.") def wait(self): if self.process: - print("[TOGSim] Waiting for simulation to complete...") + logger.info("[TOGSim] Waiting for simulation to complete...") self.quit() self.process.wait() self.process = None - print("[TOGSim] Simulation completed.") + logger.info("[TOGSim] Simulation completed.") def send_command(self, command): if self.process: try: - if extension_config.CONFIG_TORCHSIM_DEBUG_MODE: - print(command, flush=True) + logger.debug(command) self.process.stdin.write(command + '\n') self.process.stdin.flush() ret = self.process.stderr.readline().strip() @@ -321,11 +322,11 @@ def send_command(self, command): except BrokenPipeError: err = self.process.stderr.readlines() for line in err: - print(line) + logger.error(line.strip()) self.process = None exit(1) else: - print("Simulator is not running.") + logger.warning("Simulator is not running.") return None def launch(self, onnx_path, attribute_path, arrival_time=0, partion_id=0): @@ -440,7 +441,7 @@ def get_result_from_file(result_path): break if simulation_finished_idx == -1: - print(f"[TOGSim] Warning: Unable to parse the output file ({result_path}). The file may be improperly formatted.") + logger.warning(f"[TOGSim] Warning: Unable to parse the output file ({result_path}). The file may be improperly formatted.") return core_metrics, dram_channel_bw, avg_dram_bw, simulation_time total_stat_lines = lines[simulation_finished_idx:] diff --git a/tests/Diffusion/test_diffusion.py b/tests/Diffusion/test_diffusion.py index c5170209..d6d740fe 100644 --- a/tests/Diffusion/test_diffusion.py +++ b/tests/Diffusion/test_diffusion.py @@ -557,14 +557,14 @@ def test_upsample2d( module = PyTorchSimRunner.setup_device() device = module.custom_device() - #test_upsample2d(device) - #test_groupnorm(device) - #test_groupnorm(device, stride=[1, 1, 320*32, 320]) - #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=320) - #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=1280) - #test_cross_attn_down_block2d(device) - #test_unet_mid_block2d_cross_attn(device) - #test_cross_attn_up_block2d(device) + test_upsample2d(device) + test_groupnorm(device) + test_groupnorm(device, stride=[1, 1, 320*32, 320]) + test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=320) + test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=1280) + test_cross_attn_down_block2d(device) + test_unet_mid_block2d_cross_attn(device) + test_cross_attn_up_block2d(device) test_unet2d_condition_model(device) #test_unet_conditional( # device=device, diff --git a/tests/Llama/test_llama.py b/tests/Llama/test_llama.py index 443f3fc2..889e5fa8 100644 --- a/tests/Llama/test_llama.py +++ b/tests/Llama/test_llama.py @@ -101,7 +101,8 @@ def run_rotary_embedding_test( vocab_size=8192, _attn_implementation = "sdpa" ) - base_rope = LlamaRotaryEmbedding(cfg) + # Pass dim explicitly to avoid config parsing issues + base_rope = LlamaRotaryEmbedding(dim=head_dim, max_position_embeddings=cfg.max_position_embeddings, base=cfg.rope_theta, config=cfg) cpu_rope = copy.deepcopy(base_rope) @@ -375,14 +376,14 @@ def run_llama_model_test( torch.compiler.is_compiling = lambda: True # FIXME. How to fix this? #run_rmsnorm_test(device) #run_rotary_embedding_test(device) - #run_decoder_layer_test( - # device=device, - # batch=args.batch, - # seq_len=args.seq_len, - # dtype=args.dtype, - # rtol=args.rtol, - # atol=args.atol, - #) + run_decoder_layer_test( + device=device, + batch=args.batch, + seq_len=args.seq_len, + dtype=args.dtype, + rtol=args.rtol, + atol=args.atol, + ) run_llama_model_test(device) #run_custom_llama_test( # device=device, diff --git a/tests/MoE/test_moe.py b/tests/MoE/test_moe.py index ae16f0b0..9ebfb11e 100644 --- a/tests/MoE/test_moe.py +++ b/tests/MoE/test_moe.py @@ -4,7 +4,6 @@ import copy import matplotlib.pyplot as plt - import torch import torch.nn as nn from torch.distributions.normal import Normal @@ -17,6 +16,32 @@ sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +# FIXME. This is a temporary solution to avoid is_forward conflict during backward +def patch_compile_event_logger(): + """Patch CompileEventLogger.compilation_metric to avoid is_forward conflict during backward.""" + from torch._dynamo.utils import CompileEventLogger + from torch._dynamo.utils import get_metrics_context + + original_compilation_metric = CompileEventLogger.compilation_metric + + @staticmethod + def patched_compilation_metric(is_forward=True, **kwargs): + """Patched version that clears is_forward before setting it if there's a conflict.""" + try: + metrics_context = get_metrics_context() + if metrics_context.in_progress() and hasattr(metrics_context, '_metrics'): + # If is_forward is already set and we're trying to set it to a different value, clear it first + current_is_forward = metrics_context._metrics.get('is_forward') + if current_is_forward is not None and current_is_forward != is_forward: + metrics_context._metrics.pop('is_forward', None) + except: + pass + # Call the original function + return original_compilation_metric(is_forward=is_forward, **kwargs) + + # Patch the method + CompileEventLogger.compilation_metric = patched_compilation_metric + def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): pass_message = f"|{name} Test Passed|" fail_message = f"|{name} Test Failed|" @@ -64,6 +89,7 @@ class SparseDispatcher(object): `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. """ + @torch.compiler.disable(recursive=True) def __init__(self, num_experts, gates): """Create a SparseDispatcher.""" gates = gates.cpu() @@ -469,6 +495,9 @@ def test_moe(device): print("\n") def train_moe(device): + # Patch CompileEventLogger to avoid metric conflicts + patch_compile_event_logger() + def perceptron(a, b, c): return a * b + c @@ -589,6 +618,9 @@ def weight_update(a, b, lr): plt.savefig('result.png') def train_moe_mnist(device): + # Patch CompileEventLogger to avoid metric conflicts + patch_compile_event_logger() + torch.manual_seed(0) batch_size = 32 input_size = 28*28 @@ -670,6 +702,9 @@ def train(model, device, train_loader, optimizer, epochs): plt.savefig(f'{name}_result.png') def train_moe_single_iteration(device, iter_idx, is_evaluation=0): + # Patch CompileEventLogger to avoid metric conflicts + patch_compile_event_logger() + # Training moe with mnist dataset for sinlge iteration torch.manual_seed(0) batch_size = 128 diff --git a/tests/test_activation.py b/tests/test_activation.py index 575fc7e8..49a9467c 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -23,9 +23,10 @@ def test_ReLU(device, size=(128, 128)): input = torch.randn(size) x1 = input.to(device=device) x2 = input.to("cpu") - opt_fn = torch.compile(dynamic=False)(torch.nn.functional.relu) + ReLU = torch.nn.ReLU() + opt_fn = torch.compile(dynamic=False)(ReLU) y = opt_fn(x1) - cpu_y = torch.nn.functional.relu(x2) + cpu_y = ReLU(x2) test_result("ReLU", y, cpu_y) def test_GeLU(device, size=(128, 128), approximate='none'): diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index e964319d..97e5cdea 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -44,15 +44,16 @@ def custom_conv2d(a, b, bias): module = PyTorchSimRunner.setup_device() device = module.custom_device() torch._dynamo.config.cache_size_limit = 64 - test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) - test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) - test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) + with torch.no_grad(): + test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) + test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) diff --git a/tests/test_gqa.py b/tests/test_gqa.py new file mode 100644 index 00000000..c5f2f6f6 --- /dev/null +++ b/tests/test_gqa.py @@ -0,0 +1,335 @@ +import sys +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch._dynamo +import argparse + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + + +class GQAMultiheadAttention(nn.Module): + """ + Grouped Query Attention (GQA) implementation. + Query has num_heads, but key/value have num_kv_heads (num_kv_heads < num_heads). + """ + def __init__(self, embed_dim, num_heads, num_kv_heads=None, head_dim=None, bias=True, dropout=0.0): + super().__init__() + assert embed_dim % num_heads == 0 + if head_dim is None: + head_dim = embed_dim // num_heads + assert embed_dim == num_heads * head_dim + + # If num_kv_heads is not specified, use num_heads (standard MHA) + if num_kv_heads is None: + num_kv_heads = num_heads + + assert num_kv_heads <= num_heads + assert embed_dim % num_kv_heads == 0 + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.dropout = dropout + + # QKV projection: Q has embed_dim, K and V have kv_embed_dim each + kv_embed_dim = num_kv_heads * head_dim + total_qkv_dim = embed_dim + 2 * kv_embed_dim + + self.qkv_proj = nn.Linear(embed_dim, total_qkv_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward(self, query, key=None, value=None, attn_mask=None, need_weights=False): + """ + Args: + query: [batch, seq_len, embed_dim] or [seq_len, batch, embed_dim] + key: optional, same shape as query + value: optional, same shape as query + attn_mask: optional attention mask + need_weights: whether to return attention weights + """ + # For compatibility with nn.MultiheadAttention API + if key is None: + key = query + if value is None: + value = query + + # Handle batch_first vs batch_second + if query.dim() == 3: + batch_first = True + batch_size, seq_len, _ = query.shape + else: + batch_first = False + seq_len, batch_size, _ = query.shape + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + # Project QKV + # Use query for QKV projection (standard MHA/GQA pattern) + qkv = self.qkv_proj(query) # [batch, seq_len, total_qkv_dim] + + # Split into Q, K, V + kv_embed_dim = self.num_kv_heads * self.head_dim + q = qkv[:, :, :self.embed_dim] # [batch, seq_len, embed_dim] + k = qkv[:, :, self.embed_dim:self.embed_dim + kv_embed_dim] # [batch, seq_len, kv_embed_dim] + v = qkv[:, :, self.embed_dim + kv_embed_dim:] # [batch, seq_len, kv_embed_dim] + + # Reshape to multi-head format + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) # [batch, seq_len, num_heads, head_dim] + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) # [batch, seq_len, num_kv_heads, head_dim] + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) # [batch, seq_len, num_kv_heads, head_dim] + + # Transpose for attention: [batch, num_heads, seq_len, head_dim] + q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] + k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] + v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] + + # Scaled dot product attention with GQA support + # enable_gqa=True allows different number of heads for Q vs K/V + attn_output = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, + enable_gqa=(self.num_kv_heads < self.num_heads) + ) # [batch, num_heads, seq_len, head_dim] + + # Reshape back: [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, embed_dim] + attn_output = attn_output.transpose(1, 2) # [batch, seq_len, num_heads, head_dim] + attn_output = attn_output.contiguous().view(batch_size, seq_len, self.embed_dim) + + # Output projection + output = self.out_proj(attn_output) # [batch, seq_len, embed_dim] + + if not batch_first: + output = output.transpose(0, 1) # [seq_len, batch, embed_dim] + + if need_weights: + # Compute attention weights for return + # This is simplified - in practice you'd want the actual attention weights + attn_weights = None + return output, attn_weights + else: + return output + + +def test_gqa_attention(device, batch=1, seq_len=32, embed_dim=768, num_heads=12, num_kv_heads=4): + """ + Test Grouped Query Attention (GQA) where num_kv_heads < num_heads. + + Args: + device: target device + batch: batch size + seq_len: sequence length + embed_dim: embedding dimension + num_heads: number of query heads + num_kv_heads: number of key/value heads (should be <= num_heads) + """ + print(f"Testing GQA Attention (batch={batch}, seq_len={seq_len}, embed_dim={embed_dim}, " + f"num_heads={num_heads}, num_kv_heads={num_kv_heads})") + + # Create GQA model + gqa = GQAMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + bias=True, + dropout=0.0 + ).eval() + + # Initialize weights + torch.nn.init.normal_(gqa.qkv_proj.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.qkv_proj.bias, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.out_proj.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.out_proj.bias, mean=0.0, std=0.02) + + # Create input + x = torch.randn(batch, seq_len, embed_dim) + query = x.clone() + key = x.clone() + value = x.clone() + + # Run on custom device + gqa_device = gqa.to(device) + q1, k1, v1 = query.to(device), key.to(device), value.to(device) + + compiled_gqa = torch.compile(gqa_device, dynamic=False) + with torch.no_grad(): + out_device = compiled_gqa(q1, k1, v1) + + # Run on CPU + gqa_cpu = gqa.cpu() + q2, k2, v2 = query.cpu(), key.cpu(), value.cpu() + with torch.no_grad(): + out_cpu = gqa_cpu(q2, k2, v2) + + test_result("GQA Attention", out_device, out_cpu) + print("Max diff > ", torch.max(torch.abs(out_device.cpu() - out_cpu))) + print("GQA Attention Simulation Done") + + +def test_standard_mha_via_gqa(device, batch=1, seq_len=32, embed_dim=768, num_heads=12): + """ + Test standard Multi-Head Attention using GQA with num_kv_heads == num_heads. + This should behave the same as standard MHA. + """ + print(f"Testing Standard MHA via GQA (batch={batch}, seq_len={seq_len}, " + f"embed_dim={embed_dim}, num_heads={num_heads})") + + test_gqa_attention(device, batch, seq_len, embed_dim, num_heads, num_kv_heads=num_heads) + + +def test_repeat_interleave_compilation(device, batch=1, seq_len=32, embed_dim=768, num_heads=12, num_kv_heads=4): + """ + Test that repeat_interleave operation compiles and works correctly using scaled_dot_product_attention implementation. + + This test uses the exact implementation from F.scaled_dot_product_attention to verify + that repeat_interleave works correctly when enable_gqa=True. + + Args: + device: target device + batch: batch size + seq_len: sequence length + embed_dim: embedding dimension + num_heads: number of query heads + num_kv_heads: number of key/value heads (should be < num_heads) + """ + import math + + print(f"Testing repeat_interleave compilation using scaled_dot_product_attention implementation " + f"(batch={batch}, seq_len={seq_len}, embed_dim={embed_dim}, " + f"num_heads={num_heads}, num_kv_heads={num_kv_heads})") + + head_dim = embed_dim // num_heads + assert num_kv_heads < num_heads, "num_kv_heads must be less than num_heads for GQA" + + # Create Q, K, V tensors + # Q: [batch, num_heads, seq_len, head_dim] + # K, V: [batch, num_kv_heads, seq_len, head_dim] + q = torch.randn(batch, num_heads, seq_len, head_dim) + k = torch.randn(batch, num_kv_heads, seq_len, head_dim) + v = torch.randn(batch, num_kv_heads, seq_len, head_dim) + + # Move to device + q_device = q.to(device) + k_device = k.to(device) + v_device = v.to(device) + + # Implementation from F.scaled_dot_product_attention + def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight, value, attn_weight @ value + + # Compile the function + compiled_attn = torch.compile(scaled_dot_product_attention, dynamic=False) + + # Run on custom device with enable_gqa=True + with torch.no_grad(): + output_device = compiled_attn(q_device, k_device, v_device, + attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=True) + + # Run on CPU for comparison + q_cpu = q.cpu() + k_cpu = k.cpu() + v_cpu = v.cpu() + with torch.no_grad(): + output_cpu = scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, + attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=True) + + # Compare results + test_result("repeat_interleave in scaled_dot_product_attention", output_device[0], output_cpu[0]) + print("Max diff > ", torch.max(torch.abs(output_device[0].cpu() - output_cpu[0]))) + test_result("repeat_interleave in scaled_dot_product_attention", output_device[1], output_cpu[1]) + print("Max diff > ", torch.max(torch.abs(output_device[1].cpu() - output_cpu[1]))) + test_result("repeat_interleave in scaled_dot_product_attention", output_device[2], output_cpu[2]) + print("Max diff > ", torch.max(torch.abs(output_device[2].cpu() - output_cpu[2]))) + print("repeat_interleave compilation test Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="npu", help="Device to use") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--seq_len", type=int, default=32, help="Sequence length") + parser.add_argument("--embed_dim", type=int, default=768, help="Embedding dimension") + parser.add_argument("--num_heads", type=int, default=8, help="Number of query heads") + parser.add_argument("--num_kv_heads", type=int, default=4, help="Number of key/value heads") + parser.add_argument("--test_standard", action="store_true", help="Also test standard MHA via GQA") + parser.add_argument("--test_repeat_interleave", action="store_true", help="Test repeat_interleave compilation") + + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + test_repeat_interleave_compilation( + device=device, + batch=args.batch, + seq_len=args.seq_len, + embed_dim=args.embed_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads + ) + + # Test GQA + test_gqa_attention( + device=device, + batch=args.batch, + seq_len=args.seq_len, + embed_dim=args.embed_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads + ) + + # Optionally test standard MHA via GQA + # if args.test_standard: + # test_standard_mha_via_gqa( + # device=args.device, + # batch=args.batch, + # seq_len=args.seq_len, + # embed_dim=args.embed_dim, + # num_heads=args.num_heads + # ) diff --git a/tests/test_layernorm.py b/tests/test_layernorm.py index 28e38d37..a2e842d0 100644 --- a/tests/test_layernorm.py +++ b/tests/test_layernorm.py @@ -44,5 +44,6 @@ def test_LayerNorm(device, size=(64, 64)): from Scheduler.scheduler import PyTorchSimRunner module = PyTorchSimRunner.setup_device() device = module.custom_device() - #test_LayerNorm(device) - test_LayerNorm(device, shape) + with torch.no_grad(): + #test_LayerNorm(device) + test_LayerNorm(device, shape) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index e6e8cc1e..005c3ed2 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -42,8 +42,17 @@ def test_softmax(device, size=(128, 128), dim=1): #cpu_y = softmax3(x2, cpu_max, cpu_sum) #test_result("Softmax", y, cpu_y) - opt_fn = torch.compile(dynamic=False)(torch.nn.functional.softmax) - y = opt_fn(x1, dim=dim) + class SoftmaxModule(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.softmax(x, dim=self.dim) + + softmax_module = SoftmaxModule(dim=dim).to(device) + opt_fn = torch.compile(dynamic=False)(softmax_module) + y = opt_fn(x1) cpu_y = torch.nn.functional.softmax(x2, dim=dim) test_result("Softmax", y, cpu_y)