diff --git a/CLAUDE.md b/CLAUDE.md index 4bd427f..a763c1b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -87,9 +87,9 @@ uv run ruff format gpu_test/ - **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety - **Operations**: All take stack as input and produce stack as output (except `forth.stack`) -- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). +- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !` (global memory), `S@ S!` (shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). - **Kernel Parameters**: Declared in the `\!` header. `\! kernel ` is required and must appear first. `\! param i64[]` becomes a `memref` argument; `\! param i64` becomes an `i64` argument. Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value). -- **Shared Memory**: `\! shared i64[]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution (`memref>`). Using the shared name in code pushes its base address onto the stack. Cannot be referenced inside word definitions. +- **Shared Memory**: `\! shared i64[]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution (`memref>`). Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for shared accesses. Cannot be referenced inside word definitions. - **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer - **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion - **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call` diff --git a/gpu_test/test_kernels.py b/gpu_test/test_kernels.py index 4f161ab..564dc82 100644 --- a/gpu_test/test_kernels.py +++ b/gpu_test/test_kernels.py @@ -337,6 +337,68 @@ def test_naive_matmul_i64(kernel_runner: KernelRunner) -> None: assert result == [12, 6, 9, 28, 14, 29] +def test_tiled_matmul_i64(kernel_runner: KernelRunner) -> None: + """Tiled i64 matmul with shared memory: C = A(4x4) * B(4x4) -> C(4x4). + + Uses 2x2 tiles, shared memory for A/B tiles, and BARRIER for sync. + Grid: (2,2,1), Block: (2,2,1) — 4 blocks of 4 threads each. + """ + result = kernel_runner.run( + forth_source=( + "\\! kernel main\n" + "\\! param A i64[16]\n" + "\\! param B i64[16]\n" + "\\! param C i64[16]\n" + "\\! shared SA i64[4]\n" + "\\! shared SB i64[4]\n" + "BID-Y 2 * TID-Y +\n" + "BID-X 2 * TID-X +\n" + "0\n" + "2 0 DO\n" + " 2 PICK 4 * I 2 * + TID-X + CELLS A + @\n" + " TID-Y 2 * TID-X + CELLS SA + S!\n" + " I 2 * TID-Y + 4 * 2 PICK + CELLS B + @\n" + " TID-Y 2 * TID-X + CELLS SB + S!\n" + " BARRIER\n" + " 2 0 DO\n" + " TID-Y 2 * I + CELLS SA + S@\n" + " I 2 * TID-X + CELLS SB + S@\n" + " * +\n" + " LOOP\n" + " BARRIER\n" + "LOOP\n" + "ROT 4 * ROT + CELLS C + !" + ), + params={ + "A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "B": [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32], + }, + grid=(2, 2, 1), + block=(2, 2, 1), + output_param=2, + output_count=16, + ) + expected = [ + 250, + 260, + 270, + 280, + 618, + 644, + 670, + 696, + 986, + 1028, + 1070, + 1112, + 1354, + 1412, + 1470, + 1528, + ] + assert result == expected + + # --- User-Defined Words --- diff --git a/include/warpforth/Dialect/Forth/ForthOps.td b/include/warpforth/Dialect/Forth/ForthOps.td index 15a7d4e..82df7b3 100644 --- a/include/warpforth/Dialect/Forth/ForthOps.td +++ b/include/warpforth/Dialect/Forth/ForthOps.td @@ -250,6 +250,24 @@ def Forth_StoreOp : Forth_StackOpBase<"store"> { }]; } +def Forth_SharedLoadOp : Forth_StackOpBase<"shared_load"> { + let summary = "Load value from shared memory buffer"; + let description = [{ + Pops an address from the stack, loads a value from shared/workgroup memory at + that address, and pushes the loaded value onto the stack. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_SharedStoreOp : Forth_StackOpBase<"shared_store"> { + let summary = "Store value to shared memory buffer"; + let description = [{ + Pops an address and value from the stack, stores the value to shared/workgroup + memory at the specified address. + Forth semantics: ( x addr -- ) + }]; +} + def Forth_ParamRefOp : Forth_Op<"param_ref", [Pure]> { let summary = "Push kernel parameter address onto stack"; let description = [{ diff --git a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp index 080fe16..7932c87 100644 --- a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp +++ b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" @@ -27,6 +28,8 @@ namespace { // Stack configuration constants constexpr int64_t kStackSize = 256; +constexpr unsigned kWorkgroupAddressSpace = + static_cast(gpu::AddressSpace::Workgroup); /// Type converter for forth.stack -> memref + index class ForthToMemRefTypeConverter : public TypeConverter { @@ -725,6 +728,83 @@ struct StoreOpConversion : public OpConversionPattern { } }; +/// Conversion pattern for forth.shared_load operation (S@). +/// Pops address from stack, loads value via shared/workgroup pointer, pushes +/// value: ( addr -- value ) +struct SharedLoadOpConversion + : public OpConversionPattern { + SharedLoadOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(forth::SharedLoadOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + auto i64Type = rewriter.getI64Type(); + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), + kWorkgroupAddressSpace); + + // Load address from stack. + Value addrValue = rewriter.create(loc, memref, stackPtr); + + // Load value from shared memory via address-space-qualified pointer. + Value ptr = rewriter.create(loc, ptrType, addrValue); + Value loadedValue = rewriter.create(loc, i64Type, ptr); + + // Store loaded value back at same position (replaces address). + rewriter.create(loc, loadedValue, memref, stackPtr); + + rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}}); + return success(); + } +}; + +/// Conversion pattern for forth.shared_store operation (S!). +/// Pops address and value from stack, stores value via shared/workgroup +/// pointer: ( x addr -- ) +struct SharedStoreOpConversion + : public OpConversionPattern { + SharedStoreOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(forth::SharedStoreOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), + kWorkgroupAddressSpace); + + // Pop address from stack. + Value addrValue = rewriter.create(loc, memref, stackPtr); + + // Pop value from stack. + Value one = rewriter.create(loc, 1); + Value spMinus1 = rewriter.create(loc, stackPtr, one); + Value value = rewriter.create(loc, memref, spMinus1); + + // Store value to shared memory via address-space-qualified pointer. + Value ptr = rewriter.create(loc, ptrType, addrValue); + rewriter.create(loc, value, ptr); + + // New stack pointer is SP-2 (popped both address and value). + Value spMinus2 = rewriter.create(loc, spMinus1, one); + rewriter.replaceOpWithMultiple(op, {{memref, spMinus2}}); + return success(); + } +}; + /// Template for converting GPU indexing ops to intrinsic ops. /// Creates an intrinsic op with the specified name and pushes the value onto /// the stack. @@ -1026,8 +1106,9 @@ struct ConvertForthToMemRefPass NotOpConversion, LshiftOpConversion, RshiftOpConversion, EqOpConversion, LtOpConversion, GtOpConversion, NeOpConversion, LeOpConversion, GeOpConversion, ZeroEqOpConversion, ParamRefOpConversion, - LoadOpConversion, StoreOpConversion, PopFlagOpConversion, - PopOpConversion, PushValueOpConversion>(typeConverter, context); + LoadOpConversion, StoreOpConversion, SharedLoadOpConversion, + SharedStoreOpConversion, PopFlagOpConversion, PopOpConversion, + PushValueOpConversion>(typeConverter, context); // Add GPU indexing op conversion patterns patterns.add>(typeConverter, diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp index b176843..fe4923c 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp @@ -487,6 +487,12 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, } else if (word == "!") { return builder.create(loc, stackType, inputStack) .getResult(); + } else if (word == "S@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "S!") { + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "TID-X") { return builder.create(loc, stackType, inputStack) .getResult(); diff --git a/test/Conversion/ForthToMemRef/memory-ops.mlir b/test/Conversion/ForthToMemRef/memory-ops.mlir index 755597e..d4965ef 100644 --- a/test/Conversion/ForthToMemRef/memory-ops.mlir +++ b/test/Conversion/ForthToMemRef/memory-ops.mlir @@ -15,6 +15,14 @@ // CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr // CHECK: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr +// shared load (S@): pop address, inttoptr shared addrspace, llvm.load +// CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<{{[1-9][0-9]*}}> +// CHECK: llvm.load %{{.*}} : !llvm.ptr<{{[1-9][0-9]*}}> -> i64 + +// shared store (S!): pop address + value, inttoptr shared addrspace, llvm.store +// CHECK: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<{{[1-9][0-9]*}}> +// CHECK: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr<{{[1-9][0-9]*}}> + module { func.func private @main() { %0 = forth.stack !forth.stack @@ -23,6 +31,11 @@ module { %3 = forth.literal %2 42 : !forth.stack -> !forth.stack %4 = forth.literal %3 100 : !forth.stack -> !forth.stack %5 = forth.store %4 : !forth.stack -> !forth.stack + %6 = forth.literal %5 2 : !forth.stack -> !forth.stack + %7 = forth.shared_load %6 : !forth.stack -> !forth.stack + %8 = forth.literal %7 9 : !forth.stack -> !forth.stack + %9 = forth.literal %8 3 : !forth.stack -> !forth.stack + %10 = forth.shared_store %9 : !forth.stack -> !forth.stack return } } diff --git a/test/Translation/Forth/memory-ops.forth b/test/Translation/Forth/memory-ops.forth index e340df8..449cc42 100644 --- a/test/Translation/Forth/memory-ops.forth +++ b/test/Translation/Forth/memory-ops.forth @@ -6,9 +6,15 @@ \ Test ! produces forth.store \ CHECK: forth.store %{{.*}} : !forth.stack -> !forth.stack +\ Test S@ produces forth.shared_load +\ CHECK: forth.shared_load %{{.*}} : !forth.stack -> !forth.stack + +\ Test S! produces forth.shared_store +\ CHECK: forth.shared_store %{{.*}} : !forth.stack -> !forth.stack + \ Test CELLS produces literal 8 + mul \ CHECK: forth.literal %{{.*}} 8 \ CHECK-NEXT: forth.mul \! kernel main -1 @ 2 3 ! +1 @ 2 3 ! 4 S@ 5 6 S! 4 CELLS