From 9ac5a665d4abad4ffddcbd2501dd4221aaf20c1b Mon Sep 17 00:00:00 2001 From: Alex Cameron Date: Sat, 21 Feb 2026 02:02:53 +0900 Subject: [PATCH 1/4] feat: add f64 floating-point support and rename ops to MLIR conventions --- CLAUDE.md | 7 +- gpu_test/conftest.py | 43 ++- include/warpforth/Dialect/Forth/ForthOps.td | 237 ++++++++++--- .../ForthToMemRef/ForthToMemRef.cpp | 326 ++++++++++++------ lib/Translation/ForthToMLIR/ForthToMLIR.cpp | 154 +++++++-- lib/Translation/ForthToMLIR/ForthToMLIR.h | 10 +- test/Conversion/ForthToMemRef/arithmetic.mlir | 20 +- .../Conversion/ForthToMemRef/begin-until.mlir | 6 +- .../ForthToMemRef/begin-while-repeat.mlir | 10 +- test/Conversion/ForthToMemRef/bitwise.mlir | 22 +- test/Conversion/ForthToMemRef/comparison.mlir | 38 +- .../ForthToMemRef/control-flow.mlir | 10 +- test/Conversion/ForthToMemRef/do-loop.mlir | 4 +- .../ForthToMemRef/float-arithmetic.mlir | 44 +++ .../ForthToMemRef/float-memory.mlir | 29 ++ test/Conversion/ForthToMemRef/leave.mlir | 4 +- test/Conversion/ForthToMemRef/literal.mlir | 2 +- test/Conversion/ForthToMemRef/memory-ops.mlir | 20 +- .../ForthToMemRef/nested-control-flow.mlir | 30 +- .../ForthToMemRef/stack-manipulation.mlir | 10 +- .../ForthToMemRef/user-defined-words.mlir | 4 +- test/Pipeline/float-pipeline.forth | 20 ++ test/Translation/Forth/arithmetic-ops.forth | 12 +- test/Translation/Forth/basic-literals.forth | 6 +- test/Translation/Forth/begin-until.forth | 6 +- .../Forth/begin-while-repeat.forth | 10 +- test/Translation/Forth/bitwise-ops.forth | 22 +- test/Translation/Forth/case-insensitive.forth | 2 +- test/Translation/Forth/comparison-ops.forth | 38 +- test/Translation/Forth/control-flow.forth | 10 +- test/Translation/Forth/do-loop.forth | 4 +- test/Translation/Forth/float-arithmetic.forth | 12 + test/Translation/Forth/float-comparison.forth | 24 ++ test/Translation/Forth/float-conversion.forth | 9 + test/Translation/Forth/float-literals.forth | 9 + test/Translation/Forth/float-memory.forth | 15 + test/Translation/Forth/float-params.forth | 12 + .../Forth/interleaved-control-flow.forth | 22 +- test/Translation/Forth/leave.forth | 4 +- test/Translation/Forth/memory-ops.forth | 20 +- .../Forth/nested-control-flow.forth | 46 +-- .../Forth/plus-loop-negative.forth | 6 +- test/Translation/Forth/plus-loop.forth | 6 +- test/Translation/Forth/unloop-exit.forth | 2 +- test/Translation/Forth/word-definitions.forth | 2 +- tools/warpforth-runner/warpforth-runner.cpp | 160 ++++++--- 46 files changed, 1055 insertions(+), 454 deletions(-) create mode 100644 test/Conversion/ForthToMemRef/float-arithmetic.mlir create mode 100644 test/Conversion/ForthToMemRef/float-memory.mlir create mode 100644 test/Pipeline/float-pipeline.forth create mode 100644 test/Translation/Forth/float-arithmetic.forth create mode 100644 test/Translation/Forth/float-comparison.forth create mode 100644 test/Translation/Forth/float-conversion.forth create mode 100644 test/Translation/Forth/float-literals.forth create mode 100644 test/Translation/Forth/float-memory.forth create mode 100644 test/Translation/Forth/float-params.forth diff --git a/CLAUDE.md b/CLAUDE.md index a763c1b..373d1c9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -87,9 +87,10 @@ uv run ruff format gpu_test/ - **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety - **Operations**: All take stack as input and produce stack as output (except `forth.stack`) -- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !` (global memory), `S@ S!` (shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). -- **Kernel Parameters**: Declared in the `\!` header. `\! kernel ` is required and must appear first. `\! param i64[]` becomes a `memref` argument; `\! param i64` becomes an `i64` argument. Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value). -- **Shared Memory**: `\! shared i64[]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution (`memref>`). Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for shared accesses. Cannot be referenced inside word definitions. +- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). +- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations. +- **Kernel Parameters**: Declared in the `\!` header. `\! kernel ` is required and must appear first. `\! param i64[]` becomes a `memref` argument; `\! param i64` becomes an `i64` argument. `\! param f64[]` becomes a `memref` argument; `\! param f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value). +- **Shared Memory**: `\! shared i64[]` or `\! shared f64[]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions. - **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer - **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion - **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call` diff --git a/gpu_test/conftest.py b/gpu_test/conftest.py index de0b273..9b20e14 100644 --- a/gpu_test/conftest.py +++ b/gpu_test/conftest.py @@ -39,6 +39,7 @@ class ParamDecl: name: str is_array: bool size: int # 0 for scalars + base_type: str = "i64" # "i64" or "f64" class CompileError(Exception): @@ -283,15 +284,17 @@ def scp_upload(self, local_path: str | Path, remote_path: str) -> None: ) -def _parse_array_type(type_spec: str) -> int: +def _parse_array_type(type_spec: str) -> tuple[str, int]: + """Parse 'i64[256]' or 'f64[256]' into (base_type, size).""" if not type_spec.endswith("]"): msg = f"Invalid array type spec: {type_spec}" raise ValueError(msg) base, size_str = type_spec[:-1].split("[", 1) - if base.lower() != "i64": + base_lower = base.lower() + if base_lower not in ("i64", "f64"): msg = f"Unsupported base type: {base}" raise ValueError(msg) - return int(size_str) + return base_lower, int(size_str) def _iter_header_directives(forth_source: str) -> Generator[tuple[str, list[str]]]: @@ -341,13 +344,14 @@ def _parse_param_declarations(forth_source: str) -> list[ParamDecl]: name = parts[1] type_spec = parts[2] if "[" in type_spec: - size = _parse_array_type(type_spec) - decls.append(ParamDecl(name=name, is_array=True, size=size)) + base_type, size = _parse_array_type(type_spec) + decls.append(ParamDecl(name=name, is_array=True, size=size, base_type=base_type)) else: - if type_spec.lower() != "i64": + base_type = type_spec.lower() + if base_type not in ("i64", "f64"): msg = f"Unsupported scalar type: {type_spec}" raise ValueError(msg) - decls.append(ParamDecl(name=name, is_array=False, size=0)) + decls.append(ParamDecl(name=name, is_array=False, size=0, base_type=base_type)) return decls @@ -361,18 +365,18 @@ def __init__(self, session: VastSession, compiler: Compiler) -> None: def run( self, forth_source: str, - params: dict[str, list[int] | int] | None = None, + params: dict[str, list[int] | list[float] | int | float] | None = None, grid: tuple[int, int, int] = (1, 1, 1), block: tuple[int, int, int] = (1, 1, 1), output_param: int = 0, output_count: int | None = None, - ) -> list[int]: + ) -> list[int] | list[float]: """Compile Forth source locally, execute on remote GPU, return output values. Param buffer sizes are derived from the Forth source's 'param' declarations. The params dict maps param names to initial values: - - Array params: list[int] (padded with zeros to declared size) - - Scalar params: int + - Array params: list of int or float (padded with zeros to declared size) + - Scalar params: int or float Params not in the dict are zero-initialized. """ # Parse kernel name and param declarations @@ -420,16 +424,17 @@ def run( if not isinstance(values, list): msg = f"Array param '{decl.name}' expects a list, got {type(values).__name__}" raise TypeError(msg) - buf = [0] * decl.size + zero = 0.0 if decl.base_type == "f64" else 0 + buf = [zero] * decl.size for i, v in enumerate(values): buf[i] = v - cmd_parts.extend(["--param", f"i64[]:{','.join(str(v) for v in buf)}"]) + cmd_parts.extend(["--param", f"{decl.base_type}[]:{','.join(str(v) for v in buf)}"]) else: - value = params.get(decl.name, 0) + value = params.get(decl.name, 0.0 if decl.base_type == "f64" else 0) if isinstance(value, list): - msg = f"Scalar param '{decl.name}' expects an int, got list" + msg = f"Scalar param '{decl.name}' expects a scalar, got list" raise TypeError(msg) - cmd_parts.extend(["--param", f"i64:{value}"]) + cmd_parts.extend(["--param", f"{decl.base_type}:{value}"]) cmd_parts.extend( [ @@ -448,8 +453,10 @@ def run( cmd = " ".join(cmd_parts) stdout = self.session.ssh_run(cmd, timeout=120) - # Parse CSV output - return [int(v) for v in stdout.strip().split(",")] + # Parse CSV output — type depends on the output param + out_type = decls[output_param].base_type + parse = float if out_type == "f64" else int + return [parse(v) for v in stdout.strip().split(",")] # --- Fixtures --- diff --git a/include/warpforth/Dialect/Forth/ForthOps.td b/include/warpforth/Dialect/Forth/ForthOps.td index 82df7b3..4d76144 100644 --- a/include/warpforth/Dialect/Forth/ForthOps.td +++ b/include/warpforth/Dialect/Forth/ForthOps.td @@ -114,54 +114,54 @@ def Forth_RollOp : Forth_StackOpBase<"roll"> { } //===----------------------------------------------------------------------===// -// Literal operations. +// Constant operations. //===----------------------------------------------------------------------===// -def Forth_LiteralOp : Forth_Op<"literal", [Pure]> { - let summary = "Push a literal value onto the stack"; +def Forth_ConstantOp : Forth_Op<"constant", [Pure]> { + let summary = "Push a constant value onto the stack"; let description = [{ - Pushes a literal integer value onto the stack. + Pushes a constant value onto the stack. Forth semantics: ( -- n ) }]; - let arguments = (ins Forth_StackType:$input_stack, I64Attr:$value); + let arguments = (ins Forth_StackType:$input_stack, AnyAttr:$value); let results = (outs Forth_StackType:$output_stack); let assemblyFormat = [{ - $input_stack $value attr-dict `:` type($input_stack) `->` type($output_stack) + $input_stack `(` $value `)` attr-dict `:` type($input_stack) `->` type($output_stack) }]; } //===----------------------------------------------------------------------===// -// Arithmetic operations. +// Integer arithmetic operations. //===----------------------------------------------------------------------===// -def Forth_AddOp : Forth_StackOpBase<"add"> { - let summary = "Add top two stack elements"; +def Forth_AddIOp : Forth_StackOpBase<"addi"> { + let summary = "Add top two stack elements (integer)"; let description = [{ Pops the top two elements, adds them, and pushes the result. Forth semantics: ( a b -- a+b ) }]; } -def Forth_SubOp : Forth_StackOpBase<"sub"> { - let summary = "Subtract top two stack elements"; +def Forth_SubIOp : Forth_StackOpBase<"subi"> { + let summary = "Subtract top two stack elements (integer)"; let description = [{ Pops the top two elements, subtracts them (a - b), and pushes the result. Forth semantics: ( a b -- a-b ) }]; } -def Forth_MulOp : Forth_StackOpBase<"mul"> { - let summary = "Multiply top two stack elements"; +def Forth_MulIOp : Forth_StackOpBase<"muli"> { + let summary = "Multiply top two stack elements (integer)"; let description = [{ Pops the top two elements, multiplies them, and pushes the result. Forth semantics: ( a b -- a*b ) }]; } -def Forth_DivOp : Forth_StackOpBase<"div"> { - let summary = "Divide top two stack elements"; +def Forth_DivIOp : Forth_StackOpBase<"divi"> { + let summary = "Divide top two stack elements (integer)"; let description = [{ Pops the top two elements, divides them (a / b), and pushes the result. Forth semantics: ( a b -- a/b ) @@ -176,6 +176,42 @@ def Forth_ModOp : Forth_StackOpBase<"mod"> { }]; } +//===----------------------------------------------------------------------===// +// Float arithmetic operations. +//===----------------------------------------------------------------------===// + +def Forth_AddFOp : Forth_StackOpBase<"addf"> { + let summary = "Add top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, adds, bitcasts result back to i64. + Forth semantics: ( f1 f2 -- f1+f2 ) + }]; +} + +def Forth_SubFOp : Forth_StackOpBase<"subf"> { + let summary = "Subtract top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, subtracts, bitcasts result back to i64. + Forth semantics: ( f1 f2 -- f1-f2 ) + }]; +} + +def Forth_MulFOp : Forth_StackOpBase<"mulf"> { + let summary = "Multiply top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, multiplies, bitcasts result back to i64. + Forth semantics: ( f1 f2 -- f1*f2 ) + }]; +} + +def Forth_DivFOp : Forth_StackOpBase<"divf"> { + let summary = "Divide top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, divides, bitcasts result back to i64. + Forth semantics: ( f1 f2 -- f1/f2 ) + }]; +} + //===----------------------------------------------------------------------===// // Bitwise operations. //===----------------------------------------------------------------------===// @@ -232,38 +268,73 @@ def Forth_RshiftOp : Forth_StackOpBase<"rshift"> { // Memory operations. //===----------------------------------------------------------------------===// -def Forth_LoadOp : Forth_StackOpBase<"load"> { - let summary = "Load value from memory buffer"; +def Forth_LoadIOp : Forth_StackOpBase<"loadi"> { + let summary = "Load i64 value from memory buffer"; + let description = [{ + Pops an address from the stack, loads an i64 value from memory, + and pushes the loaded value onto the stack. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_StoreIOp : Forth_StackOpBase<"storei"> { + let summary = "Store i64 value to memory buffer"; + let description = [{ + Pops an address and value from the stack, stores the i64 value to memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_LoadFOp : Forth_StackOpBase<"loadf"> { + let summary = "Load f64 value from memory buffer"; + let description = [{ + Pops an address from the stack, loads an f64 value from memory, + bitcasts to i64, and pushes onto the stack. + Forth semantics: ( addr -- value ) + }]; +} + +def Forth_StoreFOp : Forth_StackOpBase<"storef"> { + let summary = "Store f64 value to memory buffer"; + let description = [{ + Pops an address and value (i64 bit pattern of f64) from the stack, + bitcasts to f64, stores to memory. + Forth semantics: ( x addr -- ) + }]; +} + +def Forth_SharedLoadIOp : Forth_StackOpBase<"shared_loadi"> { + let summary = "Load i64 value from shared memory buffer"; let description = [{ - Pops an address from the stack, loads a value from the memory buffer at that address, + Pops an address from the stack, loads an i64 value from shared/workgroup memory, and pushes the loaded value onto the stack. Forth semantics: ( addr -- value ) }]; } -def Forth_StoreOp : Forth_StackOpBase<"store"> { - let summary = "Store value to memory buffer"; +def Forth_SharedStoreIOp : Forth_StackOpBase<"shared_storei"> { + let summary = "Store i64 value to shared memory buffer"; let description = [{ - Pops an address and value from the stack, stores the value to the memory buffer - at the specified address. + Pops an address and value from the stack, stores the i64 value to + shared/workgroup memory. Forth semantics: ( x addr -- ) }]; } -def Forth_SharedLoadOp : Forth_StackOpBase<"shared_load"> { - let summary = "Load value from shared memory buffer"; +def Forth_SharedLoadFOp : Forth_StackOpBase<"shared_loadf"> { + let summary = "Load f64 value from shared memory buffer"; let description = [{ - Pops an address from the stack, loads a value from shared/workgroup memory at - that address, and pushes the loaded value onto the stack. + Pops an address from the stack, loads an f64 value from shared/workgroup memory, + bitcasts to i64, and pushes onto the stack. Forth semantics: ( addr -- value ) }]; } -def Forth_SharedStoreOp : Forth_StackOpBase<"shared_store"> { - let summary = "Store value to shared memory buffer"; +def Forth_SharedStoreFOp : Forth_StackOpBase<"shared_storef"> { + let summary = "Store f64 value to shared memory buffer"; let description = [{ - Pops an address and value from the stack, stores the value to shared/workgroup - memory at the specified address. + Pops an address and value (i64 bit pattern of f64) from the stack, + bitcasts to f64, stores to shared/workgroup memory. Forth semantics: ( x addr -- ) }]; } @@ -417,57 +488,115 @@ def Forth_BarrierOp : Forth_Op<"barrier", []> { } //===----------------------------------------------------------------------===// -// Comparison operations. +// Integer comparison operations. //===----------------------------------------------------------------------===// -def Forth_EqOp : Forth_StackOpBase<"eq"> { - let summary = "Test equality of top two stack elements"; +def Forth_EqIOp : Forth_StackOpBase<"eqi"> { + let summary = "Test equality of top two stack elements (integer)"; let description = [{ Pops two values, pushes -1 (true) if equal, 0 (false) otherwise. Forth semantics: ( a b -- flag ) }]; } -def Forth_LtOp : Forth_StackOpBase<"lt"> { - let summary = "Test less-than of top two stack elements"; +def Forth_LtIOp : Forth_StackOpBase<"lti"> { + let summary = "Test less-than of top two stack elements (integer)"; let description = [{ Pops two values, pushes -1 (true) if a < b, 0 (false) otherwise. Forth semantics: ( a b -- flag ) }]; } -def Forth_GtOp : Forth_StackOpBase<"gt"> { - let summary = "Test greater-than of top two stack elements"; +def Forth_GtIOp : Forth_StackOpBase<"gti"> { + let summary = "Test greater-than of top two stack elements (integer)"; let description = [{ Pops two values, pushes -1 (true) if a > b, 0 (false) otherwise. Forth semantics: ( a b -- flag ) }]; } -def Forth_NeOp : Forth_StackOpBase<"ne"> { - let summary = "Test inequality of top two stack elements"; +def Forth_NeIOp : Forth_StackOpBase<"nei"> { + let summary = "Test inequality of top two stack elements (integer)"; let description = [{ Pops two values, pushes -1 (true) if not equal, 0 (false) otherwise. Forth semantics: ( a b -- flag ) }]; } -def Forth_LeOp : Forth_StackOpBase<"le"> { - let summary = "Test less-than-or-equal of top two stack elements"; +def Forth_LeIOp : Forth_StackOpBase<"lei"> { + let summary = "Test less-than-or-equal of top two stack elements (integer)"; let description = [{ Pops two values, pushes -1 (true) if a <= b, 0 (false) otherwise. Forth semantics: ( a b -- flag ) }]; } -def Forth_GeOp : Forth_StackOpBase<"ge"> { - let summary = "Test greater-than-or-equal of top two stack elements"; +def Forth_GeIOp : Forth_StackOpBase<"gei"> { + let summary = "Test greater-than-or-equal of top two stack elements (integer)"; let description = [{ Pops two values, pushes -1 (true) if a >= b, 0 (false) otherwise. Forth semantics: ( a b -- flag ) }]; } +//===----------------------------------------------------------------------===// +// Float comparison operations. +//===----------------------------------------------------------------------===// + +def Forth_EqFOp : Forth_StackOpBase<"eqf"> { + let summary = "Test equality of top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, compares for equality (ordered). + Pushes -1 (true) or 0 (false). + Forth semantics: ( f1 f2 -- flag ) + }]; +} + +def Forth_LtFOp : Forth_StackOpBase<"ltf"> { + let summary = "Test less-than of top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, compares f1 < f2 (ordered). + Pushes -1 (true) or 0 (false). + Forth semantics: ( f1 f2 -- flag ) + }]; +} + +def Forth_GtFOp : Forth_StackOpBase<"gtf"> { + let summary = "Test greater-than of top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, compares f1 > f2 (ordered). + Pushes -1 (true) or 0 (false). + Forth semantics: ( f1 f2 -- flag ) + }]; +} + +def Forth_NeFOp : Forth_StackOpBase<"nef"> { + let summary = "Test inequality of top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, compares for inequality (ordered). + Pushes -1 (true) or 0 (false). + Forth semantics: ( f1 f2 -- flag ) + }]; +} + +def Forth_LeFOp : Forth_StackOpBase<"lef"> { + let summary = "Test less-than-or-equal of top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, compares f1 <= f2 (ordered). + Pushes -1 (true) or 0 (false). + Forth semantics: ( f1 f2 -- flag ) + }]; +} + +def Forth_GeFOp : Forth_StackOpBase<"gef"> { + let summary = "Test greater-than-or-equal of top two stack elements (float)"; + let description = [{ + Pops two i64 values, bitcasts to f64, compares f1 >= f2 (ordered). + Pushes -1 (true) or 0 (false). + Forth semantics: ( f1 f2 -- flag ) + }]; +} + def Forth_ZeroEqOp : Forth_StackOpBase<"zero_eq"> { let summary = "Test if top of stack is zero"; let description = [{ @@ -476,6 +605,28 @@ def Forth_ZeroEqOp : Forth_StackOpBase<"zero_eq"> { }]; } +//===----------------------------------------------------------------------===// +// Type conversion operations. +//===----------------------------------------------------------------------===// + +def Forth_IToFOp : Forth_StackOpBase<"itof"> { + let summary = "Convert integer to float"; + let description = [{ + Pops an i64 integer, converts to f64 via sitofp, bitcasts result + to i64 bit pattern, pushes onto stack. + Forth semantics: ( n -- f ) + }]; +} + +def Forth_FToIOp : Forth_StackOpBase<"ftoi"> { + let summary = "Convert float to integer"; + let description = [{ + Pops an i64 (f64 bit pattern), bitcasts to f64, converts to i64 + via fptosi, pushes onto stack. + Forth semantics: ( f -- n ) + }]; +} + //===----------------------------------------------------------------------===// // Control flow support operations. //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp index 7932c87..135d035 100644 --- a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp +++ b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp @@ -86,24 +86,35 @@ struct StackOpConversion : public OpConversionPattern { } }; -/// Conversion pattern for forth.literal operation. -/// Increments SP and stores the literal value at the new SP position. -struct LiteralOpConversion : public OpConversionPattern { - LiteralOpConversion(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} +/// Conversion pattern for forth.constant operation. +/// Handles both integer and float constants. +struct ConstantOpConversion : public OpConversionPattern { + ConstantOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::LiteralOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(forth::ConstantOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ValueRange inputStack = adaptor.getOperands()[0]; Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - Value literalValue = rewriter.create( - loc, rewriter.getI64Type(), op.getValueAttr()); - Value newSP = pushValue(loc, rewriter, memref, stackPtr, literalValue); + Value valueToPush; + auto typedValue = cast(op.getValueAttr()); + if (isa(typedValue)) { + // Float: create f64 constant, bitcast to i64 + Value f64Value = rewriter.create( + loc, rewriter.getF64Type(), typedValue); + valueToPush = rewriter.create( + loc, rewriter.getI64Type(), f64Value); + } else { + // Integer: create i64 constant directly + valueToPush = rewriter.create( + loc, rewriter.getI64Type(), typedValue); + } + Value newSP = pushValue(loc, rewriter, memref, stackPtr, valueToPush); rewriter.replaceOpWithMultiple(op, {{memref, newSP}}); return success(); @@ -430,7 +441,8 @@ struct RollOpConversion : public OpConversionPattern { /// Base template for binary arithmetic operations. /// Pops two values, applies operation, pushes result: (a b -- result) -template +/// When IsFloat=true, bitcasts i64->f64 before the op and f64->i64 after. +template struct BinaryArithOpConversion : public OpConversionPattern { BinaryArithOpConversion(const TypeConverter &typeConverter, MLIRContext *context) @@ -453,8 +465,19 @@ struct BinaryArithOpConversion : public OpConversionPattern { Value spMinus1 = rewriter.create(loc, stackPtr, one); Value a = rewriter.create(loc, memref, spMinus1); - // Perform arithmetic operation - Value result = rewriter.create(loc, a, b); + Value result; + if constexpr (IsFloat) { + // Bitcast i64 -> f64 + auto f64Type = rewriter.getF64Type(); + Value aF = rewriter.create(loc, f64Type, a); + Value bF = rewriter.create(loc, f64Type, b); + Value resF = rewriter.create(loc, aF, bF); + // Bitcast f64 -> i64 + result = + rewriter.create(loc, rewriter.getI64Type(), resF); + } else { + result = rewriter.create(loc, a, b); + } // Store result at SP-1 (effectively popping both and pushing result) rewriter.create(loc, result, memref, spMinus1); @@ -465,11 +488,11 @@ struct BinaryArithOpConversion : public OpConversionPattern { } }; -// Instantiate arithmetic operation conversions -using AddOpConversion = BinaryArithOpConversion; -using SubOpConversion = BinaryArithOpConversion; -using MulOpConversion = BinaryArithOpConversion; -using DivOpConversion = BinaryArithOpConversion; +// Integer arithmetic +using AddIOpConversion = BinaryArithOpConversion; +using SubIOpConversion = BinaryArithOpConversion; +using MulIOpConversion = BinaryArithOpConversion; +using DivIOpConversion = BinaryArithOpConversion; using ModOpConversion = BinaryArithOpConversion; using AndOpConversion = BinaryArithOpConversion; using OrOpConversion = BinaryArithOpConversion; @@ -479,9 +502,21 @@ using LshiftOpConversion = using RshiftOpConversion = BinaryArithOpConversion; +// Float arithmetic +using AddFOpConversion = + BinaryArithOpConversion; +using SubFOpConversion = + BinaryArithOpConversion; +using MulFOpConversion = + BinaryArithOpConversion; +using DivFOpConversion = + BinaryArithOpConversion; + /// Base template for binary comparison operations. /// Pops two values, compares, pushes -1 (true) or 0 (false): (a b -- flag) -template +/// When IsFloat=true, bitcasts i64->f64 before comparing. +template struct BinaryCmpOpConversion : public OpConversionPattern { BinaryCmpOpConversion(const TypeConverter &typeConverter, MLIRContext *context) @@ -504,8 +539,15 @@ struct BinaryCmpOpConversion : public OpConversionPattern { Value spMinus1 = rewriter.create(loc, stackPtr, one); Value a = rewriter.create(loc, memref, spMinus1); - // Compare - Value cmp = rewriter.create(loc, predicate, a, b); + Value cmp; + if constexpr (IsFloat) { + auto f64Type = rewriter.getF64Type(); + Value aF = rewriter.create(loc, f64Type, a); + Value bF = rewriter.create(loc, f64Type, b); + cmp = rewriter.create(loc, Predicate, aF, bF); + } else { + cmp = rewriter.create(loc, Predicate, a, b); + } // Extend i1 to i64: true = -1 (all bits set), false = 0 Value result = @@ -520,19 +562,33 @@ struct BinaryCmpOpConversion : public OpConversionPattern { } }; -// Instantiate comparison operation conversions -using EqOpConversion = - BinaryCmpOpConversion; -using LtOpConversion = - BinaryCmpOpConversion; -using GtOpConversion = - BinaryCmpOpConversion; -using NeOpConversion = - BinaryCmpOpConversion; -using LeOpConversion = - BinaryCmpOpConversion; -using GeOpConversion = - BinaryCmpOpConversion; +// Integer comparisons +using EqIOpConversion = BinaryCmpOpConversion; +using LtIOpConversion = BinaryCmpOpConversion; +using GtIOpConversion = BinaryCmpOpConversion; +using NeIOpConversion = BinaryCmpOpConversion; +using LeIOpConversion = BinaryCmpOpConversion; +using GeIOpConversion = BinaryCmpOpConversion; + +// Float comparisons (ordered predicates) +using EqFOpConversion = BinaryCmpOpConversion; +using LtFOpConversion = BinaryCmpOpConversion; +using GtFOpConversion = BinaryCmpOpConversion; +using NeFOpConversion = BinaryCmpOpConversion; +using LeFOpConversion = BinaryCmpOpConversion; +using GeFOpConversion = BinaryCmpOpConversion; /// Conversion pattern for forth.not operation (bitwise NOT). /// Unary: pops one value, XORs with -1 (all bits set), pushes result: (a -- ~a) @@ -643,8 +699,12 @@ struct ParamRefOpConversion : public OpConversionPattern { valueToPush = rewriter.create( loc, rewriter.getI64Type(), ptrIndex); } else if (memrefArg.getType().isInteger(64)) { - // Scalar param: push value directly. + // Scalar i64 param: push value directly. valueToPush = memrefArg; + } else if (memrefArg.getType().isF64()) { + // Scalar f64 param: bitcast to i64 for stack storage. + valueToPush = rewriter.create( + loc, rewriter.getI64Type(), memrefArg); } else { return rewriter.notifyMatchFailure( op, "unsupported param argument type for param_ref"); @@ -658,56 +718,84 @@ struct ParamRefOpConversion : public OpConversionPattern { } }; -/// Conversion pattern for forth.load operation (@). -/// Pops address from stack, loads value via pointer, pushes value: ( addr -- -/// value ) -struct LoadOpConversion : public OpConversionPattern { - LoadOpConversion(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} - using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; +/// Generalized memory load template. +/// Pops address from stack, loads value via pointer, pushes value. +/// When IsFloat=true, loads f64 from memory and bitcasts to i64 for stack. +/// AddressSpace selects global (0) or workgroup memory. +template +struct MemoryLoadOpConversion : public OpConversionPattern { + MemoryLoadOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = + typename OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::LoadOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(ForthOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ValueRange inputStack = adaptor.getOperands()[0]; Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - auto i64Type = rewriter.getI64Type(); - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto ptrType = + LLVM::LLVMPointerType::get(rewriter.getContext(), AddressSpace); // Load address from stack Value addrValue = rewriter.create(loc, memref, stackPtr); // Load value from memory via pointer Value ptr = rewriter.create(loc, ptrType, addrValue); - Value loadedValue = rewriter.create(loc, i64Type, ptr); + + Value valueToPush; + if constexpr (IsFloat) { + // Load f64 from memory, then bitcast to i64 for stack storage + Value loadedF64 = + rewriter.create(loc, rewriter.getF64Type(), ptr); + valueToPush = rewriter.create( + loc, rewriter.getI64Type(), loadedF64); + } else { + valueToPush = + rewriter.create(loc, rewriter.getI64Type(), ptr); + } // Store loaded value back at same position (replaces address) - rewriter.create(loc, loadedValue, memref, stackPtr); + rewriter.create(loc, valueToPush, memref, stackPtr); rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}}); return success(); } }; -/// Conversion pattern for forth.store operation (!). -/// Pops address and value from stack, stores value to memory: ( x addr -- ) -struct StoreOpConversion : public OpConversionPattern { - StoreOpConversion(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} - using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; +// Memory load instantiations +using LoadIOpConversion = MemoryLoadOpConversion; +using LoadFOpConversion = MemoryLoadOpConversion; +using SharedLoadIOpConversion = + MemoryLoadOpConversion; +using SharedLoadFOpConversion = + MemoryLoadOpConversion; + +/// Generalized memory store template. +/// Pops address and value from stack, stores value to memory. +/// When IsFloat=true, bitcasts i64->f64 before storing. +template +struct MemoryStoreOpConversion : public OpConversionPattern { + MemoryStoreOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = + typename OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::StoreOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(ForthOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ValueRange inputStack = adaptor.getOperands()[0]; Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto ptrType = + LLVM::LLVMPointerType::get(rewriter.getContext(), AddressSpace); // Pop address from stack Value addrValue = rewriter.create(loc, memref, stackPtr); @@ -719,7 +807,15 @@ struct StoreOpConversion : public OpConversionPattern { // Store value to memory via pointer Value ptr = rewriter.create(loc, ptrType, addrValue); - rewriter.create(loc, value, ptr); + + if constexpr (IsFloat) { + // Bitcast i64 -> f64 before storing + Value f64Value = + rewriter.create(loc, rewriter.getF64Type(), value); + rewriter.create(loc, f64Value, ptr); + } else { + rewriter.create(loc, value, ptr); + } // New stack pointer is SP-2 (popped both address and value) Value spMinus2 = rewriter.create(loc, spMinus1, one); @@ -728,79 +824,80 @@ struct StoreOpConversion : public OpConversionPattern { } }; -/// Conversion pattern for forth.shared_load operation (S@). -/// Pops address from stack, loads value via shared/workgroup pointer, pushes -/// value: ( addr -- value ) -struct SharedLoadOpConversion - : public OpConversionPattern { - SharedLoadOpConversion(const TypeConverter &typeConverter, - MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} +// Memory store instantiations +using StoreIOpConversion = MemoryStoreOpConversion; +using StoreFOpConversion = MemoryStoreOpConversion; +using SharedStoreIOpConversion = + MemoryStoreOpConversion; +using SharedStoreFOpConversion = + MemoryStoreOpConversion; + +/// Conversion pattern for forth.itof (S>F). +/// Pops i64, converts to f64 via sitofp, bitcasts back to i64, pushes. +struct IToFOpConversion : public OpConversionPattern { + IToFOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::SharedLoadOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(forth::IToFOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ValueRange inputStack = adaptor.getOperands()[0]; Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - auto i64Type = rewriter.getI64Type(); - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), - kWorkgroupAddressSpace); + // Load i64 value from top of stack + Value i64Value = rewriter.create(loc, memref, stackPtr); - // Load address from stack. - Value addrValue = rewriter.create(loc, memref, stackPtr); + // Convert i64 -> f64 via SIToFPOp + Value f64Value = + rewriter.create(loc, rewriter.getF64Type(), i64Value); - // Load value from shared memory via address-space-qualified pointer. - Value ptr = rewriter.create(loc, ptrType, addrValue); - Value loadedValue = rewriter.create(loc, i64Type, ptr); + // Bitcast f64 -> i64 for stack storage + Value result = + rewriter.create(loc, rewriter.getI64Type(), f64Value); - // Store loaded value back at same position (replaces address). - rewriter.create(loc, loadedValue, memref, stackPtr); + // Store result (SP unchanged — unary op) + rewriter.create(loc, result, memref, stackPtr); rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}}); return success(); } }; -/// Conversion pattern for forth.shared_store operation (S!). -/// Pops address and value from stack, stores value via shared/workgroup -/// pointer: ( x addr -- ) -struct SharedStoreOpConversion - : public OpConversionPattern { - SharedStoreOpConversion(const TypeConverter &typeConverter, - MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} +/// Conversion pattern for forth.ftoi (F>S). +/// Pops i64 (f64 bits), bitcasts to f64, converts to i64 via fptosi, pushes. +struct FToIOpConversion : public OpConversionPattern { + FToIOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::SharedStoreOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(forth::FToIOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ValueRange inputStack = adaptor.getOperands()[0]; Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), - kWorkgroupAddressSpace); + // Load i64 (f64 bit pattern) from top of stack + Value i64Bits = rewriter.create(loc, memref, stackPtr); - // Pop address from stack. - Value addrValue = rewriter.create(loc, memref, stackPtr); + // Bitcast i64 -> f64 + Value f64Value = + rewriter.create(loc, rewriter.getF64Type(), i64Bits); - // Pop value from stack. - Value one = rewriter.create(loc, 1); - Value spMinus1 = rewriter.create(loc, stackPtr, one); - Value value = rewriter.create(loc, memref, spMinus1); + // Convert f64 -> i64 via FPToSIOp + Value result = + rewriter.create(loc, rewriter.getI64Type(), f64Value); - // Store value to shared memory via address-space-qualified pointer. - Value ptr = rewriter.create(loc, ptrType, addrValue); - rewriter.create(loc, value, ptr); + // Store result (SP unchanged — unary op) + rewriter.create(loc, result, memref, stackPtr); - // New stack pointer is SP-2 (popped both address and value). - Value spMinus2 = rewriter.create(loc, spMinus1, one); - rewriter.replaceOpWithMultiple(op, {{memref, spMinus2}}); + rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}}); return success(); } }; @@ -1098,17 +1195,34 @@ struct ConvertForthToMemRefPass // Add Forth operation conversion patterns patterns.add< - StackOpConversion, LiteralOpConversion, DupOpConversion, + StackOpConversion, ConstantOpConversion, DupOpConversion, DropOpConversion, SwapOpConversion, OverOpConversion, RotOpConversion, NipOpConversion, TuckOpConversion, PickOpConversion, RollOpConversion, - AddOpConversion, SubOpConversion, MulOpConversion, DivOpConversion, - ModOpConversion, AndOpConversion, OrOpConversion, XorOpConversion, - NotOpConversion, LshiftOpConversion, RshiftOpConversion, EqOpConversion, - LtOpConversion, GtOpConversion, NeOpConversion, LeOpConversion, - GeOpConversion, ZeroEqOpConversion, ParamRefOpConversion, - LoadOpConversion, StoreOpConversion, SharedLoadOpConversion, - SharedStoreOpConversion, PopFlagOpConversion, PopOpConversion, - PushValueOpConversion>(typeConverter, context); + // Integer arithmetic + AddIOpConversion, SubIOpConversion, MulIOpConversion, DivIOpConversion, + ModOpConversion, + // Float arithmetic + AddFOpConversion, SubFOpConversion, MulFOpConversion, DivFOpConversion, + // Bitwise + AndOpConversion, OrOpConversion, XorOpConversion, NotOpConversion, + LshiftOpConversion, RshiftOpConversion, + // Integer comparisons + EqIOpConversion, LtIOpConversion, GtIOpConversion, NeIOpConversion, + LeIOpConversion, GeIOpConversion, + // Float comparisons + EqFOpConversion, LtFOpConversion, GtFOpConversion, NeFOpConversion, + LeFOpConversion, GeFOpConversion, + // Other + ZeroEqOpConversion, ParamRefOpConversion, + // Memory ops (int + float, global + shared) + LoadIOpConversion, StoreIOpConversion, LoadFOpConversion, + StoreFOpConversion, SharedLoadIOpConversion, SharedStoreIOpConversion, + SharedLoadFOpConversion, SharedStoreFOpConversion, + // Type conversions + IToFOpConversion, FToIOpConversion, + // Control flow + PopFlagOpConversion, PopOpConversion, PushValueOpConversion>( + typeConverter, context); // Add GPU indexing op conversion patterns patterns.add>(typeConverter, diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp index fe4923c..aebb7e8 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp @@ -133,6 +133,26 @@ bool ForthLexer::isNumber(const std::string &str) const { return true; } +bool ForthLexer::isFloat(const std::string &str) const { + if (str.empty()) + return false; + + // Try to parse as a double. A valid float must contain a '.' or 'e'/'E'. + bool hasDotOrExp = false; + for (char c : str) { + if (c == '.' || c == 'e' || c == 'E') { + hasDotOrExp = true; + break; + } + } + if (!hasDotOrExp) + return false; + + char *end = nullptr; + std::strtod(str.c_str(), &end); + return end == str.c_str() + str.size(); +} + Token ForthLexer::nextToken() { skipWhitespace(); @@ -158,9 +178,16 @@ Token ForthLexer::nextToken() { } std::string text(tokenStart, curPtr - tokenStart); - Token::Kind kind = isNumber(text) ? Token::Kind::Number : Token::Kind::Word; - if (kind == Token::Kind::Word) + Token::Kind kind; + if (isFloat(text)) { + kind = Token::Kind::Float; + // Don't uppercase float tokens (preserve original text for strtod) + } else if (isNumber(text)) { + kind = Token::Kind::Number; + } else { + kind = Token::Kind::Word; text = toUpperCase(text); + } return Token(kind, text, loc); } @@ -293,6 +320,7 @@ LogicalResult ForthParser::parseHeader() { llvm::StringRef typeToken = tokens[2]; bool isArray = false; int64_t size = 0; + BaseType baseType = BaseType::I64; size_t lbracket = typeToken.find('['); if (lbracket != llvm::StringRef::npos) { size_t rbracket = typeToken.find(']'); @@ -304,7 +332,12 @@ LogicalResult ForthParser::parseHeader() { llvm::StringRef base = typeToken.substr(0, lbracket); llvm::StringRef sizeStr = typeToken.substr(lbracket + 1, rbracket - lbracket - 1); - if (toUpperCase(base) != "I64") { + std::string baseUpper = toUpperCase(base); + if (baseUpper == "I64") { + baseType = BaseType::I64; + } else if (baseUpper == "F64") { + baseType = BaseType::F64; + } else { return emitErrorAt(lineLoc, "unsupported base type: " + base.str()); } if (sizeStr.empty()) @@ -314,7 +347,12 @@ LogicalResult ForthParser::parseHeader() { "array size must be a positive integer"); isArray = true; } else { - if (toUpperCase(typeToken) != "I64") { + std::string typeUpper = toUpperCase(typeToken); + if (typeUpper == "I64") { + baseType = BaseType::I64; + } else if (typeUpper == "F64") { + baseType = BaseType::F64; + } else { return emitErrorAt(lineLoc, "unsupported scalar type: " + typeToken.str()); } @@ -325,12 +363,14 @@ LogicalResult ForthParser::parseHeader() { decl.name = nameUpper; decl.isArray = isArray; decl.size = size; + decl.baseType = baseType; paramDecls.push_back(decl); } else { SharedDecl decl; decl.name = nameUpper; decl.isArray = isArray; decl.size = size; + decl.baseType = baseType; sharedDecls.push_back(decl); } } else { @@ -423,13 +463,13 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, .getResult(0); } - // CELLS: multiply by 8 (sizeof i64) for byte addressing + // CELLS: multiply by 8 (sizeof i64 = sizeof f64) for byte addressing if (word == "CELLS") { Value lit8 = builder - .create(loc, stackType, inputStack, - builder.getI64IntegerAttr(8)) + .create(loc, stackType, inputStack, + builder.getI64IntegerAttr(8)) .getResult(); - return builder.create(loc, stackType, lit8).getResult(); + return builder.create(loc, stackType, lit8).getResult(); } // Built-in operations @@ -458,13 +498,29 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, return builder.create(loc, stackType, inputStack) .getResult(); } else if (word == "+" || word == "ADD") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "-" || word == "SUB") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "*" || word == "MUL") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "/" || word == "DIV") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "F+") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "F-") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "F*") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "F/") { + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "MOD") { return builder.create(loc, stackType, inputStack).getResult(); } else if (word == "AND") { @@ -482,16 +538,28 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, return builder.create(loc, stackType, inputStack) .getResult(); } else if (word == "@") { - return builder.create(loc, stackType, inputStack) + return builder.create(loc, stackType, inputStack) .getResult(); } else if (word == "!") { - return builder.create(loc, stackType, inputStack) + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "F@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "F!") { + return builder.create(loc, stackType, inputStack) .getResult(); } else if (word == "S@") { - return builder.create(loc, stackType, inputStack) + return builder.create(loc, stackType, inputStack) .getResult(); } else if (word == "S!") { - return builder.create(loc, stackType, inputStack) + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SF@") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "SF!") { + return builder.create(loc, stackType, inputStack) .getResult(); } else if (word == "TID-X") { return builder.create(loc, stackType, inputStack) @@ -536,17 +604,35 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, builder.create(loc); return inputStack; } else if (word == "=") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack).getResult(); } else if (word == "<") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack).getResult(); } else if (word == ">") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack).getResult(); } else if (word == "<>") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack).getResult(); } else if (word == "<=") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack).getResult(); } else if (word == ">=") { - return builder.create(loc, stackType, inputStack).getResult(); + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "F=") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "F<") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "F>") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "F<>") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "F<=") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "F>=") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "S>F") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "F>S") { + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "0=") { return builder.create(loc, stackType, inputStack) .getResult(); @@ -629,8 +715,16 @@ LogicalResult ForthParser::parseBody(Value &stack) { Location tokenLoc = getLoc(); int64_t value = std::stoll(currentToken.text); stack = builder - .create(tokenLoc, stackType, stack, - builder.getI64IntegerAttr(value)) + .create(tokenLoc, stackType, stack, + builder.getI64IntegerAttr(value)) + .getResult(); + consume(); + } else if (currentToken.kind == Token::Kind::Float) { + Location tokenLoc = getLoc(); + double value = std::stod(currentToken.text); + stack = builder + .create(tokenLoc, stackType, stack, + builder.getF64FloatAttr(value)) .getResult(); consume(); } else if (currentToken.kind == Token::Kind::Word) { @@ -1031,10 +1125,13 @@ OwningOpRef ForthParser::parseModule() { // Build function argument types from param declarations SmallVector argTypes; for (const auto ¶m : paramDecls) { + Type elemType = param.baseType == BaseType::F64 + ? Type(builder.getF64Type()) + : Type(builder.getI64Type()); if (param.isArray) { - argTypes.push_back(MemRefType::get({param.size}, builder.getI64Type())); + argTypes.push_back(MemRefType::get({param.size}, elemType)); } else { - argTypes.push_back(builder.getI64Type()); + argTypes.push_back(elemType); } } @@ -1056,7 +1153,10 @@ OwningOpRef ForthParser::parseModule() { // Emit shared memory allocations at kernel entry for (const auto &shared : sharedDecls) { int64_t size = shared.isArray ? shared.size : 1; - auto memrefType = MemRefType::get({size}, builder.getI64Type()); + Type elemType = shared.baseType == BaseType::F64 + ? Type(builder.getF64Type()) + : Type(builder.getI64Type()); + auto memrefType = MemRefType::get({size}, elemType); Value alloca = builder.create(loc, memrefType); alloca.getDefiningOp()->setAttr("forth.shared_name", builder.getStringAttr(shared.name)); diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.h b/lib/Translation/ForthToMLIR/ForthToMLIR.h index 025c8a4..6aa4809 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.h +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.h @@ -18,11 +18,15 @@ namespace mlir { namespace forth { +/// Base element type for param/shared declarations. +enum class BaseType { I64, F64 }; + /// A declared kernel parameter: `param `. struct ParamDecl { std::string name; bool isArray = false; int64_t size = 0; + BaseType baseType = BaseType::I64; }; /// A declared shared memory region: `shared `. @@ -30,11 +34,12 @@ struct SharedDecl { std::string name; bool isArray = false; int64_t size = 0; + BaseType baseType = BaseType::I64; }; /// Simple token representing a Forth word or literal. struct Token { - enum class Kind { Number, Word, Colon, Semicolon, EndOfFile }; + enum class Kind { Number, Float, Word, Colon, Semicolon, EndOfFile }; Kind kind; std::string text; @@ -73,6 +78,9 @@ class ForthLexer { /// Check if a string is a number. bool isNumber(const std::string &str) const; + + /// Check if a string is a floating-point number. + bool isFloat(const std::string &str) const; }; /// Parser and translator for Forth source code to MLIR. diff --git a/test/Conversion/ForthToMemRef/arithmetic.mlir b/test/Conversion/ForthToMemRef/arithmetic.mlir index 9d3b218..a0bbf4b 100644 --- a/test/Conversion/ForthToMemRef/arithmetic.mlir +++ b/test/Conversion/ForthToMemRef/arithmetic.mlir @@ -36,16 +36,16 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 10 : !forth.stack -> !forth.stack - %2 = forth.literal %1 20 : !forth.stack -> !forth.stack - %3 = forth.add %2 : !forth.stack -> !forth.stack - %4 = forth.literal %3 3 : !forth.stack -> !forth.stack - %5 = forth.sub %4 : !forth.stack -> !forth.stack - %6 = forth.literal %5 4 : !forth.stack -> !forth.stack - %7 = forth.mul %6 : !forth.stack -> !forth.stack - %8 = forth.literal %7 2 : !forth.stack -> !forth.stack - %9 = forth.div %8 : !forth.stack -> !forth.stack - %10 = forth.literal %9 5 : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack + %2 = forth.constant %1(20 : i64) : !forth.stack -> !forth.stack + %3 = forth.addi %2 : !forth.stack -> !forth.stack + %4 = forth.constant %3(3 : i64) : !forth.stack -> !forth.stack + %5 = forth.subi %4 : !forth.stack -> !forth.stack + %6 = forth.constant %5(4 : i64) : !forth.stack -> !forth.stack + %7 = forth.muli %6 : !forth.stack -> !forth.stack + %8 = forth.constant %7(2 : i64) : !forth.stack -> !forth.stack + %9 = forth.divi %8 : !forth.stack -> !forth.stack + %10 = forth.constant %9(5 : i64) : !forth.stack -> !forth.stack %11 = forth.mod %10 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/begin-until.mlir b/test/Conversion/ForthToMemRef/begin-until.mlir index 1df7aeb..e9b5b28 100644 --- a/test/Conversion/ForthToMemRef/begin-until.mlir +++ b/test/Conversion/ForthToMemRef/begin-until.mlir @@ -31,11 +31,11 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 10 : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack cf.br ^bb1(%1 : !forth.stack) ^bb1(%2: !forth.stack): - %3 = forth.literal %2 1 : !forth.stack -> !forth.stack - %4 = forth.sub %3 : !forth.stack -> !forth.stack + %3 = forth.constant %2(1 : i64) : !forth.stack -> !forth.stack + %4 = forth.subi %3 : !forth.stack -> !forth.stack %5 = forth.dup %4 : !forth.stack -> !forth.stack %6 = forth.zero_eq %5 : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %6 : !forth.stack -> !forth.stack, i1 diff --git a/test/Conversion/ForthToMemRef/begin-while-repeat.mlir b/test/Conversion/ForthToMemRef/begin-while-repeat.mlir index ba23318..99c5c2e 100644 --- a/test/Conversion/ForthToMemRef/begin-while-repeat.mlir +++ b/test/Conversion/ForthToMemRef/begin-while-repeat.mlir @@ -38,17 +38,17 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 10 : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack cf.br ^bb1(%1 : !forth.stack) ^bb1(%2: !forth.stack): %3 = forth.dup %2 : !forth.stack -> !forth.stack - %4 = forth.literal %3 0 : !forth.stack -> !forth.stack - %5 = forth.gt %4 : !forth.stack -> !forth.stack + %4 = forth.constant %3(0 : i64) : !forth.stack -> !forth.stack + %5 = forth.gti %4 : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %5 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb2(%output_stack : !forth.stack), ^bb3(%output_stack : !forth.stack) ^bb2(%6: !forth.stack): - %7 = forth.literal %6 1 : !forth.stack -> !forth.stack - %8 = forth.sub %7 : !forth.stack -> !forth.stack + %7 = forth.constant %6(1 : i64) : !forth.stack -> !forth.stack + %8 = forth.subi %7 : !forth.stack -> !forth.stack cf.br ^bb1(%8 : !forth.stack) ^bb3(%9: !forth.stack): return diff --git a/test/Conversion/ForthToMemRef/bitwise.mlir b/test/Conversion/ForthToMemRef/bitwise.mlir index be8ea2c..243546a 100644 --- a/test/Conversion/ForthToMemRef/bitwise.mlir +++ b/test/Conversion/ForthToMemRef/bitwise.mlir @@ -42,22 +42,22 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 3 : !forth.stack -> !forth.stack - %2 = forth.literal %1 5 : !forth.stack -> !forth.stack + %1 = forth.constant %0(3 : i64) : !forth.stack -> !forth.stack + %2 = forth.constant %1(5 : i64) : !forth.stack -> !forth.stack %3 = forth.and %2 : !forth.stack -> !forth.stack - %4 = forth.literal %3 7 : !forth.stack -> !forth.stack - %5 = forth.literal %4 8 : !forth.stack -> !forth.stack + %4 = forth.constant %3(7 : i64) : !forth.stack -> !forth.stack + %5 = forth.constant %4(8 : i64) : !forth.stack -> !forth.stack %6 = forth.or %5 : !forth.stack -> !forth.stack - %7 = forth.literal %6 15 : !forth.stack -> !forth.stack - %8 = forth.literal %7 3 : !forth.stack -> !forth.stack + %7 = forth.constant %6(15 : i64) : !forth.stack -> !forth.stack + %8 = forth.constant %7(3 : i64) : !forth.stack -> !forth.stack %9 = forth.xor %8 : !forth.stack -> !forth.stack - %10 = forth.literal %9 42 : !forth.stack -> !forth.stack + %10 = forth.constant %9(42 : i64) : !forth.stack -> !forth.stack %11 = forth.not %10 : !forth.stack -> !forth.stack - %12 = forth.literal %11 1 : !forth.stack -> !forth.stack - %13 = forth.literal %12 4 : !forth.stack -> !forth.stack + %12 = forth.constant %11(1 : i64) : !forth.stack -> !forth.stack + %13 = forth.constant %12(4 : i64) : !forth.stack -> !forth.stack %14 = forth.lshift %13 : !forth.stack -> !forth.stack - %15 = forth.literal %14 256 : !forth.stack -> !forth.stack - %16 = forth.literal %15 2 : !forth.stack -> !forth.stack + %15 = forth.constant %14(256 : i64) : !forth.stack -> !forth.stack + %16 = forth.constant %15(2 : i64) : !forth.stack -> !forth.stack %17 = forth.rshift %16 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/comparison.mlir b/test/Conversion/ForthToMemRef/comparison.mlir index 2a03b13..4a546b2 100644 --- a/test/Conversion/ForthToMemRef/comparison.mlir +++ b/test/Conversion/ForthToMemRef/comparison.mlir @@ -55,26 +55,26 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 1 : !forth.stack -> !forth.stack - %2 = forth.literal %1 2 : !forth.stack -> !forth.stack - %3 = forth.eq %2 : !forth.stack -> !forth.stack - %4 = forth.literal %3 3 : !forth.stack -> !forth.stack - %5 = forth.literal %4 4 : !forth.stack -> !forth.stack - %6 = forth.lt %5 : !forth.stack -> !forth.stack - %7 = forth.literal %6 5 : !forth.stack -> !forth.stack - %8 = forth.literal %7 6 : !forth.stack -> !forth.stack - %9 = forth.gt %8 : !forth.stack -> !forth.stack - %10 = forth.literal %9 0 : !forth.stack -> !forth.stack + %1 = forth.constant %0(1 : i64) : !forth.stack -> !forth.stack + %2 = forth.constant %1(2 : i64) : !forth.stack -> !forth.stack + %3 = forth.eqi %2 : !forth.stack -> !forth.stack + %4 = forth.constant %3(3 : i64) : !forth.stack -> !forth.stack + %5 = forth.constant %4(4 : i64) : !forth.stack -> !forth.stack + %6 = forth.lti %5 : !forth.stack -> !forth.stack + %7 = forth.constant %6(5 : i64) : !forth.stack -> !forth.stack + %8 = forth.constant %7(6 : i64) : !forth.stack -> !forth.stack + %9 = forth.gti %8 : !forth.stack -> !forth.stack + %10 = forth.constant %9(0 : i64) : !forth.stack -> !forth.stack %11 = forth.zero_eq %10 : !forth.stack -> !forth.stack - %12 = forth.literal %11 7 : !forth.stack -> !forth.stack - %13 = forth.literal %12 8 : !forth.stack -> !forth.stack - %14 = forth.ne %13 : !forth.stack -> !forth.stack - %15 = forth.literal %14 9 : !forth.stack -> !forth.stack - %16 = forth.literal %15 10 : !forth.stack -> !forth.stack - %17 = forth.le %16 : !forth.stack -> !forth.stack - %18 = forth.literal %17 11 : !forth.stack -> !forth.stack - %19 = forth.literal %18 12 : !forth.stack -> !forth.stack - %20 = forth.ge %19 : !forth.stack -> !forth.stack + %12 = forth.constant %11(7 : i64) : !forth.stack -> !forth.stack + %13 = forth.constant %12(8 : i64) : !forth.stack -> !forth.stack + %14 = forth.nei %13 : !forth.stack -> !forth.stack + %15 = forth.constant %14(9 : i64) : !forth.stack -> !forth.stack + %16 = forth.constant %15(10 : i64) : !forth.stack -> !forth.stack + %17 = forth.lei %16 : !forth.stack -> !forth.stack + %18 = forth.constant %17(11 : i64) : !forth.stack -> !forth.stack + %19 = forth.constant %18(12 : i64) : !forth.stack -> !forth.stack + %20 = forth.gei %19 : !forth.stack -> !forth.stack return } } diff --git a/test/Conversion/ForthToMemRef/control-flow.mlir b/test/Conversion/ForthToMemRef/control-flow.mlir index 121c0a8..a43493a 100644 --- a/test/Conversion/ForthToMemRef/control-flow.mlir +++ b/test/Conversion/ForthToMemRef/control-flow.mlir @@ -47,21 +47,21 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 1 : !forth.stack -> !forth.stack + %1 = forth.constant %0(1 : i64) : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %1 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) ^bb1(%2: !forth.stack): - %3 = forth.literal %2 42 : !forth.stack -> !forth.stack + %3 = forth.constant %2(42 : i64) : !forth.stack -> !forth.stack cf.br ^bb3(%3 : !forth.stack) ^bb2(%4: !forth.stack): - %5 = forth.literal %4 99 : !forth.stack -> !forth.stack + %5 = forth.constant %4(99 : i64) : !forth.stack -> !forth.stack cf.br ^bb3(%5 : !forth.stack) ^bb3(%6: !forth.stack): - %7 = forth.literal %6 0 : !forth.stack -> !forth.stack + %7 = forth.constant %6(0 : i64) : !forth.stack -> !forth.stack %output_stack_0, %flag_1 = forth.pop_flag %7 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag_1, ^bb4(%output_stack_0 : !forth.stack), ^bb5(%output_stack_0 : !forth.stack) ^bb4(%8: !forth.stack): - %9 = forth.literal %8 7 : !forth.stack -> !forth.stack + %9 = forth.constant %8(7 : i64) : !forth.stack -> !forth.stack cf.br ^bb5(%9 : !forth.stack) ^bb5(%10: !forth.stack): return diff --git a/test/Conversion/ForthToMemRef/do-loop.mlir b/test/Conversion/ForthToMemRef/do-loop.mlir index 034807a..b196aff 100644 --- a/test/Conversion/ForthToMemRef/do-loop.mlir +++ b/test/Conversion/ForthToMemRef/do-loop.mlir @@ -42,8 +42,8 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 10 : !forth.stack -> !forth.stack - %2 = forth.literal %1 0 : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack + %2 = forth.constant %1(0 : i64) : !forth.stack -> !forth.stack %output_stack, %value = forth.pop %2 : !forth.stack -> !forth.stack, i64 %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 %alloca = memref.alloca() : memref<1xi64> diff --git a/test/Conversion/ForthToMemRef/float-arithmetic.mlir b/test/Conversion/ForthToMemRef/float-arithmetic.mlir new file mode 100644 index 0000000..638b9b6 --- /dev/null +++ b/test/Conversion/ForthToMemRef/float-arithmetic.mlir @@ -0,0 +1,44 @@ +// RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s + +// CHECK-LABEL: func.func private @main + +// addf: pop two, bitcast to f64, arith.addf, bitcast back, store +// CHECK: memref.load +// CHECK: arith.subi +// CHECK: memref.load +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.addf %{{.*}}, %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: memref.store + +// subf: bitcast, subf, bitcast +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.subf %{{.*}}, %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 + +// mulf: bitcast, mulf, bitcast +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.mulf %{{.*}}, %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 + +// divf: bitcast, divf, bitcast +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: arith.divf %{{.*}}, %{{.*}} : f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 + +module { + func.func private @main() { + %0 = forth.stack !forth.stack + %1 = forth.constant %0(1.000000e+00 : f64) : !forth.stack -> !forth.stack + %2 = forth.constant %1(2.000000e+00 : f64) : !forth.stack -> !forth.stack + %3 = forth.addf %2 : !forth.stack -> !forth.stack + %4 = forth.subf %3 : !forth.stack -> !forth.stack + %5 = forth.mulf %4 : !forth.stack -> !forth.stack + %6 = forth.divf %5 : !forth.stack -> !forth.stack + return + } +} diff --git a/test/Conversion/ForthToMemRef/float-memory.mlir b/test/Conversion/ForthToMemRef/float-memory.mlir new file mode 100644 index 0000000..76c3462 --- /dev/null +++ b/test/Conversion/ForthToMemRef/float-memory.mlir @@ -0,0 +1,29 @@ +// RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s + +// CHECK-LABEL: func.func private @main + +// loadf: load addr, inttoptr, llvm.load f64, bitcast f64->i64, store +// CHECK: memref.load +// CHECK: llvm.inttoptr +// CHECK: llvm.load %{{.*}} : !llvm.ptr -> f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: memref.store + +// storef: load addr, load value, inttoptr, bitcast i64->f64, llvm.store +// CHECK: memref.load +// CHECK: memref.load +// CHECK: llvm.inttoptr +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +// CHECK: llvm.store %{{.*}}, %{{.*}} : f64 + +module { + func.func private @main() { + %0 = forth.stack !forth.stack + %1 = forth.constant %0(1000 : i64) : !forth.stack -> !forth.stack + %2 = forth.loadf %1 : !forth.stack -> !forth.stack + %3 = forth.constant %2(3.140000e+00 : f64) : !forth.stack -> !forth.stack + %4 = forth.constant %3(2000 : i64) : !forth.stack -> !forth.stack + %5 = forth.storef %4 : !forth.stack -> !forth.stack + return + } +} diff --git a/test/Conversion/ForthToMemRef/leave.mlir b/test/Conversion/ForthToMemRef/leave.mlir index 2c93ddc..d010ed0 100644 --- a/test/Conversion/ForthToMemRef/leave.mlir +++ b/test/Conversion/ForthToMemRef/leave.mlir @@ -14,8 +14,8 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 10 : !forth.stack -> !forth.stack - %2 = forth.literal %1 0 : !forth.stack -> !forth.stack + %1 = forth.constant %0(10 : i64) : !forth.stack -> !forth.stack + %2 = forth.constant %1(0 : i64) : !forth.stack -> !forth.stack %output_stack, %value = forth.pop %2 : !forth.stack -> !forth.stack, i64 %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 %alloca = memref.alloca() : memref<1xi64> diff --git a/test/Conversion/ForthToMemRef/literal.mlir b/test/Conversion/ForthToMemRef/literal.mlir index d6dc84f..fce38e4 100644 --- a/test/Conversion/ForthToMemRef/literal.mlir +++ b/test/Conversion/ForthToMemRef/literal.mlir @@ -11,7 +11,7 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 42 : !forth.stack -> !forth.stack + %1 = forth.constant %0(42 : i64) : !forth.stack -> !forth.stack return } } diff --git a/test/Conversion/ForthToMemRef/memory-ops.mlir b/test/Conversion/ForthToMemRef/memory-ops.mlir index d4965ef..976f39e 100644 --- a/test/Conversion/ForthToMemRef/memory-ops.mlir +++ b/test/Conversion/ForthToMemRef/memory-ops.mlir @@ -26,16 +26,16 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 1 : !forth.stack -> !forth.stack - %2 = forth.load %1 : !forth.stack -> !forth.stack - %3 = forth.literal %2 42 : !forth.stack -> !forth.stack - %4 = forth.literal %3 100 : !forth.stack -> !forth.stack - %5 = forth.store %4 : !forth.stack -> !forth.stack - %6 = forth.literal %5 2 : !forth.stack -> !forth.stack - %7 = forth.shared_load %6 : !forth.stack -> !forth.stack - %8 = forth.literal %7 9 : !forth.stack -> !forth.stack - %9 = forth.literal %8 3 : !forth.stack -> !forth.stack - %10 = forth.shared_store %9 : !forth.stack -> !forth.stack + %1 = forth.constant %0(1 : i64) : !forth.stack -> !forth.stack + %2 = forth.loadi %1 : !forth.stack -> !forth.stack + %3 = forth.constant %2(42 : i64) : !forth.stack -> !forth.stack + %4 = forth.constant %3(100 : i64) : !forth.stack -> !forth.stack + %5 = forth.storei %4 : !forth.stack -> !forth.stack + %6 = forth.constant %5(2 : i64) : !forth.stack -> !forth.stack + %7 = forth.shared_loadi %6 : !forth.stack -> !forth.stack + %8 = forth.constant %7(9 : i64) : !forth.stack -> !forth.stack + %9 = forth.constant %8(3 : i64) : !forth.stack -> !forth.stack + %10 = forth.shared_storei %9 : !forth.stack -> !forth.stack return } } diff --git a/test/Conversion/ForthToMemRef/nested-control-flow.mlir b/test/Conversion/ForthToMemRef/nested-control-flow.mlir index 4ac0dab..8c09f33 100644 --- a/test/Conversion/ForthToMemRef/nested-control-flow.mlir +++ b/test/Conversion/ForthToMemRef/nested-control-flow.mlir @@ -160,24 +160,24 @@ module { func.func private @TEST__NESTED__IF(%arg0: !forth.stack) -> !forth.stack { - %0 = forth.literal %arg0 1 : !forth.stack -> !forth.stack + %0 = forth.constant %arg0(1 : i64) : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %0 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) ^bb1(%1: !forth.stack): - %2 = forth.literal %1 2 : !forth.stack -> !forth.stack + %2 = forth.constant %1(2 : i64) : !forth.stack -> !forth.stack %output_stack_0, %flag_1 = forth.pop_flag %2 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag_1, ^bb3(%output_stack_0 : !forth.stack), ^bb4(%output_stack_0 : !forth.stack) ^bb2(%3: !forth.stack): return %3 : !forth.stack ^bb3(%4: !forth.stack): - %5 = forth.literal %4 3 : !forth.stack -> !forth.stack + %5 = forth.constant %4(3 : i64) : !forth.stack -> !forth.stack cf.br ^bb4(%5 : !forth.stack) ^bb4(%6: !forth.stack): cf.br ^bb2(%6 : !forth.stack) } func.func private @TEST__IF__INSIDE__DO(%arg0: !forth.stack) -> !forth.stack { - %0 = forth.literal %arg0 10 : !forth.stack -> !forth.stack - %1 = forth.literal %0 0 : !forth.stack -> !forth.stack + %0 = forth.constant %arg0(10 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(0 : i64) : !forth.stack -> !forth.stack %output_stack, %value = forth.pop %1 : !forth.stack -> !forth.stack, i64 %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 %alloca = memref.alloca() : memref<1xi64> @@ -193,8 +193,8 @@ module { %c0_3 = arith.constant 0 : index %6 = memref.load %alloca[%c0_3] : memref<1xi64> %7 = forth.push_value %5, %6 : !forth.stack, i64 -> !forth.stack - %8 = forth.literal %7 5 : !forth.stack -> !forth.stack - %9 = forth.gt %8 : !forth.stack -> !forth.stack + %8 = forth.constant %7(5 : i64) : !forth.stack -> !forth.stack + %9 = forth.gti %8 : !forth.stack -> !forth.stack %output_stack_4, %flag = forth.pop_flag %9 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb4(%output_stack_4 : !forth.stack), ^bb5(%output_stack_4 : !forth.stack) ^bb3(%10: !forth.stack): @@ -213,8 +213,8 @@ module { cf.br ^bb1(%14 : !forth.stack) } func.func private @TEST__NESTED__DO__J(%arg0: !forth.stack) -> !forth.stack { - %0 = forth.literal %arg0 3 : !forth.stack -> !forth.stack - %1 = forth.literal %0 0 : !forth.stack -> !forth.stack + %0 = forth.constant %arg0(3 : i64) : !forth.stack -> !forth.stack + %1 = forth.constant %0(0 : i64) : !forth.stack -> !forth.stack %output_stack, %value = forth.pop %1 : !forth.stack -> !forth.stack, i64 %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 %alloca = memref.alloca() : memref<1xi64> @@ -227,8 +227,8 @@ module { %4 = arith.cmpi slt, %3, %value_1 : i64 cf.cond_br %4, ^bb2(%2 : !forth.stack), ^bb3(%2 : !forth.stack) ^bb2(%5: !forth.stack): - %6 = forth.literal %5 4 : !forth.stack -> !forth.stack - %7 = forth.literal %6 0 : !forth.stack -> !forth.stack + %6 = forth.constant %5(4 : i64) : !forth.stack -> !forth.stack + %7 = forth.constant %6(0 : i64) : !forth.stack -> !forth.stack %output_stack_3, %value_4 = forth.pop %7 : !forth.stack -> !forth.stack, i64 %output_stack_5, %value_6 = forth.pop %output_stack_3 : !forth.stack -> !forth.stack, i64 %alloca_7 = memref.alloca() : memref<1xi64> @@ -249,7 +249,7 @@ module { %c0_11 = arith.constant 0 : index %15 = memref.load %alloca_7[%c0_11] : memref<1xi64> %16 = forth.push_value %14, %15 : !forth.stack, i64 -> !forth.stack - %17 = forth.add %16 : !forth.stack -> !forth.stack + %17 = forth.addi %16 : !forth.stack -> !forth.stack %c0_12 = arith.constant 0 : index %18 = memref.load %alloca_7[%c0_12] : memref<1xi64> %c1_i64 = arith.constant 1 : i64 @@ -265,7 +265,7 @@ module { cf.br ^bb1(%20 : !forth.stack) } func.func private @TEST__WHILE__INSIDE__IF(%arg0: !forth.stack) -> !forth.stack { - %0 = forth.literal %arg0 5 : !forth.stack -> !forth.stack + %0 = forth.constant %arg0(5 : i64) : !forth.stack -> !forth.stack %output_stack, %flag = forth.pop_flag %0 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) ^bb1(%1: !forth.stack): @@ -277,8 +277,8 @@ module { %output_stack_0, %flag_1 = forth.pop_flag %4 : !forth.stack -> !forth.stack, i1 cf.cond_br %flag_1, ^bb4(%output_stack_0 : !forth.stack), ^bb5(%output_stack_0 : !forth.stack) ^bb4(%5: !forth.stack): - %6 = forth.literal %5 1 : !forth.stack -> !forth.stack - %7 = forth.sub %6 : !forth.stack -> !forth.stack + %6 = forth.constant %5(1 : i64) : !forth.stack -> !forth.stack + %7 = forth.subi %6 : !forth.stack -> !forth.stack cf.br ^bb3(%7 : !forth.stack) ^bb5(%8: !forth.stack): cf.br ^bb2(%8 : !forth.stack) diff --git a/test/Conversion/ForthToMemRef/stack-manipulation.mlir b/test/Conversion/ForthToMemRef/stack-manipulation.mlir index e8a1721..4da0597 100644 --- a/test/Conversion/ForthToMemRef/stack-manipulation.mlir +++ b/test/Conversion/ForthToMemRef/stack-manipulation.mlir @@ -76,9 +76,9 @@ module { func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 1 : !forth.stack -> !forth.stack - %2 = forth.literal %1 2 : !forth.stack -> !forth.stack - %3 = forth.literal %2 3 : !forth.stack -> !forth.stack + %1 = forth.constant %0(1 : i64) : !forth.stack -> !forth.stack + %2 = forth.constant %1(2 : i64) : !forth.stack -> !forth.stack + %3 = forth.constant %2(3 : i64) : !forth.stack -> !forth.stack %4 = forth.dup %3 : !forth.stack -> !forth.stack %5 = forth.drop %4 : !forth.stack -> !forth.stack %6 = forth.swap %5 : !forth.stack -> !forth.stack @@ -86,9 +86,9 @@ module { %8 = forth.rot %7 : !forth.stack -> !forth.stack %9 = forth.nip %8 : !forth.stack -> !forth.stack %10 = forth.tuck %9 : !forth.stack -> !forth.stack - %11 = forth.literal %10 2 : !forth.stack -> !forth.stack + %11 = forth.constant %10(2 : i64) : !forth.stack -> !forth.stack %12 = forth.pick %11 : !forth.stack -> !forth.stack - %13 = forth.literal %12 2 : !forth.stack -> !forth.stack + %13 = forth.constant %12(2 : i64) : !forth.stack -> !forth.stack %14 = forth.roll %13 : !forth.stack -> !forth.stack return } diff --git a/test/Conversion/ForthToMemRef/user-defined-words.mlir b/test/Conversion/ForthToMemRef/user-defined-words.mlir index c7bf11e..6634934 100644 --- a/test/Conversion/ForthToMemRef/user-defined-words.mlir +++ b/test/Conversion/ForthToMemRef/user-defined-words.mlir @@ -14,12 +14,12 @@ module { func.func private @double(%arg0: !forth.stack) -> !forth.stack { %0 = forth.dup %arg0 : !forth.stack -> !forth.stack - %1 = forth.add %0 : !forth.stack -> !forth.stack + %1 = forth.addi %0 : !forth.stack -> !forth.stack return %1 : !forth.stack } func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 5 : !forth.stack -> !forth.stack + %1 = forth.constant %0(5 : i64) : !forth.stack -> !forth.stack %2 = call @double(%1) : (!forth.stack) -> !forth.stack return } diff --git a/test/Pipeline/float-pipeline.forth b/test/Pipeline/float-pipeline.forth new file mode 100644 index 0000000..2036b97 --- /dev/null +++ b/test/Pipeline/float-pipeline.forth @@ -0,0 +1,20 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --convert-forth-to-memref --convert-forth-to-gpu | %FileCheck %s --check-prefix=MID + +\ Verify that Forth with f64 params through the full pipeline produces a gpu.binary +\ CHECK: gpu.binary @warpforth_module + +\ Verify intermediate MLIR structure at the memref+gpu stage +\ MID: gpu.module @warpforth_module +\ MID: gpu.func @main(%arg0: memref<256xf64> {forth.param_name = "DATA"}, %arg1: f64 {forth.param_name = "SCALE"}) kernel +\ MID: memref.alloca() : memref<256xi64> +\ MID: memref.extract_aligned_pointer_as_index %arg0 +\ MID: arith.bitcast %{{.*}} : f64 to i64 +\ MID: gpu.return + +\! kernel main +\! param DATA f64[256] +\! param SCALE f64 +GLOBAL-ID CELLS DATA + F@ +SCALE F* +GLOBAL-ID CELLS DATA + F! diff --git a/test/Translation/Forth/arithmetic-ops.forth b/test/Translation/Forth/arithmetic-ops.forth index 4f767e2..7bbfef6 100644 --- a/test/Translation/Forth/arithmetic-ops.forth +++ b/test/Translation/Forth/arithmetic-ops.forth @@ -2,12 +2,12 @@ \ Verify SSA chaining: each op consumes the previous stack value \ CHECK: %[[S0:.*]] = forth.stack -\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] -\ CHECK: %[[S2:.*]] = forth.literal %[[S1]] -\ CHECK: %[[S3:.*]] = forth.add %[[S2]] -\ CHECK: %[[S4:.*]] = forth.sub %[[S3]] -\ CHECK: %[[S5:.*]] = forth.mul %[[S4]] -\ CHECK: %[[S6:.*]] = forth.div %[[S5]] +\ CHECK: %[[S1:.*]] = forth.constant %[[S0]] +\ CHECK: %[[S2:.*]] = forth.constant %[[S1]] +\ CHECK: %[[S3:.*]] = forth.addi %[[S2]] +\ CHECK: %[[S4:.*]] = forth.subi %[[S3]] +\ CHECK: %[[S5:.*]] = forth.muli %[[S4]] +\ CHECK: %[[S6:.*]] = forth.divi %[[S5]] \ CHECK: %{{.*}} = forth.mod %[[S6]] \! kernel main 1 2 + - * / MOD diff --git a/test/Translation/Forth/basic-literals.forth b/test/Translation/Forth/basic-literals.forth index 9cff6b5..e739e93 100644 --- a/test/Translation/Forth/basic-literals.forth +++ b/test/Translation/Forth/basic-literals.forth @@ -1,8 +1,8 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s \ CHECK: forth.stack -\ CHECK-NEXT: forth.literal %{{.*}} 42 -\ CHECK-NEXT: forth.literal %{{.*}} -7 -\ CHECK-NEXT: forth.literal %{{.*}} 0 +\ CHECK-NEXT: forth.constant %{{.*}}(42 : i64) +\ CHECK-NEXT: forth.constant %{{.*}}(-7 : i64) +\ CHECK-NEXT: forth.constant %{{.*}}(0 : i64) \! kernel main 42 -7 0 diff --git a/test/Translation/Forth/begin-until.forth b/test/Translation/Forth/begin-until.forth index fec8b73..ceca20d 100644 --- a/test/Translation/Forth/begin-until.forth +++ b/test/Translation/Forth/begin-until.forth @@ -3,11 +3,11 @@ \ Verify BEGIN/UNTIL generates loop with pop_flag + cond_br \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb1(%[[S1]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L1:.*]] = forth.literal %[[B1]] 1 : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[SUB:.*]] = forth.sub %[[L1]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L1:.*]] = forth.constant %[[B1]](1 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[SUB:.*]] = forth.subi %[[L1]] : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[DUP:.*]] = forth.dup %[[SUB]] : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[ZEQ:.*]] = forth.zero_eq %[[DUP]] : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF:.*]], %[[FLAG:.*]] = forth.pop_flag %[[ZEQ]] : !forth.stack -> !forth.stack, i1 diff --git a/test/Translation/Forth/begin-while-repeat.forth b/test/Translation/Forth/begin-while-repeat.forth index 90a2a96..3a9e40d 100644 --- a/test/Translation/Forth/begin-while-repeat.forth +++ b/test/Translation/Forth/begin-while-repeat.forth @@ -3,17 +3,17 @@ \ Verify BEGIN/WHILE/REPEAT generates condition check + body loop with cond_br \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb1(%[[S1]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): \ CHECK-NEXT: %[[DUP:.*]] = forth.dup %[[B1]] : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[L0:.*]] = forth.literal %[[DUP]] 0 : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[GT:.*]] = forth.gt %[[L0]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L0:.*]] = forth.constant %[[DUP]](0 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[GT:.*]] = forth.gti %[[L0]] : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF:.*]], %[[FLAG:.*]] = forth.pop_flag %[[GT]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FLAG]], ^bb2(%[[PF]] : !forth.stack), ^bb3(%[[PF]] : !forth.stack) \ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L1:.*]] = forth.literal %[[B2]] 1 : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[SUB:.*]] = forth.sub %[[L1]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L1:.*]] = forth.constant %[[B2]](1 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[SUB:.*]] = forth.subi %[[L1]] : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb1(%[[SUB]] : !forth.stack) \ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): \ CHECK-NEXT: return diff --git a/test/Translation/Forth/bitwise-ops.forth b/test/Translation/Forth/bitwise-ops.forth index cdf7c9e..7749a20 100644 --- a/test/Translation/Forth/bitwise-ops.forth +++ b/test/Translation/Forth/bitwise-ops.forth @@ -2,22 +2,22 @@ \ Verify bitwise operations parse correctly with SSA chaining \ CHECK: %[[S0:.*]] = forth.stack -\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] -\ CHECK: %[[S2:.*]] = forth.literal %[[S1]] +\ CHECK: %[[S1:.*]] = forth.constant %[[S0]] +\ CHECK: %[[S2:.*]] = forth.constant %[[S1]] \ CHECK: %[[S3:.*]] = forth.and %[[S2]] -\ CHECK: %[[S4:.*]] = forth.literal %[[S3]] -\ CHECK: %[[S5:.*]] = forth.literal %[[S4]] +\ CHECK: %[[S4:.*]] = forth.constant %[[S3]] +\ CHECK: %[[S5:.*]] = forth.constant %[[S4]] \ CHECK: %[[S6:.*]] = forth.or %[[S5]] -\ CHECK: %[[S7:.*]] = forth.literal %[[S6]] -\ CHECK: %[[S8:.*]] = forth.literal %[[S7]] +\ CHECK: %[[S7:.*]] = forth.constant %[[S6]] +\ CHECK: %[[S8:.*]] = forth.constant %[[S7]] \ CHECK: %[[S9:.*]] = forth.xor %[[S8]] -\ CHECK: %[[S10:.*]] = forth.literal %[[S9]] +\ CHECK: %[[S10:.*]] = forth.constant %[[S9]] \ CHECK: %[[S11:.*]] = forth.not %[[S10]] -\ CHECK: %[[S12:.*]] = forth.literal %[[S11]] -\ CHECK: %[[S13:.*]] = forth.literal %[[S12]] +\ CHECK: %[[S12:.*]] = forth.constant %[[S11]] +\ CHECK: %[[S13:.*]] = forth.constant %[[S12]] \ CHECK: %[[S14:.*]] = forth.lshift %[[S13]] -\ CHECK: %[[S15:.*]] = forth.literal %[[S14]] -\ CHECK: %[[S16:.*]] = forth.literal %[[S15]] +\ CHECK: %[[S15:.*]] = forth.constant %[[S14]] +\ CHECK: %[[S16:.*]] = forth.constant %[[S15]] \ CHECK: %{{.*}} = forth.rshift %[[S16]] \! kernel main 3 5 AND 7 8 OR 15 3 XOR 42 NOT 1 4 LSHIFT 256 2 RSHIFT diff --git a/test/Translation/Forth/case-insensitive.forth b/test/Translation/Forth/case-insensitive.forth index c545f15..2021721 100644 --- a/test/Translation/Forth/case-insensitive.forth +++ b/test/Translation/Forth/case-insensitive.forth @@ -4,6 +4,6 @@ \ CHECK: forth.dup \ CHECK: forth.drop \ CHECK: forth.swap -\ CHECK: forth.add +\ CHECK: forth.addi \! kernel main 1 Dup DROP swap duP + diff --git a/test/Translation/Forth/comparison-ops.forth b/test/Translation/Forth/comparison-ops.forth index 804b5bb..bf1d09e 100644 --- a/test/Translation/Forth/comparison-ops.forth +++ b/test/Translation/Forth/comparison-ops.forth @@ -2,25 +2,25 @@ \ Verify comparison operations parse correctly \ CHECK: %[[S0:.*]] = forth.stack -\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] -\ CHECK: %[[S2:.*]] = forth.literal %[[S1]] -\ CHECK: %[[S3:.*]] = forth.eq %[[S2]] -\ CHECK: %[[S4:.*]] = forth.literal %[[S3]] -\ CHECK: %[[S5:.*]] = forth.literal %[[S4]] -\ CHECK: %[[S6:.*]] = forth.lt %[[S5]] -\ CHECK: %[[S7:.*]] = forth.literal %[[S6]] -\ CHECK: %[[S8:.*]] = forth.literal %[[S7]] -\ CHECK: %[[S9:.*]] = forth.gt %[[S8]] -\ CHECK: %[[S10:.*]] = forth.literal %[[S9]] +\ CHECK: %[[S1:.*]] = forth.constant %[[S0]] +\ CHECK: %[[S2:.*]] = forth.constant %[[S1]] +\ CHECK: %[[S3:.*]] = forth.eqi %[[S2]] +\ CHECK: %[[S4:.*]] = forth.constant %[[S3]] +\ CHECK: %[[S5:.*]] = forth.constant %[[S4]] +\ CHECK: %[[S6:.*]] = forth.lti %[[S5]] +\ CHECK: %[[S7:.*]] = forth.constant %[[S6]] +\ CHECK: %[[S8:.*]] = forth.constant %[[S7]] +\ CHECK: %[[S9:.*]] = forth.gti %[[S8]] +\ CHECK: %[[S10:.*]] = forth.constant %[[S9]] \ CHECK: %[[S11:.*]] = forth.zero_eq %[[S10]] -\ CHECK: %[[S12:.*]] = forth.literal %[[S11]] -\ CHECK: %[[S13:.*]] = forth.literal %[[S12]] -\ CHECK: %[[S14:.*]] = forth.ne %[[S13]] -\ CHECK: %[[S15:.*]] = forth.literal %[[S14]] -\ CHECK: %[[S16:.*]] = forth.literal %[[S15]] -\ CHECK: %[[S17:.*]] = forth.le %[[S16]] -\ CHECK: %[[S18:.*]] = forth.literal %[[S17]] -\ CHECK: %[[S19:.*]] = forth.literal %[[S18]] -\ CHECK: %{{.*}} = forth.ge %[[S19]] +\ CHECK: %[[S12:.*]] = forth.constant %[[S11]] +\ CHECK: %[[S13:.*]] = forth.constant %[[S12]] +\ CHECK: %[[S14:.*]] = forth.nei %[[S13]] +\ CHECK: %[[S15:.*]] = forth.constant %[[S14]] +\ CHECK: %[[S16:.*]] = forth.constant %[[S15]] +\ CHECK: %[[S17:.*]] = forth.lei %[[S16]] +\ CHECK: %[[S18:.*]] = forth.constant %[[S17]] +\ CHECK: %[[S19:.*]] = forth.constant %[[S18]] +\ CHECK: %{{.*}} = forth.gei %[[S19]] \! kernel main 1 2 = 3 4 < 5 6 > 0 0= 7 8 <> 9 10 <= 11 12 >= diff --git a/test/Translation/Forth/control-flow.forth b/test/Translation/Forth/control-flow.forth index f8a39a2..460fbfe 100644 --- a/test/Translation/Forth/control-flow.forth +++ b/test/Translation/Forth/control-flow.forth @@ -4,25 +4,25 @@ \ Basic IF/ELSE/THEN \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 1 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](1 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF1:.*]], %[[FLAG1:.*]] = forth.pop_flag %[[S1]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FLAG1]], ^bb1(%[[PF1]] : !forth.stack), ^bb2(%[[PF1]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L42:.*]] = forth.literal %[[B1]] 42 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L42:.*]] = forth.constant %[[B1]](42 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb3(%[[L42]] : !forth.stack) \ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L99:.*]] = forth.literal %[[B2]] 99 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L99:.*]] = forth.constant %[[B2]](99 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb3(%[[L99]] : !forth.stack) \! kernel main 1 IF 42 ELSE 99 THEN \ Basic IF/THEN (no ELSE - fallthrough on false) \ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): -\ CHECK-NEXT: %[[S2:.*]] = forth.literal %[[B3]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[B3]](0 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF2:.*]], %[[FLAG2:.*]] = forth.pop_flag %[[S2]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FLAG2]], ^bb4(%[[PF2]] : !forth.stack), ^bb5(%[[PF2]] : !forth.stack) \ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L7:.*]] = forth.literal %[[B4]] 7 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L7:.*]] = forth.constant %[[B4]](7 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb5(%[[L7]] : !forth.stack) \ CHECK: ^bb5(%[[B5:.*]]: !forth.stack): \ CHECK-NEXT: return diff --git a/test/Translation/Forth/do-loop.forth b/test/Translation/Forth/do-loop.forth index 39e8e0f..b922258 100644 --- a/test/Translation/Forth/do-loop.forth +++ b/test/Translation/Forth/do-loop.forth @@ -3,8 +3,8 @@ \ Verify DO/LOOP generates post-test loop with crossing test \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[S2:.*]] = forth.literal %[[S1]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[OS:.*]], %[[VAL:.*]] = forth.pop %[[S2]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[OS2:.*]], %[[LIM:.*]] = forth.pop %[[OS]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<1xi64> diff --git a/test/Translation/Forth/float-arithmetic.forth b/test/Translation/Forth/float-arithmetic.forth new file mode 100644 index 0000000..1a79e7a --- /dev/null +++ b/test/Translation/Forth/float-arithmetic.forth @@ -0,0 +1,12 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Verify float arithmetic ops parse correctly +\ CHECK: %[[S0:.*]] = forth.stack +\ CHECK: %[[S1:.*]] = forth.constant %[[S0]] +\ CHECK: %[[S2:.*]] = forth.constant %[[S1]] +\ CHECK: %[[S3:.*]] = forth.addf %[[S2]] +\ CHECK: %[[S4:.*]] = forth.subf %[[S3]] +\ CHECK: %[[S5:.*]] = forth.mulf %[[S4]] +\ CHECK: %{{.*}} = forth.divf %[[S5]] +\! kernel main +1.0 2.0 F+ F- F* F/ diff --git a/test/Translation/Forth/float-comparison.forth b/test/Translation/Forth/float-comparison.forth new file mode 100644 index 0000000..b206a0a --- /dev/null +++ b/test/Translation/Forth/float-comparison.forth @@ -0,0 +1,24 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Verify float comparison operations parse correctly +\ CHECK: %[[S0:.*]] = forth.stack +\ CHECK: %[[S1:.*]] = forth.constant %[[S0]] +\ CHECK: %[[S2:.*]] = forth.constant %[[S1]] +\ CHECK: %[[S3:.*]] = forth.eqf %[[S2]] +\ CHECK: %[[S4:.*]] = forth.constant %[[S3]] +\ CHECK: %[[S5:.*]] = forth.constant %[[S4]] +\ CHECK: %[[S6:.*]] = forth.ltf %[[S5]] +\ CHECK: %[[S7:.*]] = forth.constant %[[S6]] +\ CHECK: %[[S8:.*]] = forth.constant %[[S7]] +\ CHECK: %[[S9:.*]] = forth.gtf %[[S8]] +\ CHECK: %[[S10:.*]] = forth.constant %[[S9]] +\ CHECK: %[[S11:.*]] = forth.constant %[[S10]] +\ CHECK: %[[S12:.*]] = forth.nef %[[S11]] +\ CHECK: %[[S13:.*]] = forth.constant %[[S12]] +\ CHECK: %[[S14:.*]] = forth.constant %[[S13]] +\ CHECK: %[[S15:.*]] = forth.lef %[[S14]] +\ CHECK: %[[S16:.*]] = forth.constant %[[S15]] +\ CHECK: %[[S17:.*]] = forth.constant %[[S16]] +\ CHECK: %{{.*}} = forth.gef %[[S17]] +\! kernel main +1.0 2.0 F= 3.0 4.0 F< 5.0 6.0 F> 7.0 8.0 F<> 9.0 10.0 F<= 11.0 12.0 F>= diff --git a/test/Translation/Forth/float-conversion.forth b/test/Translation/Forth/float-conversion.forth new file mode 100644 index 0000000..667bb2a --- /dev/null +++ b/test/Translation/Forth/float-conversion.forth @@ -0,0 +1,9 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Test S>F produces forth.itof +\ CHECK: forth.itof %{{.*}} : !forth.stack -> !forth.stack + +\ Test F>S produces forth.ftoi +\ CHECK: forth.ftoi %{{.*}} : !forth.stack -> !forth.stack +\! kernel main +42 S>F F>S diff --git a/test/Translation/Forth/float-literals.forth b/test/Translation/Forth/float-literals.forth new file mode 100644 index 0000000..31d8f16 --- /dev/null +++ b/test/Translation/Forth/float-literals.forth @@ -0,0 +1,9 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ CHECK: forth.stack +\ CHECK-NEXT: forth.constant %{{.*}}(3.140000e+00 : f64) +\ CHECK-NEXT: forth.constant %{{.*}}(-2.000000e+00 : f64) +\ CHECK-NEXT: forth.constant %{{.*}}(1.000000e-05 : f64) +\ CHECK-NEXT: forth.constant %{{.*}}(1.000000e+03 : f64) +\! kernel main +3.14 -2.0 1.0e-5 1e3 diff --git a/test/Translation/Forth/float-memory.forth b/test/Translation/Forth/float-memory.forth new file mode 100644 index 0000000..e93b1dc --- /dev/null +++ b/test/Translation/Forth/float-memory.forth @@ -0,0 +1,15 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Test F@ produces forth.loadf +\ CHECK: forth.loadf %{{.*}} : !forth.stack -> !forth.stack + +\ Test F! produces forth.storef +\ CHECK: forth.storef %{{.*}} : !forth.stack -> !forth.stack + +\ Test SF@ produces forth.shared_loadf +\ CHECK: forth.shared_loadf %{{.*}} : !forth.stack -> !forth.stack + +\ Test SF! produces forth.shared_storef +\ CHECK: forth.shared_storef %{{.*}} : !forth.stack -> !forth.stack +\! kernel main +1 F@ 2.0 3 F! 4 SF@ 5.0 6 SF! diff --git a/test/Translation/Forth/float-params.forth b/test/Translation/Forth/float-params.forth new file mode 100644 index 0000000..aed8e9e --- /dev/null +++ b/test/Translation/Forth/float-params.forth @@ -0,0 +1,12 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Check f64 scalar param becomes f64 function argument +\ CHECK: func.func private @main(%arg0: memref<256xf64> {forth.param_name = "DATA"}, %arg1: f64 {forth.param_name = "SCALE"}) + +\ Check param refs work +\ CHECK: forth.param_ref %{{.*}} "DATA" +\ CHECK: forth.param_ref %{{.*}} "SCALE" +\! kernel main +\! param DATA f64[256] +\! param SCALE f64 +DATA SCALE diff --git a/test/Translation/Forth/interleaved-control-flow.forth b/test/Translation/Forth/interleaved-control-flow.forth index a6fe597..6865d05 100644 --- a/test/Translation/Forth/interleaved-control-flow.forth +++ b/test/Translation/Forth/interleaved-control-flow.forth @@ -15,15 +15,15 @@ \ Loop header: DUP 10 > → WHILE(1) \ CHECK: ^bb1(%[[H:.*]]: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.literal %{{.*}} 10 -\ CHECK-NEXT: %{{.*}} = forth.gt +\ CHECK: forth.constant %{{.*}}(10 : i64) +\ CHECK-NEXT: %{{.*}} = forth.gti \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : !forth.stack), ^bb3(%{{.*}} : !forth.stack) \ WHILE(1) body: DUP 2 MOD 0= → WHILE(2) \ CHECK: ^bb2(%{{.*}}: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.literal %{{.*}} 2 +\ CHECK: forth.constant %{{.*}}(2 : i64) \ CHECK-NEXT: %{{.*}} = forth.mod \ CHECK-NEXT: %{{.*}} = forth.zero_eq \ CHECK: forth.pop_flag @@ -35,8 +35,8 @@ \ WHILE(2) body: 1 - → REPEAT (branch back to loop header) \ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.literal %[[B4]] 1 -\ CHECK-NEXT: %{{.*}} = forth.sub +\ CHECK-NEXT: %{{.*}} = forth.constant %[[B4]](1 : i64) +\ CHECK-NEXT: %{{.*}} = forth.subi \ CHECK-NEXT: cf.br ^bb1 \ WHILE(2) exit: DROP → THEN (branch to WHILE(1) exit) @@ -59,19 +59,19 @@ \ Loop header: DUP 0 > → WHILE \ CHECK: ^bb1(%{{.*}}: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.literal %{{.*}} 0 -\ CHECK-NEXT: %{{.*}} = forth.gt +\ CHECK: forth.constant %{{.*}}(0 : i64) +\ CHECK-NEXT: %{{.*}} = forth.gti \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : !forth.stack), ^bb3(%{{.*}} : !forth.stack) \ WHILE body + UNTIL: 1 - DUP 5 = UNTIL \ UNTIL true exits to ^bb4, UNTIL false loops back to ^bb1 \ CHECK: ^bb2(%[[W:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.literal %[[W]] 1 -\ CHECK-NEXT: %{{.*}} = forth.sub +\ CHECK-NEXT: %{{.*}} = forth.constant %[[W]](1 : i64) +\ CHECK-NEXT: %{{.*}} = forth.subi \ CHECK: forth.dup -\ CHECK: forth.literal %{{.*}} 5 -\ CHECK-NEXT: %{{.*}} = forth.eq +\ CHECK: forth.constant %{{.*}}(5 : i64) +\ CHECK-NEXT: %{{.*}} = forth.eqi \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb4(%{{.*}} : !forth.stack), ^bb1(%{{.*}} : !forth.stack) diff --git a/test/Translation/Forth/leave.forth b/test/Translation/Forth/leave.forth index 528f6fc..1da159e 100644 --- a/test/Translation/Forth/leave.forth +++ b/test/Translation/Forth/leave.forth @@ -3,8 +3,8 @@ \ Verify LEAVE branches to the loop exit block. \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[S2:.*]] = forth.literal %[[S1]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i64) : !forth.stack -> !forth.stack \ CHECK: cf.br ^bb1(%{{.*}} : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): \ CHECK-NEXT: %[[TRUE:.*]] = arith.constant true diff --git a/test/Translation/Forth/memory-ops.forth b/test/Translation/Forth/memory-ops.forth index 449cc42..0d6afe3 100644 --- a/test/Translation/Forth/memory-ops.forth +++ b/test/Translation/Forth/memory-ops.forth @@ -1,20 +1,20 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s -\ Test @ produces forth.load -\ CHECK: forth.load %{{.*}} : !forth.stack -> !forth.stack +\ Test @ produces forth.loadi +\ CHECK: forth.loadi %{{.*}} : !forth.stack -> !forth.stack -\ Test ! produces forth.store -\ CHECK: forth.store %{{.*}} : !forth.stack -> !forth.stack +\ Test ! produces forth.storei +\ CHECK: forth.storei %{{.*}} : !forth.stack -> !forth.stack -\ Test S@ produces forth.shared_load -\ CHECK: forth.shared_load %{{.*}} : !forth.stack -> !forth.stack +\ Test S@ produces forth.shared_loadi +\ CHECK: forth.shared_loadi %{{.*}} : !forth.stack -> !forth.stack -\ Test S! produces forth.shared_store -\ CHECK: forth.shared_store %{{.*}} : !forth.stack -> !forth.stack +\ Test S! produces forth.shared_storei +\ CHECK: forth.shared_storei %{{.*}} : !forth.stack -> !forth.stack \ Test CELLS produces literal 8 + mul -\ CHECK: forth.literal %{{.*}} 8 -\ CHECK-NEXT: forth.mul +\ CHECK: forth.constant %{{.*}}(8 : i64) +\ CHECK-NEXT: forth.muli \! kernel main 1 @ 2 3 ! 4 S@ 5 6 S! 4 CELLS diff --git a/test/Translation/Forth/nested-control-flow.forth b/test/Translation/Forth/nested-control-flow.forth index 730b372..26cbca1 100644 --- a/test/Translation/Forth/nested-control-flow.forth +++ b/test/Translation/Forth/nested-control-flow.forth @@ -2,11 +2,11 @@ \ === Nested IF === \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 1 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](1 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF1:.*]], %[[FL1:.*]] = forth.pop_flag %[[S1]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FL1]], ^bb1(%[[PF1]] : !forth.stack), ^bb2(%[[PF1]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L2:.*]] = forth.literal %[[B1]] 2 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L2:.*]] = forth.constant %[[B1]](2 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[PF2:.*]], %[[FL2:.*]] = forth.pop_flag %[[L2]] : !forth.stack -> !forth.stack, i1 \ CHECK-NEXT: cf.cond_br %[[FL2]], ^bb3(%[[PF2]] : !forth.stack), ^bb4(%[[PF2]] : !forth.stack) \! kernel main @@ -15,8 +15,8 @@ \ === IF inside DO === \ After IF/THEN merge, set up DO loop: 10 0 DO \ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L10:.*]] = forth.literal %[[B2]] 10 : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[L0A:.*]] = forth.literal %[[L10]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L10:.*]] = forth.constant %[[B2]](10 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L0A:.*]] = forth.constant %[[L10]](0 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[POP1:.*]], %[[V1:.*]] = forth.pop %[[L0A]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[POP2:.*]], %[[V2:.*]] = forth.pop %[[POP1]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloca() : memref<1xi64> @@ -27,7 +27,7 @@ \ Nested IF: true branch pushes 3, then merges \ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): -\ CHECK-NEXT: %[[L3:.*]] = forth.literal %[[B3]] 3 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L3:.*]] = forth.constant %[[B3]](3 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: cf.br ^bb4(%[[L3]] : !forth.stack) \ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): \ CHECK-NEXT: cf.br ^bb2(%[[B4]] : !forth.stack) @@ -35,15 +35,15 @@ \ DO loop body (post-test: no check block): I 5 > IF I THEN \ CHECK: ^bb5(%[[B5:.*]]: !forth.stack): \ CHECK: forth.push_value %[[B5]] -\ CHECK: forth.literal %{{.*}} 5 -\ CHECK-NEXT: %{{.*}} = forth.gt +\ CHECK: forth.constant %{{.*}}(5 : i64) +\ CHECK-NEXT: %{{.*}} = forth.gti \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br %{{[^,]*}}, ^bb7(%{{[^)]*}} : !forth.stack), ^bb8(%{{[^)]*}} : !forth.stack) \ === Nested DO with J === \ After first DO loop exits: sets up nested DO (3 0 DO) \ CHECK: ^bb6(%[[B6:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.literal %[[B6]] 3 +\ CHECK-NEXT: %{{.*}} = forth.constant %[[B6]](3 : i64) 3 0 DO 4 0 DO J I + LOOP LOOP \ IF I true branch: push loop index @@ -61,8 +61,8 @@ \ Outer DO body (3 0 DO) with inner DO setup (4 0 DO) \ CHECK: ^bb9(%{{.*}}: !forth.stack): -\ CHECK: forth.literal %{{.*}} 4 -\ CHECK: forth.literal %{{.*}} 0 +\ CHECK: forth.constant %{{.*}}(4 : i64) +\ CHECK: forth.constant %{{.*}}(0 : i64) \ CHECK: forth.pop \ CHECK: forth.pop \ CHECK: memref.alloca() @@ -70,14 +70,14 @@ \ === Triple-nested DO with K === \ After nested DO exits: sets up triple-nested DO (2 0 DO) \ CHECK: ^bb10(%{{.*}}: !forth.stack): -\ CHECK: forth.literal %{{.*}} 2 +\ CHECK: forth.constant %{{.*}}(2 : i64) 2 0 DO 2 0 DO 2 0 DO K J I + + LOOP LOOP LOOP \ Inner loop of J I + (bb11 body) \ CHECK: ^bb11(%{{.*}}: !forth.stack): \ CHECK: forth.push_value \ CHECK: forth.push_value -\ CHECK: forth.add +\ CHECK: forth.addi \ Inner loop crossing test \ CHECK: arith.xori @@ -93,13 +93,13 @@ \ Triple-nested outer loop body (bb13) \ CHECK: ^bb13(%{{.*}}: !forth.stack): -\ CHECK: forth.literal %{{.*}} 2 -\ CHECK: forth.literal %{{.*}} 0 +\ CHECK: forth.constant %{{.*}}(2 : i64) +\ CHECK: forth.constant %{{.*}}(0 : i64) \ === BEGIN/WHILE inside IF === \ After triple-nested exits: 5 IF BEGIN DUP WHILE 1 - REPEAT THEN \ CHECK: ^bb14(%{{.*}}: !forth.stack): -\ CHECK: forth.literal %{{.*}} 5 +\ CHECK: forth.constant %{{.*}}(5 : i64) \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br 5 IF BEGIN DUP WHILE 1 - REPEAT THEN @@ -120,26 +120,26 @@ \ WHILE body: 1 - \ CHECK: ^bb22(%[[B22:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.literal %[[B22]] 1 -\ CHECK-NEXT: %{{.*}} = forth.sub +\ CHECK-NEXT: %{{.*}} = forth.constant %[[B22]](1 : i64) +\ CHECK-NEXT: %{{.*}} = forth.subi \ === IF inside BEGIN/UNTIL === \ BEGIN/UNTIL header: DUP 10 < \ CHECK: ^bb24(%{{.*}}: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.literal %{{.*}} 10 -\ CHECK-NEXT: %{{.*}} = forth.lt +\ CHECK: forth.constant %{{.*}}(10 : i64) +\ CHECK-NEXT: %{{.*}} = forth.lti BEGIN DUP 10 < IF 1 + THEN DUP 20 = UNTIL \ IF true branch: 1 + \ CHECK: ^bb25(%[[B25:.*]]: !forth.stack): -\ CHECK-NEXT: %{{.*}} = forth.literal %[[B25]] 1 -\ CHECK-NEXT: %{{.*}} = forth.add +\ CHECK-NEXT: %{{.*}} = forth.constant %[[B25]](1 : i64) +\ CHECK-NEXT: %{{.*}} = forth.addi \ UNTIL condition: DUP 20 = \ CHECK: ^bb26(%{{.*}}: !forth.stack): \ CHECK: forth.dup -\ CHECK: forth.literal %{{.*}} 20 -\ CHECK-NEXT: %{{.*}} = forth.eq +\ CHECK: forth.constant %{{.*}}(20 : i64) +\ CHECK-NEXT: %{{.*}} = forth.eqi \ CHECK: forth.pop_flag \ CHECK-NEXT: cf.cond_br diff --git a/test/Translation/Forth/plus-loop-negative.forth b/test/Translation/Forth/plus-loop-negative.forth index a087ba8..0ef5e94 100644 --- a/test/Translation/Forth/plus-loop-negative.forth +++ b/test/Translation/Forth/plus-loop-negative.forth @@ -3,13 +3,13 @@ \ Verify +LOOP with negative step uses crossing test (handles negative direction) \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 0 : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[S2:.*]] = forth.literal %[[S1]] 10 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](0 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](10 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[OS:.*]], %[[VAL:.*]] = forth.pop %[[S2]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[OS2:.*]], %[[LIM:.*]] = forth.pop %[[OS]] : !forth.stack -> !forth.stack, i64 \ CHECK: cf.br ^bb1(%[[OS2]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK: %[[STEP_S:.*]] = forth.literal %[[B1]] -1 : !forth.stack -> !forth.stack +\ CHECK: %[[STEP_S:.*]] = forth.constant %[[B1]](-1 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[POP_S:.*]], %[[STEP:.*]] = forth.pop %[[STEP_S]] : !forth.stack -> !forth.stack, i64 \ CHECK: %[[OLD:.*]] = memref.load \ CHECK: %[[NEW:.*]] = arith.addi %[[OLD]], %[[STEP]] : i64 diff --git a/test/Translation/Forth/plus-loop.forth b/test/Translation/Forth/plus-loop.forth index c31ac49..ed2f791 100644 --- a/test/Translation/Forth/plus-loop.forth +++ b/test/Translation/Forth/plus-loop.forth @@ -3,8 +3,8 @@ \ Verify +LOOP pops step from data stack and uses it as increment \ CHECK: %[[S0:.*]] = forth.stack !forth.stack -\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack -\ CHECK-NEXT: %[[S2:.*]] = forth.literal %[[S1]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.constant %[[S0]](10 : i64) : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.constant %[[S1]](0 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[OS:.*]], %[[VAL:.*]] = forth.pop %[[S2]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[OS2:.*]], %[[LIM:.*]] = forth.pop %[[OS]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<1xi64> @@ -12,7 +12,7 @@ \ CHECK-NEXT: memref.store %[[VAL]], %[[ALLOCA]][%[[C0]]] : memref<1xi64> \ CHECK-NEXT: cf.br ^bb1(%[[OS2]] : !forth.stack) \ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): -\ CHECK: %[[STEP_S:.*]] = forth.literal %[[B1]] 2 : !forth.stack -> !forth.stack +\ CHECK: %[[STEP_S:.*]] = forth.constant %[[B1]](2 : i64) : !forth.stack -> !forth.stack \ CHECK-NEXT: %[[POP_S:.*]], %[[STEP:.*]] = forth.pop %[[STEP_S]] : !forth.stack -> !forth.stack, i64 \ CHECK-NEXT: %[[C0_2:.*]] = arith.constant 0 : index \ CHECK-NEXT: %[[OLD:.*]] = memref.load %[[ALLOCA]][%[[C0_2]]] : memref<1xi64> diff --git a/test/Translation/Forth/unloop-exit.forth b/test/Translation/Forth/unloop-exit.forth index c492fa6..268bc8e 100644 --- a/test/Translation/Forth/unloop-exit.forth +++ b/test/Translation/Forth/unloop-exit.forth @@ -6,7 +6,7 @@ \ CHECK: memref.alloca \ CHECK: cf.br ^bb[[#BODY:]] \ CHECK: ^bb[[#BODY]](%{{.*}}: !forth.stack): -\ CHECK: forth.eq +\ CHECK: forth.eqi \ CHECK: cf.cond_br %{{.*}}, ^bb[[#THEN:]](%{{.*}}), ^bb[[#ENDIF:]](%{{.*}}) \ CHECK: ^bb[[#EXIT:]](%{{.*}}: !forth.stack): \ CHECK: return diff --git a/test/Translation/Forth/word-definitions.forth b/test/Translation/Forth/word-definitions.forth index 590940c..939c162 100644 --- a/test/Translation/Forth/word-definitions.forth +++ b/test/Translation/Forth/word-definitions.forth @@ -2,7 +2,7 @@ \ CHECK: func.func private @DOUBLE(%arg0: !forth.stack) -> !forth.stack { \ CHECK: forth.dup -\ CHECK: forth.add +\ CHECK: forth.addi \ CHECK: return %{{.*}} : !forth.stack \ CHECK: } \ CHECK: func.func private @main() diff --git a/tools/warpforth-runner/warpforth-runner.cpp b/tools/warpforth-runner/warpforth-runner.cpp index 9f4f98c..8a46a41 100644 --- a/tools/warpforth-runner/warpforth-runner.cpp +++ b/tools/warpforth-runner/warpforth-runner.cpp @@ -5,7 +5,7 @@ /// -std=c++17`. /// /// Usage: -/// warpforth-runner kernel.ptx --param i64[]:1,2,3 --param i64:42 \ +/// warpforth-runner kernel.ptx --param i64[]:1,2,3 --param f64:3.14 \ /// --grid 4,1,1 --block 64,1,1 --kernel main \ /// --output-param 0 --output-count 3 @@ -15,11 +15,14 @@ #include #include #include +#include #include #include #include #include #include +#include +#include #include #define CHECK_CU(call) \ @@ -34,15 +37,55 @@ } \ } while (0) -enum class ParamKind { Array, Scalar }; - -struct Param { - ParamKind kind = ParamKind::Array; - std::vector values; +template struct ArrayParam { + std::vector values; CUdeviceptr devicePtr = 0; - int64_t scalarValue = 0; }; +template struct ScalarParam { + T value; +}; + +using Param = std::variant, ArrayParam, + ScalarParam, ScalarParam>; + +template static void allocDevice(ArrayParam &arr) { + size_t bytes = arr.values.size() * sizeof(T); + CHECK_CU(cuMemAlloc(&arr.devicePtr, bytes)); + CHECK_CU(cuMemcpyHtoD(arr.devicePtr, arr.values.data(), bytes)); +} + +template +static void printOutput(ArrayParam &arr, size_t count) { + std::vector output(arr.values.size()); + CHECK_CU(cuMemcpyDtoH(output.data(), arr.devicePtr, + arr.values.size() * sizeof(T))); + for (size_t i = 0; i < count; ++i) { + if (i > 0) + std::cout << ","; + if constexpr (std::is_floating_point_v) + std::cout << std::setprecision(17) << output[i]; + else + std::cout << output[i]; + } + std::cout << "\n"; +} + +static void *kernelArgPtr(Param &p) { + if (auto *a = std::get_if>(&p)) + return &a->devicePtr; + if (auto *a = std::get_if>(&p)) + return &a->devicePtr; + if (auto *s = std::get_if>(&p)) + return &s->value; + return &std::get>(p).value; +} + +static bool isScalar(const Param &p) { + return std::holds_alternative>(p) || + std::holds_alternative>(p); +} + struct Dims { unsigned x = 1, y = 1, z = 1; }; @@ -84,14 +127,12 @@ static Dims parseDims(std::string_view s) { } static Param parseParam(std::string_view s) { - Param p; std::string input(s); - // Find the type prefix (everything before the first ':') auto colonPos = input.find(':'); if (colonPos == std::string::npos) { std::cerr << "Error: --param requires type prefix (e.g. i64:42 or " - "i64[]:1,2,3), got: " + "f64[]:1.0,2.0), got: " << s << "\n"; exit(1); } @@ -105,30 +146,40 @@ static Param parseParam(std::string_view s) { exit(1); } - // Determine kind from type prefix - if (typePrefix == "i64[]") { - p.kind = ParamKind::Array; - } else if (typePrefix == "i64") { - p.kind = ParamKind::Scalar; - } else { - std::cerr << "Error: unsupported param type '" << typePrefix - << "' (expected i64 or i64[]), got: " << s << "\n"; - exit(1); - } + // Parse comma-separated values into a typed vector + auto parseValues = [&](auto convert) { + using T = decltype(convert(std::string{})); + std::vector vals; + std::istringstream iss(valueStr); + std::string token; + while (std::getline(iss, token, ',')) + vals.push_back(convert(token)); + return vals; + }; + + auto toI64 = [](const std::string &s) -> int64_t { return std::stoll(s); }; + auto toF64 = [](const std::string &s) -> double { return std::stod(s); }; - // Parse values - std::istringstream iss(valueStr); - std::string token; - while (std::getline(iss, token, ',')) - p.values.push_back(std::stoll(token)); + if (typePrefix == "i64[]") + return Param{ArrayParam{parseValues(toI64)}}; + if (typePrefix == "f64[]") + return Param{ArrayParam{parseValues(toF64)}}; - if (p.kind == ParamKind::Scalar && p.values.size() != 1) { + // Scalars — must be exactly one value + if (valueStr.find(',') != std::string::npos) { std::cerr << "Error: scalar param expects exactly one value, got: " << s << "\n"; exit(1); } - return p; + if (typePrefix == "i64") + return Param{ScalarParam{std::stoll(valueStr)}}; + if (typePrefix == "f64") + return Param{ScalarParam{std::stod(valueStr)}}; + + std::cerr << "Error: unsupported param type '" << typePrefix + << "' (expected i64, i64[], f64, or f64[]), got: " << s << "\n"; + exit(1); } static std::string readFile(std::string_view path) { @@ -187,7 +238,8 @@ int main(int argc, char **argv) { if (!ptxFile) { std::cerr << "Usage: warpforth-runner kernel.ptx --kernel NAME " - "[--param i64[]:V,...] [--param i64:V] [--grid X,Y,Z] " + "[--param i64[]:V,...] [--param f64[]:V,...] " + "[--param i64:V] [--param f64:V] [--grid X,Y,Z] " "[--block X,Y,Z] [--output-param N] [--output-count N]\n"; return 1; } @@ -208,7 +260,7 @@ int main(int argc, char **argv) { return 1; } - if (params[outputParam].kind == ParamKind::Scalar) { + if (isScalar(params[outputParam])) { std::cerr << "Error: output-param " << outputParam << " is a scalar (cannot read back)\n"; return 1; @@ -233,24 +285,18 @@ int main(int argc, char **argv) { CUfunction func; CHECK_CU(cuModuleGetFunction(&func, module, kernelName)); - // Allocate device buffers (arrays) or store scalar values + // Allocate device buffers for array params for (auto &p : params) { - if (p.kind == ParamKind::Array) { - size_t bytes = p.values.size() * sizeof(int64_t); - CHECK_CU(cuMemAlloc(&p.devicePtr, bytes)); - CHECK_CU(cuMemcpyHtoD(p.devicePtr, p.values.data(), bytes)); - } else { - p.scalarValue = p.values[0]; - } + if (auto *a = std::get_if>(&p)) + allocDevice(*a); + else if (auto *a = std::get_if>(&p)) + allocDevice(*a); } // Set up kernel parameters — Driver API expects array of pointers to args std::vector kernelArgs(params.size()); - for (size_t i = 0; i < params.size(); ++i) { - kernelArgs[i] = (params[i].kind == ParamKind::Array) - ? static_cast(¶ms[i].devicePtr) - : static_cast(¶ms[i].scalarValue); - } + for (size_t i = 0; i < params.size(); ++i) + kernelArgs[i] = kernelArgPtr(params[i]); // Launch kernel CHECK_CU(cuLaunchKernel(func, grid.x, grid.y, grid.z, block.x, block.y, @@ -258,25 +304,25 @@ int main(int argc, char **argv) { CHECK_CU(cuCtxSynchronize()); - // Copy back output param - size_t outSize = params[outputParam].values.size(); - std::vector output(outSize); - CHECK_CU(cuMemcpyDtoH(output.data(), params[outputParam].devicePtr, - outSize * sizeof(int64_t))); - - // Print CSV to stdout - size_t count = outputCount >= 0 ? static_cast(outputCount) : outSize; - for (size_t i = 0; i < count; ++i) { - if (i > 0) - std::cout << ","; - std::cout << output[i]; + // Copy back and print output param + size_t count = outputCount >= 0 ? static_cast(outputCount) : 0; + if (auto *iArr = std::get_if>(¶ms[outputParam])) { + if (outputCount < 0) + count = iArr->values.size(); + printOutput(*iArr, count); + } else { + auto &fArr = std::get>(params[outputParam]); + if (outputCount < 0) + count = fArr.values.size(); + printOutput(fArr, count); } - std::cout << "\n"; // Cleanup — only free device memory for array params for (auto &p : params) { - if (p.kind == ParamKind::Array) - cuMemFree(p.devicePtr); + if (auto *a = std::get_if>(&p)) + cuMemFree(a->devicePtr); + else if (auto *a = std::get_if>(&p)) + cuMemFree(a->devicePtr); } cuModuleUnload(module); cuCtxDestroy(ctx); From 36b2c0ec28db56af3ed8424135812df81453a25d Mon Sep 17 00:00:00 2001 From: Alex Cameron Date: Sat, 21 Feb 2026 02:02:58 +0900 Subject: [PATCH 2/4] test: add tiled f64 matmul GPU test with shared memory --- gpu_test/test_kernels.py | 173 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/gpu_test/test_kernels.py b/gpu_test/test_kernels.py index 564dc82..f7b64e6 100644 --- a/gpu_test/test_kernels.py +++ b/gpu_test/test_kernels.py @@ -399,6 +399,68 @@ def test_tiled_matmul_i64(kernel_runner: KernelRunner) -> None: assert result == expected +def test_tiled_matmul_f64(kernel_runner: KernelRunner) -> None: + """Tiled f64 matmul with shared memory: C = A(4x4) * B(4x4) -> C(4x4). + + Uses 2x2 tiles, float shared memory for A/B tiles, and BARRIER for sync. + Grid: (2,2,1), Block: (2,2,1) — 4 blocks of 4 threads each. + """ + result = kernel_runner.run( + forth_source=( + "\\! kernel main\n" + "\\! param A f64[16]\n" + "\\! param B f64[16]\n" + "\\! param C f64[16]\n" + "\\! shared SA f64[4]\n" + "\\! shared SB f64[4]\n" + "BID-Y 2 * TID-Y +\n" + "BID-X 2 * TID-X +\n" + "0.0\n" + "2 0 DO\n" + " 2 PICK 4 * I 2 * + TID-X + CELLS A + F@\n" + " TID-Y 2 * TID-X + CELLS SA + SF!\n" + " I 2 * TID-Y + 4 * 2 PICK + CELLS B + F@\n" + " TID-Y 2 * TID-X + CELLS SB + SF!\n" + " BARRIER\n" + " 2 0 DO\n" + " TID-Y 2 * I + CELLS SA + SF@\n" + " I 2 * TID-X + CELLS SB + SF@\n" + " F* F+\n" + " LOOP\n" + " BARRIER\n" + "LOOP\n" + "ROT 4 * ROT + CELLS C + F!" + ), + params={ + "A": [1.5, 2.0, 0.5, 3.0, 4.0, 1.5, 2.5, 0.5, 0.5, 3.0, 1.0, 2.0, 2.0, 0.5, 3.5, 1.5], + "B": [1.0, 0.5, 2.0, 1.5, 3.0, 1.0, 0.5, 2.0, 0.5, 2.5, 1.0, 0.5, 2.0, 1.5, 3.0, 1.0], + }, + grid=(2, 2, 1), + block=(2, 2, 1), + output_param=2, + output_count=16, + ) + expected = [ + 13.75, + 8.5, + 13.5, + 9.5, + 10.75, + 10.5, + 12.75, + 10.75, + 14.0, + 8.75, + 9.5, + 9.25, + 8.25, + 12.5, + 12.25, + 7.25, + ] + assert result == [pytest.approx(v) for v in expected] + + # --- User-Defined Words --- @@ -410,3 +472,114 @@ def test_user_defined_word(kernel_runner: KernelRunner) -> None: ), ) assert result[0] == 10 + + +# --- Float Arithmetic --- + + +def test_float_addition(kernel_runner: KernelRunner) -> None: + """F+: 1.5 + 2.5 = 4.0.""" + result = kernel_runner.run( + forth_source="\\! kernel main\n\\! param DATA f64[256]\n1.5 2.5 F+\n0 CELLS DATA + F!", + ) + assert result[0] == pytest.approx(4.0) + + +def test_float_subtraction(kernel_runner: KernelRunner) -> None: + """F-: 10.0 - 3.5 = 6.5.""" + result = kernel_runner.run( + forth_source="\\! kernel main\n\\! param DATA f64[256]\n10.0 3.5 F-\n0 CELLS DATA + F!", + ) + assert result[0] == pytest.approx(6.5) + + +def test_float_multiplication(kernel_runner: KernelRunner) -> None: + """F*: 6.0 * 7.5 = 45.0.""" + result = kernel_runner.run( + forth_source="\\! kernel main\n\\! param DATA f64[256]\n6.0 7.5 F*\n0 CELLS DATA + F!", + ) + assert result[0] == pytest.approx(45.0) + + +def test_float_division(kernel_runner: KernelRunner) -> None: + """F/: 42.0 / 6.0 = 7.0.""" + result = kernel_runner.run( + forth_source="\\! kernel main\n\\! param DATA f64[256]\n42.0 6.0 F/\n0 CELLS DATA + F!", + ) + assert result[0] == pytest.approx(7.0) + + +# --- Float Memory --- + + +def test_float_load_store(kernel_runner: KernelRunner) -> None: + """F@ and F!: read from DATA[0], multiply by 2, write to DATA[1].""" + result = kernel_runner.run( + forth_source=( + "\\! kernel main\n\\! param DATA f64[256]\n0 CELLS DATA + F@\n2.0 F*\n1 CELLS DATA + F!" + ), + params={"DATA": [3.14]}, + output_count=2, + ) + assert result[1] == pytest.approx(6.28) + + +# --- Float Scalar Params --- + + +def test_float_scalar_param(kernel_runner: KernelRunner) -> None: + """Scalar f64 param: each thread scales DATA[i] by SCALE.""" + result = kernel_runner.run( + forth_source=( + "\\! kernel main\n\\! param DATA f64[256]\n\\! param SCALE f64\n" + "GLOBAL-ID\n" + "DUP CELLS DATA + F@\n" + "SCALE F*\n" + "SWAP CELLS DATA + F!" + ), + params={"DATA": [1.0, 2.0, 3.0, 4.0], "SCALE": 2.5}, + block=(4, 1, 1), + output_count=4, + ) + assert result == [ + pytest.approx(2.5), + pytest.approx(5.0), + pytest.approx(7.5), + pytest.approx(10.0), + ] + + +# --- Float Comparisons --- + + +def test_float_comparisons(kernel_runner: KernelRunner) -> None: + """F=, F<, F>: True = -1, False = 0 (pushed as i64 on the stack).""" + result = kernel_runner.run( + forth_source=( + "\\! kernel main\n\\! param DATA i64[256]\n" + "3.14 3.14 F= 0 CELLS DATA + !\n" + "1.0 2.0 F< 1 CELLS DATA + !\n" + "5.0 3.0 F> 2 CELLS DATA + !" + ), + output_count=3, + ) + assert result == [-1, -1, -1] + + +# --- Float Conversion --- + + +def test_int_to_float_conversion(kernel_runner: KernelRunner) -> None: + """S>F: convert int 7 to float, multiply by 1.5, store as f64.""" + result = kernel_runner.run( + forth_source=("\\! kernel main\n\\! param DATA f64[256]\n7 S>F 1.5 F*\n0 CELLS DATA + F!"), + ) + assert result[0] == pytest.approx(10.5) + + +def test_float_to_int_conversion(kernel_runner: KernelRunner) -> None: + """F>S: convert float 7.9 to int (truncates to 7), store as i64.""" + result = kernel_runner.run( + forth_source=("\\! kernel main\n\\! param DATA i64[256]\n7.9 F>S\n0 CELLS DATA + !"), + ) + assert result[0] == 7 From 5312b89d196884e6d252056997d5c6bb879626ad Mon Sep 17 00:00:00 2001 From: Alex Cameron Date: Sat, 21 Feb 2026 02:25:01 +0900 Subject: [PATCH 3/4] fix(dialect): restrict ConstantOp value attribute to I64Attr or F64Attr --- include/warpforth/Dialect/Forth/ForthOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/warpforth/Dialect/Forth/ForthOps.td b/include/warpforth/Dialect/Forth/ForthOps.td index 4d76144..1b12ec4 100644 --- a/include/warpforth/Dialect/Forth/ForthOps.td +++ b/include/warpforth/Dialect/Forth/ForthOps.td @@ -124,7 +124,7 @@ def Forth_ConstantOp : Forth_Op<"constant", [Pure]> { Forth semantics: ( -- n ) }]; - let arguments = (ins Forth_StackType:$input_stack, AnyAttr:$value); + let arguments = (ins Forth_StackType:$input_stack, AnyAttrOf<[I64Attr, F64Attr]>:$value); let results = (outs Forth_StackType:$output_stack); let assemblyFormat = [{ From 00879678073cd668c89bc7cc48cba582f07f1113 Mon Sep 17 00:00:00 2001 From: Alex Cameron Date: Sat, 21 Feb 2026 02:26:08 +0900 Subject: [PATCH 4/4] fix(runner): add error handling for stoll/stod in parseParam --- tools/warpforth-runner/warpforth-runner.cpp | 24 +++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tools/warpforth-runner/warpforth-runner.cpp b/tools/warpforth-runner/warpforth-runner.cpp index 8a46a41..67e1d3f 100644 --- a/tools/warpforth-runner/warpforth-runner.cpp +++ b/tools/warpforth-runner/warpforth-runner.cpp @@ -157,8 +157,24 @@ static Param parseParam(std::string_view s) { return vals; }; - auto toI64 = [](const std::string &s) -> int64_t { return std::stoll(s); }; - auto toF64 = [](const std::string &s) -> double { return std::stod(s); }; + auto toI64 = [&](const std::string &tok) -> int64_t { + try { + return std::stoll(tok); + } catch (const std::exception &) { + std::cerr << "Error: invalid integer value '" << tok << "' in --param " + << s << "\n"; + exit(1); + } + }; + auto toF64 = [&](const std::string &tok) -> double { + try { + return std::stod(tok); + } catch (const std::exception &) { + std::cerr << "Error: invalid float value '" << tok << "' in --param " << s + << "\n"; + exit(1); + } + }; if (typePrefix == "i64[]") return Param{ArrayParam{parseValues(toI64)}}; @@ -173,9 +189,9 @@ static Param parseParam(std::string_view s) { } if (typePrefix == "i64") - return Param{ScalarParam{std::stoll(valueStr)}}; + return Param{ScalarParam{toI64(valueStr)}}; if (typePrefix == "f64") - return Param{ScalarParam{std::stod(valueStr)}}; + return Param{ScalarParam{toF64(valueStr)}}; std::cerr << "Error: unsupported param type '" << typePrefix << "' (expected i64, i64[], f64, or f64[]), got: " << s << "\n";