Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `FEXP FSQRT FLOG FABS FNEG` (float math intrinsics), `FMAX FMIN` (float min/max), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations.
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i64[<N>]` becomes a `memref<Nxi64>` argument; `\! param <name> i64` becomes an `i64` argument. `\! param <name> f64[<N>]` becomes a `memref<Nxf64>` argument; `\! param <name> f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Shared Memory**: `\! shared <name> i64[<N>]` or `\! shared <name> f64[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions.
Expand Down
3 changes: 2 additions & 1 deletion include/warpforth/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def ConvertForthToMemRef
let dependentDialects = ["mlir::memref::MemRefDialect",
"mlir::arith::ArithDialect",
"mlir::LLVM::LLVMDialect",
"mlir::cf::ControlFlowDialect"];
"mlir::cf::ControlFlowDialect",
"mlir::math::MathDialect"];
}

def ConvertForthToGPU : Pass<"convert-forth-to-gpu", "mlir::ModuleOp"> {
Expand Down
65 changes: 65 additions & 0 deletions include/warpforth/Dialect/Forth/ForthOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,71 @@ def Forth_DivFOp : Forth_StackOpBase<"divf"> {
}];
}

//===----------------------------------------------------------------------===//
// Float math intrinsic operations.
//===----------------------------------------------------------------------===//

def Forth_ExpFOp : Forth_StackOpBase<"expf"> {
let summary = "Exponential of top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, computes e^x,
bitcasts result back to i64.
Forth semantics: ( f -- exp(f) )
}];
}

def Forth_SqrtFOp : Forth_StackOpBase<"sqrtf"> {
let summary = "Square root of top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, computes sqrt(x),
bitcasts result back to i64.
Forth semantics: ( f -- sqrt(f) )
}];
}

def Forth_LogFOp : Forth_StackOpBase<"logf"> {
let summary = "Natural logarithm of top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, computes ln(x),
bitcasts result back to i64.
Forth semantics: ( f -- log(f) )
}];
}

def Forth_AbsFOp : Forth_StackOpBase<"absf"> {
let summary = "Absolute value of top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, computes |x|,
bitcasts result back to i64.
Forth semantics: ( f -- |f| )
}];
}

def Forth_NegFOp : Forth_StackOpBase<"negf"> {
let summary = "Negate top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, negates,
bitcasts result back to i64.
Forth semantics: ( f -- -f )
}];
}

def Forth_MaxFOp : Forth_StackOpBase<"maxf"> {
let summary = "Maximum of top two stack elements (float)";
let description = [{
Pops two i64 values, bitcasts to f64, computes max, bitcasts result back to i64.
Forth semantics: ( f1 f2 -- max(f1,f2) )
}];
}

def Forth_MinFOp : Forth_StackOpBase<"minf"> {
let summary = "Minimum of top two stack elements (float)";
let description = [{
Pops two i64 values, bitcasts to f64, computes min, bitcasts result back to i64.
Forth semantics: ( f1 f2 -- min(f1,f2) )
}];
}

//===----------------------------------------------------------------------===//
// Bitwise operations.
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_library(MLIRConversionPasses
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUTransforms
MLIRMathToLLVM
MLIRReconcileUnrealizedCasts
MLIRTransforms
)
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/ForthToMemRef/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRForthToMemRefConversion
MLIRLLVMDialect
MLIRFuncDialect
MLIRControlFlowDialect
MLIRMathDialect
MLIRForth
)
61 changes: 60 additions & 1 deletion lib/Conversion/ForthToMemRef/ForthToMemRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -514,6 +515,60 @@ using MulFOpConversion =
using DivFOpConversion =
BinaryArithOpConversion<forth::DivFOp, arith::DivFOp, true>;

// Float binary intrinsics (max/min)
using MaxFOpConversion =
BinaryArithOpConversion<forth::MaxFOp, arith::MaximumFOp, true>;
using MinFOpConversion =
BinaryArithOpConversion<forth::MinFOp, arith::MinimumFOp, true>;

/// Base template for unary float operations.
/// Pops one value, applies operation, pushes result: (f -- result)
/// Bitcasts i64->f64 before the op and f64->i64 after.
template <typename ForthOp, typename MathOp>
struct UnaryFloatOpConversion : public OpConversionPattern<ForthOp> {
UnaryFloatOpConversion(const TypeConverter &typeConverter,
MLIRContext *context)
: OpConversionPattern<ForthOp>(typeConverter, context) {}
using OneToNOpAdaptor =
typename OpConversionPattern<ForthOp>::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(ForthOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

// Load value from top of stack
Value a = rewriter.create<memref::LoadOp>(loc, memref, stackPtr);

// Bitcast i64 -> f64
auto f64Type = rewriter.getF64Type();
Value aF = rewriter.create<arith::BitcastOp>(loc, f64Type, a);

// Apply math/arith op
Value resF = rewriter.create<MathOp>(loc, aF);

// Bitcast f64 -> i64
Value result =
rewriter.create<arith::BitcastOp>(loc, rewriter.getI64Type(), resF);

// Store result at same position (SP unchanged — unary op)
rewriter.create<memref::StoreOp>(loc, result, memref, stackPtr);

rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}});
return success();
}
};

// Float unary intrinsics
using ExpFOpConversion = UnaryFloatOpConversion<forth::ExpFOp, math::ExpOp>;
using SqrtFOpConversion = UnaryFloatOpConversion<forth::SqrtFOp, math::SqrtOp>;
using LogFOpConversion = UnaryFloatOpConversion<forth::LogFOp, math::LogOp>;
using AbsFOpConversion = UnaryFloatOpConversion<forth::AbsFOp, math::AbsFOp>;
using NegFOpConversion = UnaryFloatOpConversion<forth::NegFOp, arith::NegFOp>;

/// Base template for binary comparison operations.
/// Pops two values, compares, pushes -1 (true) or 0 (false): (a b -- flag)
/// When IsFloat=true, bitcasts i64->f64 before comparing.
Expand Down Expand Up @@ -1153,7 +1208,8 @@ struct ConvertForthToMemRefPass

// Mark MemRef, Arith, LLVM, and CF dialects as legal
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
LLVM::LLVMDialect, cf::ControlFlowDialect>();
LLVM::LLVMDialect, cf::ControlFlowDialect,
math::MathDialect>();

// Mark IntrinsicOp and BarrierOp as legal (to be lowered later)
target.addLegalOp<forth::IntrinsicOp>();
Expand Down Expand Up @@ -1205,6 +1261,9 @@ struct ConvertForthToMemRefPass
ModOpConversion,
// Float arithmetic
AddFOpConversion, SubFOpConversion, MulFOpConversion, DivFOpConversion,
// Float math intrinsics
ExpFOpConversion, SqrtFOpConversion, LogFOpConversion, AbsFOpConversion,
NegFOpConversion, MaxFOpConversion, MinFOpConversion,
// Bitwise
AndOpConversion, OrOpConversion, XorOpConversion, NotOpConversion,
LshiftOpConversion, RshiftOpConversion,
Expand Down
10 changes: 7 additions & 3 deletions lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "warpforth/Conversion/Passes.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -40,13 +41,16 @@ void buildWarpForthPipeline(OpPassManager &pm) {
pm.addNestedPass<gpu::GPUModuleOp>(
createConvertGpuOpsToNVVMOps(gpuToNVVMOptions));

// Stage 6: Lower NVVM to LLVM
// Stage 6: Lower math ops to LLVM intrinsics inside GPU module
pm.addNestedPass<gpu::GPUModuleOp>(createConvertMathToLLVMPass());

// Stage 7: Lower NVVM to LLVM
pm.addPass(createConvertNVVMToLLVMPass());

// Stage 7: Reconcile type conversions
// Stage 8: Reconcile type conversions
pm.addPass(createReconcileUnrealizedCastsPass());

// Stage 8: Compile GPU module to PTX binary
// Stage 9: Compile GPU module to PTX binary
GpuModuleToBinaryPassOptions binaryOptions;
binaryOptions.compilationTarget = "isa"; // Output PTX assembly
pm.addPass(createGpuModuleToBinaryPass(binaryOptions));
Expand Down
21 changes: 21 additions & 0 deletions lib/Translation/ForthToMLIR/ForthToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,27 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack,
} else if (word == "F/") {
return builder.create<forth::DivFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FEXP") {
return builder.create<forth::ExpFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FSQRT") {
return builder.create<forth::SqrtFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FLOG") {
return builder.create<forth::LogFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FABS") {
return builder.create<forth::AbsFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FNEG") {
return builder.create<forth::NegFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FMAX") {
return builder.create<forth::MaxFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FMIN") {
return builder.create<forth::MinFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "MOD") {
return builder.create<forth::ModOp>(loc, stackType, inputStack).getResult();
} else if (word == "AND") {
Expand Down
70 changes: 70 additions & 0 deletions test/Conversion/ForthToMemRef/float-math-intrinsics.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s

// CHECK-LABEL: func.func private @main

// expf: load, bitcast i64->f64, math.exp, bitcast f64->i64, store (SP unchanged)
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: math.exp %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// sqrtf: load, bitcast, math.sqrt, bitcast, store
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: math.sqrt %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// logf: load, bitcast, math.log, bitcast, store
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: math.log %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// absf: load, bitcast, math.absf, bitcast, store
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: math.absf %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// negf: load, bitcast, arith.negf, bitcast, store
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.negf %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// maxf: binary — pop two, bitcast, arith.maximumf, bitcast, store
// CHECK: memref.load
// CHECK: arith.subi
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.maximumf %{{.*}}, %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// minf: binary — pop two, bitcast, arith.minimumf, bitcast, store
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.minimumf %{{.*}}, %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64

module {
func.func private @main() {
%0 = forth.stack !forth.stack
%1 = forth.constant %0(1.000000e+00 : f64) : !forth.stack -> !forth.stack
%2 = forth.expf %1 : !forth.stack -> !forth.stack
%3 = forth.sqrtf %2 : !forth.stack -> !forth.stack
%4 = forth.logf %3 : !forth.stack -> !forth.stack
%5 = forth.absf %4 : !forth.stack -> !forth.stack
%6 = forth.negf %5 : !forth.stack -> !forth.stack
%7 = forth.constant %6(2.000000e+00 : f64) : !forth.stack -> !forth.stack
%8 = forth.maxf %7 : !forth.stack -> !forth.stack
%9 = forth.minf %8 : !forth.stack -> !forth.stack
return
}
}
12 changes: 12 additions & 0 deletions test/Pipeline/float-math-intrinsics.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s

\ Verify float math intrinsics lower through the full pipeline to gpu.binary
\ CHECK: gpu.binary @warpforth_module

\! kernel main
\! param data f64[256]
GLOBAL-ID CELLS data + F@
FABS FEXP FSQRT FLOG FNEG
GLOBAL-ID CELLS data + F@
FMAX FMIN
GLOBAL-ID CELLS data + F!
21 changes: 21 additions & 0 deletions test/Translation/Forth/float-math-intrinsics.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s

\ Verify float math intrinsic ops parse correctly

\ Unary ops
\ CHECK: %[[S0:.*]] = forth.stack
\ CHECK: %[[S1:.*]] = forth.constant %[[S0]]
\ CHECK: %[[S2:.*]] = forth.expf %[[S1]]
\ CHECK: %[[S3:.*]] = forth.sqrtf %[[S2]]
\ CHECK: %[[S4:.*]] = forth.logf %[[S3]]
\ CHECK: %[[S5:.*]] = forth.absf %[[S4]]
\ CHECK: %[[S6:.*]] = forth.negf %[[S5]]

\ Binary ops
\ CHECK: %[[S7:.*]] = forth.constant %[[S6]]
\ CHECK: %[[S8:.*]] = forth.maxf %[[S7]]
\ CHECK: %[[S9:.*]] = forth.minf %[[S8]]

\! kernel main
1.0 FEXP FSQRT FLOG FABS FNEG
2.0 FMAX FMIN