Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0abfffe
PyTorch version upgrade: tested on single-operator tests
wok1909 Sep 24, 2025
b7a275e
[Test] Add torch.no_grad(), change to use torch.nn.ReLU, fuion off
wok1909 Sep 24, 2025
5c5e61c
[Implement] Hook and GuardImpl for extension device
wok1909 Nov 6, 2025
74704b8
[CI] Change the trigger condition
YWHyuk Jan 6, 2026
d3f3298
[CI] Use CMake 3 to build pytorchsim
YWHyuk Jan 6, 2026
0763363
[CI] Seperate base image
YWHyuk Jan 6, 2026
4591403
[Fix] PyTorch2.8 support (WIP)
YWHyuk Jan 7, 2026
b9d4144
[Fix] Use official prologue fusion path
YWHyuk Jan 7, 2026
9abc060
[Fix] Don't split a reduce kernel
YWHyuk Jan 7, 2026
2c7264b
[Fix] Add a missing reduction fusion condition
YWHyuk Jan 7, 2026
b951b95
[Fix] update indirect_index interface for v2.8
YWHyuk Jan 7, 2026
c6ba98c
[Fix] Allow cpp kernel code in the wrapper function
YWHyuk Jan 7, 2026
fd07eda
[Ops] Use V.kernel instead of argument passing
YWHyuk Jan 8, 2026
4bed31b
[Fix] Set epilogue fusoin condition
YWHyuk Jan 8, 2026
758b5b3
[Fix] Support Identity indexing + Fix wrapper codegen
YWHyuk Jan 8, 2026
a7ab604
[Fix] Keep contextvar after reset()
YWHyuk Jan 8, 2026
cd52f57
[Frontend] Add decompsition of default attetnion
YWHyuk Jan 8, 2026
08e0c8b
[Fix] Add missing case
YWHyuk Jan 8, 2026
1d1508a
[Test] Add GQA test file
YWHyuk Jan 8, 2026
862ba44
[Fix+Log] Change logging system + Fix meta_code interface
YWHyuk Jan 9, 2026
75207a4
[Test] Wrap softmax module
YWHyuk Jan 9, 2026
8df5fef
[Log] Add progress bar for auto-tuning
YWHyuk Jan 9, 2026
d7c16b1
[Test/MoE] Disable compiling sparse dispatcher
YWHyuk Jan 9, 2026
c88cabc
[Fix] Support identity in the dram_stride extraction
YWHyuk Jan 12, 2026
67612bb
[Fix] index to float casting
YWHyuk Jan 12, 2026
50ceb58
[Fix] Change vlane_split_axis in case of group-dim
YWHyuk Jan 12, 2026
319fd6c
[Frontend] Fix any operation codegen
YWHyuk Jan 13, 2026
c223258
[Decompose] Use F.softmax for decomposed SDPA
YWHyuk Jan 13, 2026
07be94b
[Frontend] Add recompiliation for ModularIndexing
YWHyuk Jan 13, 2026
e999bfc
[Test] Fix minor bugs in the test folder
YWHyuk Jan 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docker-base-image-2-8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Docker Base Image CI (PyTorch 2.8)

on:
push:
branches: [ "base" ]
branches: [ "base_v2.8" ]
workflow_dispatch:
repository_dispatch:
types: [ build_base ]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docker-image-2-8.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Docker image CI (PyTorch 2.8)

