Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !` (global memory), `S@ S!` (shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i64[<N>]` becomes a `memref<Nxi64>` argument; `\! param <name> i64` becomes an `i64` argument. Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Shared Memory**: `\! shared <name> i64[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution (`memref<Nxi64, #gpu.address_space<workgroup>>`). Using the shared name in code pushes its base address onto the stack. Cannot be referenced inside word definitions.
- **Shared Memory**: `\! shared <name> i64[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution (`memref<Nxi64, #gpu.address_space<workgroup>>`). Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for shared accesses. Cannot be referenced inside word definitions.
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer
- **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion
- **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call`
Expand Down
62 changes: 62 additions & 0 deletions gpu_test/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,68 @@ def test_naive_matmul_i64(kernel_runner: KernelRunner) -> None:
assert result == [12, 6, 9, 28, 14, 29]


def test_tiled_matmul_i64(kernel_runner: KernelRunner) -> None:
"""Tiled i64 matmul with shared memory: C = A(4x4) * B(4x4) -> C(4x4).

Uses 2x2 tiles, shared memory for A/B tiles, and BARRIER for sync.
Grid: (2,2,1), Block: (2,2,1) — 4 blocks of 4 threads each.
"""
result = kernel_runner.run(
forth_source=(
"\\! kernel main\n"
"\\! param A i64[16]\n"
"\\! param B i64[16]\n"
"\\! param C i64[16]\n"
"\\! shared SA i64[4]\n"
"\\! shared SB i64[4]\n"
"BID-Y 2 * TID-Y +\n"
"BID-X 2 * TID-X +\n"
"0\n"
"2 0 DO\n"
" 2 PICK 4 * I 2 * + TID-X + CELLS A + @\n"
" TID-Y 2 * TID-X + CELLS SA + S!\n"
" I 2 * TID-Y + 4 * 2 PICK + CELLS B + @\n"
" TID-Y 2 * TID-X + CELLS SB + S!\n"
" BARRIER\n"
" 2 0 DO\n"
" TID-Y 2 * I + CELLS SA + S@\n"
" I 2 * TID-X + CELLS SB + S@\n"
" * +\n"
" LOOP\n"
" BARRIER\n"
"LOOP\n"
"ROT 4 * ROT + CELLS C + !"
),
params={
"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
"B": [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
},
grid=(2, 2, 1),
block=(2, 2, 1),
output_param=2,
output_count=16,
)
expected = [
250,
260,
270,
280,
618,
644,
670,
696,
986,
1028,
1070,
1112,
1354,
1412,
1470,
1528,
]
assert result == expected


# --- User-Defined Words ---


Expand Down
18 changes: 18 additions & 0 deletions include/warpforth/Dialect/Forth/ForthOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,24 @@ def Forth_StoreOp : Forth_StackOpBase<"store"> {
}];
}

def Forth_SharedLoadOp : Forth_StackOpBase<"shared_load"> {
let summary = "Load value from shared memory buffer";
let description = [{
Pops an address from the stack, loads a value from shared/workgroup memory at
that address, and pushes the loaded value onto the stack.
Forth semantics: ( addr -- value )
}];
}

def Forth_SharedStoreOp : Forth_StackOpBase<"shared_store"> {
let summary = "Store value to shared memory buffer";
let description = [{
Pops an address and value from the stack, stores the value to shared/workgroup
memory at the specified address.
Forth semantics: ( x addr -- )
}];
}

def Forth_ParamRefOp : Forth_Op<"param_ref", [Pure]> {
let summary = "Push kernel parameter address onto stack";
let description = [{
Expand Down
85 changes: 83 additions & 2 deletions lib/Conversion/ForthToMemRef/ForthToMemRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
Expand All @@ -27,6 +28,8 @@ namespace {

// Stack configuration constants
constexpr int64_t kStackSize = 256;
constexpr unsigned kWorkgroupAddressSpace =
static_cast<unsigned>(gpu::AddressSpace::Workgroup);

/// Type converter for forth.stack -> memref + index
class ForthToMemRefTypeConverter : public TypeConverter {
Expand Down Expand Up @@ -725,6 +728,83 @@ struct StoreOpConversion : public OpConversionPattern<forth::StoreOp> {
}
};

/// Conversion pattern for forth.shared_load operation (S@).
/// Pops address from stack, loads value via shared/workgroup pointer, pushes
/// value: ( addr -- value )
struct SharedLoadOpConversion
: public OpConversionPattern<forth::SharedLoadOp> {
SharedLoadOpConversion(const TypeConverter &typeConverter,
MLIRContext *context)
: OpConversionPattern<forth::SharedLoadOp>(typeConverter, context) {}
using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(forth::SharedLoadOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

auto i64Type = rewriter.getI64Type();
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
kWorkgroupAddressSpace);

// Load address from stack.
Value addrValue = rewriter.create<memref::LoadOp>(loc, memref, stackPtr);

// Load value from shared memory via address-space-qualified pointer.
Value ptr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, addrValue);
Value loadedValue = rewriter.create<LLVM::LoadOp>(loc, i64Type, ptr);

// Store loaded value back at same position (replaces address).
rewriter.create<memref::StoreOp>(loc, loadedValue, memref, stackPtr);

rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}});
return success();
}
};

/// Conversion pattern for forth.shared_store operation (S!).
/// Pops address and value from stack, stores value via shared/workgroup
/// pointer: ( x addr -- )
struct SharedStoreOpConversion
: public OpConversionPattern<forth::SharedStoreOp> {
SharedStoreOpConversion(const TypeConverter &typeConverter,
MLIRContext *context)
: OpConversionPattern<forth::SharedStoreOp>(typeConverter, context) {}
using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(forth::SharedStoreOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
kWorkgroupAddressSpace);

// Pop address from stack.
Value addrValue = rewriter.create<memref::LoadOp>(loc, memref, stackPtr);

// Pop value from stack.
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value spMinus1 = rewriter.create<arith::SubIOp>(loc, stackPtr, one);
Value value = rewriter.create<memref::LoadOp>(loc, memref, spMinus1);

// Store value to shared memory via address-space-qualified pointer.
Value ptr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, addrValue);
rewriter.create<LLVM::StoreOp>(loc, value, ptr);

// New stack pointer is SP-2 (popped both address and value).
Value spMinus2 = rewriter.create<arith::SubIOp>(loc, spMinus1, one);
rewriter.replaceOpWithMultiple(op, {{memref, spMinus2}});
return success();
}
};

/// Template for converting GPU indexing ops to intrinsic ops.
/// Creates an intrinsic op with the specified name and pushes the value onto
/// the stack.
Expand Down Expand Up @@ -1026,8 +1106,9 @@ struct ConvertForthToMemRefPass
NotOpConversion, LshiftOpConversion, RshiftOpConversion, EqOpConversion,
LtOpConversion, GtOpConversion, NeOpConversion, LeOpConversion,
GeOpConversion, ZeroEqOpConversion, ParamRefOpConversion,
LoadOpConversion, StoreOpConversion, PopFlagOpConversion,
PopOpConversion, PushValueOpConversion>(typeConverter, context);
LoadOpConversion, StoreOpConversion, SharedLoadOpConversion,
SharedStoreOpConversion, PopFlagOpConversion, PopOpConversion,
PushValueOpConversion>(typeConverter, context);

// Add GPU indexing op conversion patterns
patterns.add<IntrinsicOpConversion<forth::ThreadIdXOp>>(typeConverter,
Expand Down
6 changes: 6 additions & 0 deletions lib/Translation/ForthToMLIR/ForthToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack,
} else if (word == "!") {
return builder.create<forth::StoreOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "S@") {
return builder.create<forth::SharedLoadOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "S!") {
return builder.create<forth::SharedStoreOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "TID-X") {
return builder.create<forth::ThreadIdXOp>(loc, stackType, inputStack)
.getResult();
Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/ForthToMemRef/memory-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
// CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr
// CHECK: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr

// shared load (S@): pop address, inttoptr shared addrspace, llvm.load
// CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<{{[1-9][0-9]*}}>
// CHECK: llvm.load %{{.*}} : !llvm.ptr<{{[1-9][0-9]*}}> -> i64

// shared store (S!): pop address + value, inttoptr shared addrspace, llvm.store
// CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<{{[1-9][0-9]*}}>
// CHECK: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr<{{[1-9][0-9]*}}>

module {
func.func private @main() {
%0 = forth.stack !forth.stack
Expand All @@ -23,6 +31,11 @@ module {
%3 = forth.literal %2 42 : !forth.stack -> !forth.stack
%4 = forth.literal %3 100 : !forth.stack -> !forth.stack
%5 = forth.store %4 : !forth.stack -> !forth.stack
%6 = forth.literal %5 2 : !forth.stack -> !forth.stack
%7 = forth.shared_load %6 : !forth.stack -> !forth.stack
%8 = forth.literal %7 9 : !forth.stack -> !forth.stack
%9 = forth.literal %8 3 : !forth.stack -> !forth.stack
%10 = forth.shared_store %9 : !forth.stack -> !forth.stack
return
}
}
8 changes: 7 additions & 1 deletion test/Translation/Forth/memory-ops.forth
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
\ Test ! produces forth.store
\ CHECK: forth.store %{{.*}} : !forth.stack -> !forth.stack

\ Test S@ produces forth.shared_load
\ CHECK: forth.shared_load %{{.*}} : !forth.stack -> !forth.stack

\ Test S! produces forth.shared_store
\ CHECK: forth.shared_store %{{.*}} : !forth.stack -> !forth.stack

\ Test CELLS produces literal 8 + mul
\ CHECK: forth.literal %{{.*}} 8
\ CHECK-NEXT: forth.mul
\! kernel main
1 @ 2 3 !
1 @ 2 3 ! 4 S@ 5 6 S!
4 CELLS