From ef0a157c4a1e23b92d49586bc732b059958952b7 Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Mon, 19 Jan 2026 06:20:00 +0000 Subject: [PATCH 1/2] [TIR] Fix InjectPTXLDG32 segfaults and skip non-CUDA targets --- src/tir/transforms/inject_ptx_ldg32.cc | 45 ++++++++--- .../test_tir_transform_inject_ptx_ldg32.py | 80 +++++++++++++++++++ 2 files changed, 116 insertions(+), 9 deletions(-) create mode 100644 tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc index 8cdef1be44a5..39035739c318 100644 --- a/src/tir/transforms/inject_ptx_ldg32.cc +++ b/src/tir/transforms/inject_ptx_ldg32.cc @@ -35,16 +35,23 @@ namespace tir { class PTXRewriter : public StmtMutator { public: - Stmt VisitStmt_(const AllocateNode* allocate) final { - if (!has_buffer_1) { - has_buffer_1 = true; - // addr[0] -> global_addr / addr[1] -> local_addr - addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, DataType::Int(32), "addr", "local"); - predicate_buffer = - decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), "predicate", "local"); + Stmt AddAllocationsIfNeeded(Stmt body) { + if (!needs_buffer || has_buffer_2) { + return body; } + EnsureBuffers(); + body = Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape, Bool(true), body); + body = + Allocate(predicate_buffer->data, predicate_buffer->dtype, predicate_buffer->shape, + Bool(true), body); + has_buffer_2 = true; + return body; + } + + Stmt VisitStmt_(const AllocateNode* allocate) final { Stmt result = StmtMutator::VisitStmt_(allocate); - if (!has_buffer_2) { + if (needs_buffer && !has_buffer_2) { + EnsureBuffers(); has_buffer_2 = true; result = Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape, Bool(true), result); @@ -82,6 +89,8 @@ class PTXRewriter : public StmtMutator { if (ramp != nullptr) { return result; } + EnsureBuffers(); + needs_buffer = true; local_addr = store->indices[0]; BufferStore addr_store(addr_buffer, global_addr, {IntImm(DataType::Int(32), 0)}); BufferStore local_addr_store(addr_buffer, local_addr, {IntImm(DataType::Int(32), 1)}); @@ -104,7 +113,19 @@ class PTXRewriter : public StmtMutator { return result; } + void EnsureBuffers() { + if (has_buffer_1) { + return; + } + has_buffer_1 = true; + // addr[0] -> global_addr / addr[1] -> local_addr + addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, DataType::Int(32), "addr", "local"); + predicate_buffer = + decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), "predicate", "local"); + } + bool has_buffer_1 = false, has_buffer_2 = false; + bool needs_buffer = false; Buffer addr_buffer, predicate_buffer; }; @@ -113,8 +134,14 @@ namespace transform { Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) { auto pass_func = [enable_inject_ptx_intrin](PrimFunc f, IRModule m, PassContext ctx) { if (enable_inject_ptx_intrin) { + auto target = f->GetAttr("target"); + if (!target.defined() || target.value()->kind->name != "cuda") { + return f; + } auto* n = f.CopyOnWrite(); - n->body = PTXRewriter()(n->body); + PTXRewriter rewriter; + Stmt body = rewriter(n->body); + n->body = rewriter.AddAllocationsIfNeeded(body); // inject ptx } return f; diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py new file mode 100644 index 000000000000..55099f252cb7 --- /dev/null +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import tir as T + + +def _count_alloc(stmt): + num_alloc = [0] + + def visit(n): + if isinstance(n, tvm.tir.Allocate): + num_alloc[0] += 1 + + tvm.tir.stmt_functor.post_order_visit(stmt, visit) + return num_alloc[0] + + +def _count_ptx_ldg32(stmt): + num_call = [0] + + def visit(n): + if isinstance(n, tvm.tir.Call) and n.op.name == "tir.ptx_ldg32": + num_call[0] += 1 + + tvm.tir.stmt_functor.post_order_visit(stmt, visit) + return num_call[0] + + +@T.prim_func +def where_no_alloc(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("cuda")}) + for i in range(4): + C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0)) + + +@T.prim_func +def where_no_alloc_cpu(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")}) + for i in range(4): + C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0)) + + +def test_inject_ptx_ldg32_inserts_alloc_for_no_alloc_func(): + mod = tvm.IRModule.from_expr(where_no_alloc) + assert _count_alloc(mod["main"].body) == 0 + + mod = tvm.tir.transform.InjectPTXLDG32()(mod) + assert _count_alloc(mod["main"].body) > 0 + assert _count_ptx_ldg32(mod["main"].body) == 1 + + +def test_inject_ptx_ldg32_skip_non_cuda_target(): + mod = tvm.IRModule.from_expr(where_no_alloc_cpu) + cpu_target = tvm.target.Target("llvm") + mod = tvm.IRModule({"main": mod["main"].with_attr("target", cpu_target)}) + assert _count_alloc(mod["main"].body) == 0 + + mod = tvm.tir.transform.InjectPTXLDG32()(mod) + assert _count_alloc(mod["main"].body) == 0 + assert _count_ptx_ldg32(mod["main"].body) == 0 + + +if __name__ == "__main__": + tvm.testing.main() From 9ccca983f3fc0d0b4459abf0c982bcde444ff693 Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Mon, 19 Jan 2026 06:44:12 +0000 Subject: [PATCH 2/2] fix lint --- src/tir/transforms/inject_ptx_ldg32.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc index 39035739c318..f52539fa77b3 100644 --- a/src/tir/transforms/inject_ptx_ldg32.cc +++ b/src/tir/transforms/inject_ptx_ldg32.cc @@ -41,9 +41,8 @@ class PTXRewriter : public StmtMutator { } EnsureBuffers(); body = Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape, Bool(true), body); - body = - Allocate(predicate_buffer->data, predicate_buffer->dtype, predicate_buffer->shape, - Bool(true), body); + body = Allocate(predicate_buffer->data, predicate_buffer->dtype, predicate_buffer->shape, + Bool(true), body); has_buffer_2 = true; return body; }