on:
pull_request:
push:
branches: [ "torch_v2.8" ]
workflow_dispatch:

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ RUN apt -y update && \
python3-dev python-is-python3 libboost-all-dev \
libhdf5-serial-dev python3-pydot libpng-dev libelf-dev pkg-config pip \
python3-venv black libssl-dev libasan5 libubsan1 curl device-tree-compiler wget ninja-build && \
pip install onnx matplotlib scikit-learn pydot tabulate && pip install --user conan==1.56.0 && rm -rf /var/lib/apt/lists/*
pip install onnx matplotlib scikit-learn pydot tabulate && pip install --user conan==1.56.0 cmake==3.26.4 && rm -rf /var/lib/apt/lists/*

# Download RISC-V tool chain
RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2023.12.14/riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.12.14-nightly.tar.gz && \
Expand Down
8 changes: 8 additions & 0 deletions PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "ExtensionDeviceGuardImpl.h"
#include <c10/core/impl/DeviceGuardImplRegistry.h>

namespace c10::extension_device::impl {

C10_REGISTER_GUARD_IMPL(extension_device, ExtensionDeviceGuardImpl);

} // namespace c10::extension_device::impl
127 changes: 127 additions & 0 deletions PyTorchSimDevice/ExtensionDeviceGuardImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#pragma once

#include <c10/core/DeviceGuard.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/Stream.h>
#include <c10/core/Event.h>
#include <c10/core/DeviceType.h>
#include <c10/util/Optional.h>

namespace c10::extension_device::impl {

struct ExtensionDeviceGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::PrivateUse1; // ✅ your backend type

ExtensionDeviceGuardImpl() = default;

explicit ExtensionDeviceGuardImpl(DeviceType t) {
TORCH_CHECK(
t == static_type,
"ExtensionDeviceGuardImpl initialized with non-extension_device DeviceType: ",
t);
}

// --------------------------------------------------------------------------
// 기본적인 device guard (CPU처럼 동작)
// --------------------------------------------------------------------------
DeviceType type() const override {
return static_type;
}

Device exchangeDevice(Device d) const override {
TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d);
return d; // nothing to exchange, CPU-like
}

Device getDevice() const override {
return Device(static_type, 0);
}

void setDevice(Device d) const override {
TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d);
}

void uncheckedSetDevice(Device d) const noexcept override {}

DeviceIndex deviceCount() const noexcept override {
return 1; // pretend single device
}

// --------------------------------------------------------------------------
// Stream handling (동기식이므로 기본 stream만 사용)
// --------------------------------------------------------------------------
Stream getStream(Device d) const override {
return Stream(Stream::DEFAULT, d);
}

Stream getNewStream(Device d, int priority = 0) const override {
return Stream(Stream::DEFAULT, d);
}

Stream getStreamFromGlobalPool(Device d, bool = false) const override {
return Stream(Stream::DEFAULT, d);
}

Stream exchangeStream(Stream s) const override {
return s;
}

bool queryStream(const Stream& stream) const override {
(void)stream;
return true;
}

void synchronizeStream(const Stream& stream) const override {
(void)stream;
}

void synchronizeDevice(DeviceIndex device_index) const override {
(void)device_index;
}

// --------------------------------------------------------------------------
// Event handling (전부 no-op)
// --------------------------------------------------------------------------
void destroyEvent(void* event, const DeviceIndex device_index) const noexcept override {
(void)event;
(void)device_index;
}

void record(void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override {
(void)event;
(void)stream;
(void)device_index;
(void)flag;
}

void block(void* event, const Stream& stream) const override {
(void)event;
(void)stream;
}

bool queryEvent(void* event) const override {
(void)event;
return true;
}

void synchronizeEvent(void* event) const override {
(void)event;
}

double elapsedTime(void* start_event, void* end_event, const DeviceIndex device_index) const override {
(void)start_event;
(void)end_event;
(void)device_index;
return 0.0;
}

// --------------------------------------------------------------------------
// Misc (allocator integration)
// --------------------------------------------------------------------------
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override {
(void)data_ptr;
(void)stream;
}
};

} // namespace c10::extension_device::impl
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,12 @@ static inline at::MemoryFormat fix_memory_format(c10::optional<at::MemoryFormat>
return mf;
}

#include "ExtensionDeviceGuardImpl.h"

static uint64_t op_counter = 0;
static uint64_t last_saved_value = 0;

// register guard
namespace at {
namespace detail {

C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);

}} // namespace at::detail
C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::extension_device::impl::ExtensionDeviceGuardImpl);

// basic dummy add function
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
Expand Down Expand Up @@ -159,7 +155,7 @@ at::Tensor custom_to_device(
// A dummy allocator for our custom device, that secretly uses the CPU
struct DummyCustomAllocator final : at::Allocator {
DummyCustomAllocator() = default;
at::DataPtr allocate(size_t nbytes) const override {
at::DataPtr allocate(size_t nbytes) override {
void* data = c10::alloc_cpu(nbytes);
return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)};
}
Expand All @@ -174,6 +170,10 @@ struct DummyCustomAllocator final : at::Allocator {
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}

void copy_data(void* dest, const void* src, std::size_t count) const override {
std::memcpy(dest, src, count);
}
};

// Register our dummy allocator
Expand Down
63 changes: 63 additions & 0 deletions PyTorchSimDevice/extension_device_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from torch._dynamo.device_interface import DeviceInterface, caching_worker_current_devices, caching_worker_device_properties

class _ExtensionDeviceProperties: # FIXME: Dummy property values
name: str = "Extension_device"
platform_name: str
vendor: str
driver_version: str
version: str
max_compute_units: int
gpu_eu_count: int
max_work_group_size: int
max_num_sub_groups: int
sub_group_sizes: list[int]
has_fp16: bool
has_fp64: bool
has_atomic64: bool
has_bfloat16_conversions: bool
has_subgroup_matrix_multiply_accumulate: bool
has_subgroup_matrix_multiply_accumulate_tensor_float32: bool
has_subgroup_2d_block_io: bool
total_memory: int
multi_processor_count: int = 128 # gpu_subslice_count, num_sm
architecture: int
type: str

_ExtensionDeviceProperties = _ExtensionDeviceProperties

class ExtensionDeviceInterface(DeviceInterface):
class Worker:
@staticmethod
def set_device(device: int):
caching_worker_current_devices["extension_device"] = device

@staticmethod
def current_device() -> int:
if "extension_device" in caching_worker_current_devices:
return caching_worker_current_devices["extension_device"]
return torch.xpu.current_device()

@staticmethod
def get_device_properties(device: torch.types.Device = None) -> _ExtensionDeviceProperties:
if device is not None:
if isinstance(device, str):
device = torch.device(device)
assert device.type == "extension_device"
if isinstance(device, torch.device):
device = device.index
if device is None:
device = ExtensionDeviceInterface.Worker.current_device()

if "extension_device" not in caching_worker_device_properties:
device_prop = [
torch.cuda.get_device_properties(i)
for i in range(torch.cuda.device_count())
]
caching_worker_device_properties["extension_device"] = device_prop

return _ExtensionDeviceProperties

@staticmethod
def get_compute_capability(device: torch.types.Device = None):
return 36
27 changes: 27 additions & 0 deletions PyTorchSimDevice/extension_device_op_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

from textwrap import dedent

from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op_overrides
from torch._inductor.codegen.cpu_device_op_overrides import CpuDeviceOpOverrides

class ExtensionDeviceOpOverrides(DeviceOpOverrides):
def import_get_raw_stream_as(self, name: str) -> str:
return dedent(
"""
def get_raw_stream(_):
return 0
"""
)

def set_device(self, device_idx: int) -> str:
return "pass"

def synchronize(self) -> str:
return "pass"

def device_guard(self, device_idx: int) -> str:
return "pass"

register_device_op_overrides("npu", ExtensionDeviceOpOverrides())
register_device_op_overrides("cpu", CpuDeviceOpOverrides())
48 changes: 48 additions & 0 deletions PyTorchSimDevice/extension_hooks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "extension_hooks.h"

bool ExtensionPU1Hooks::isBuilt() const { return true; }
bool ExtensionPU1Hooks::isAvailable() const { return true; }

const at::Generator& ExtensionPU1Hooks::getDefaultGenerator(c10::DeviceIndex idx) const {
if (idx < 0) idx = 0;
static std::vector<at::Generator> gens;
static std::mutex m;
std::lock_guard<std::mutex> g(m);
if (gens.size() <= (size_t)idx) gens.resize((size_t)idx + 1);
if (!gens[idx].defined()) gens[idx] = at::GetGeneratorForPrivateuse1(idx);
return gens[idx]; // 영속 객체 참조 반환
}

at::Generator ExtensionPU1Hooks::getNewGenerator(c10::DeviceIndex idx) const {
if (idx < 0) idx = 0;
return at::GetGeneratorForPrivateuse1(idx);
}

at::Device ExtensionPU1Hooks::getDeviceFromPtr(void* data) const {
return at::Device(at::kPrivateUse1, 0); // MVP: 단일 디바이스 가정
}

bool ExtensionPU1Hooks::isPinnedPtr(const void* data) const {
return false;
}

at::Allocator* ExtensionPU1Hooks::getPinnedMemoryAllocator() const {
return at::getHostAllocator(at::kPrivateUse1);
}

bool ExtensionPU1Hooks::hasPrimaryContext(c10::DeviceIndex device_index) const { return true; }

void ExtensionPU1Hooks::resizePrivateUse1Bytes(const c10::Storage&, size_t) const {
TORCH_CHECK(false, "resizePrivateUse1Bytes not implemented");
}

// REGISTER_EXTENSION_HOOKS(ExtensionPU1Hooks);

namespace {
struct AutoRegistrar {
AutoRegistrar() {
at::RegisterPrivateUse1HooksInterface(new ExtensionPU1Hooks());
}
};
static AutoRegistrar _auto_registrar;
}
30 changes: 30 additions & 0 deletions PyTorchSimDevice/extension_hooks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <ATen/core/CachingHostAllocator.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>

#include <ATen/core/Generator.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/Storage.h>
#include <c10/util/Exception.h>

struct ExtensionPU1Hooks final : public at::PrivateUse1HooksInterface {
ExtensionPU1Hooks() {}
bool isBuilt() const;
bool isAvailable() const;

const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override;

at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override;

at::Device getDeviceFromPtr(void* data) const override;

bool isPinnedPtr(const void* data) const override;

at::Allocator* getPinnedMemoryAllocator() const override;

bool hasPrimaryContext(c10::DeviceIndex device_index) const override;

void resizePrivateUse1Bytes(const c10::Storage& /*storage*/, size_t /*newsize*/) const override;
};
Loading
Loading