Skip to content

Commit ad8d9e1

Browse files
authored
[mlir][gpu] Use arith dialect to lower gpu.global_id (#171614)
This PR lowers the`gpu.global_id` op using the arith dialect instead of the index dialect. Fixes #171303.
1 parent 9f5c963 commit ad8d9e1

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/Arith/IR/Arith.h"
1415
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1516
#include "mlir/Dialect/GPU/Transforms/Passes.h"
16-
#include "mlir/Dialect/Index/IR/IndexOps.h"
1717
#include "mlir/IR/PatternMatch.h"
1818

1919
using namespace mlir;
@@ -26,13 +26,15 @@ struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> {
2626
PatternRewriter &rewriter) const override {
2727
Location loc = op.getLoc();
2828
auto dim = op.getDimension();
29-
auto blockId = gpu::BlockIdOp::create(rewriter, loc, dim);
30-
auto blockDim = gpu::BlockDimOp::create(rewriter, loc, dim);
29+
Value blockId = gpu::BlockIdOp::create(rewriter, loc, dim);
30+
Value blockDim = gpu::BlockDimOp::create(rewriter, loc, dim);
31+
auto indexType = rewriter.getIndexType();
3132
// Compute blockId.x * blockDim.x
32-
auto tmp = index::MulOp::create(rewriter, op.getLoc(), blockId, blockDim);
33-
auto threadId = gpu::ThreadIdOp::create(rewriter, loc, dim);
33+
Value tmp =
34+
arith::MulIOp::create(rewriter, loc, indexType, blockId, blockDim);
35+
Value threadId = gpu::ThreadIdOp::create(rewriter, loc, dim);
3436
// Compute threadId.x + blockId.x * blockDim.x
35-
rewriter.replaceOpWithNewOp<index::AddOp>(op, threadId, tmp);
37+
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, indexType, threadId, tmp);
3638
return success();
3739
}
3840
};
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: mlir-opt %s -split-input-file -convert-gpu-to-rocdl | FileCheck %s --check-prefixes=ROCDL
2+
// RUN: mlir-opt %s -split-input-file -convert-gpu-to-nvvm | FileCheck %s --check-prefixes=NVVM
3+
4+
gpu.module @kernel {
5+
gpu.func @gpu_global_id() -> (index) {
6+
%global_id_x = gpu.global_id x
7+
gpu.return %global_id_x : index
8+
}
9+
}
10+
11+
// ROCDL-LABEL: llvm.func @gpu_global_id() -> i64 {
12+
// ROCDL: %[[WORKGROUP_0:.*]] = rocdl.workgroup.id.x : i32
13+
// ROCDL: %[[SEXT_0:.*]] = llvm.sext %[[WORKGROUP_0]] : i32 to i64
14+
// ROCDL: %[[WORKGROUP_1:.*]] = rocdl.workgroup.dim.x : i32
15+
// ROCDL: %[[SEXT_1:.*]] = llvm.sext %[[WORKGROUP_1]] : i32 to i64
16+
// ROCDL: %[[MUL_0:.*]] = llvm.mul %[[SEXT_0]], %[[SEXT_1]] : i64
17+
// ROCDL: %[[WORKITEM_0:.*]] = rocdl.workitem.id.x : i32
18+
// ROCDL: %[[SEXT_2:.*]] = llvm.sext %[[WORKITEM_0]] : i32 to i64
19+
// ROCDL: %[[ADD_0:.*]] = llvm.add %[[SEXT_2]], %[[MUL_0]] : i64
20+
// ROCDL: llvm.return %[[ADD_0]] : i64
21+
// ROCDL: }
22+
23+
// NVVM-LABEL: llvm.func @gpu_global_id() -> i64 {
24+
// NVVM: %[[READ_0:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32
25+
// NVVM: %[[SEXT_0:.*]] = llvm.sext %[[READ_0]] : i32 to i64
26+
// NVVM: %[[READ_1:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
27+
// NVVM: %[[SEXT_1:.*]] = llvm.sext %[[READ_1]] : i32 to i64
28+
// NVVM: %[[MUL_0:.*]] = llvm.mul %[[SEXT_0]], %[[SEXT_1]] : i64
29+
// NVVM: %[[READ_2:.*]] = nvvm.read.ptx.sreg.tid.x : i32
30+
// NVVM: %[[SEXT_2:.*]] = llvm.sext %[[READ_2]] : i32 to i64
31+
// NVVM: %[[ADD_0:.*]] = llvm.add %[[SEXT_2]], %[[MUL_0]] : i64
32+
// NVVM: llvm.return %[[ADD_0]] : i64
33+
// NVVM: }

mlir/test/Dialect/GPU/globalId-rewrite.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,27 @@ module {
88
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
99
// CHECK: %[[BIDY:.*]] = gpu.block_id x
1010
// CHECK-NEXT: %[[BDIMY:.*]] = gpu.block_dim x
11-
// CHECK-NEXT: %[[TMPY:.*]] = index.mul %[[BIDY]], %[[BDIMY]]
11+
// CHECK-NEXT: %[[TMPY:.*]] = arith.muli %[[BIDY]], %[[BDIMY]]
1212
// CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
13-
// CHECK-NEXT: %[[GIDX:.*]] = index.add %[[TIDX]], %[[TMPY]]
13+
// CHECK-NEXT: %[[GIDX:.*]] = arith.addi %[[TIDX]], %[[TMPY]]
1414
%idx = gpu.global_id x
1515
// CHECK: memref.store %[[GIDX]], %[[MEM]][] : memref<index, 1>
1616
memref.store %idx, %mem[] : memref<index, 1>
1717

1818
// CHECK: %[[BIDY:.*]] = gpu.block_id y
1919
// CHECK-NEXT: %[[BDIMY:.*]] = gpu.block_dim y
20-
// CHECK-NEXT: %[[TMPY:.*]] = index.mul %[[BIDY]], %[[BDIMY]]
20+
// CHECK-NEXT: %[[TMPY:.*]] = arith.muli %[[BIDY]], %[[BDIMY]]
2121
// CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
22-
// CHECK-NEXT: %[[GIDY:.*]] = index.add %[[TIDY]], %[[TMPY]]
22+
// CHECK-NEXT: %[[GIDY:.*]] = arith.addi %[[TIDY]], %[[TMPY]]
2323
%idy = gpu.global_id y
2424
// CHECK: memref.store %[[GIDY]], %[[MEM]][] : memref<index, 1>
2525
memref.store %idy, %mem[] : memref<index, 1>
2626

2727
// CHECK: %[[BIDZ:.*]] = gpu.block_id z
2828
// CHECK-NEXT: %[[BDIMZ:.*]] = gpu.block_dim z
29-
// CHECK-NEXT: %[[TMPZ:.*]] = index.mul %[[BIDZ]], %[[BDIMZ]]
29+
// CHECK-NEXT: %[[TMPZ:.*]] = arith.muli %[[BIDZ]], %[[BDIMZ]]
3030
// CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
31-
// CHECK-NEXT: %[[GIDZ:.*]] = index.add %[[TIDZ]], %[[TMPZ]]
31+
// CHECK-NEXT: %[[GIDZ:.*]] = arith.addi %[[TIDZ]], %[[TMPZ]]
3232
%idz = gpu.global_id z
3333
// CHECK: memref.store %[[GIDZ]], %[[MEM]][] : memref<index, 1>
3434
memref.store %idz, %mem[] : memref<index, 1>

0 commit comments

Comments
 (0)