Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Requires MLIR/LLVM with `MLIR_DIR` and `LLVM_DIR` configured in CMake.
./build/bin/warpforth-translate --mlir-to-ptx > kernel.ptx

# Execute PTX on GPU
./warpforth-runner kernel.ptx --param 1,2,3 --param 0,0,0,0 --output-param 1 --output-count 5
./warpforth-runner kernel.ptx --param i64[]:1,2,3 --param i64:42 --output-param 0 --output-count 3
```

## Adding New Operations
Expand Down
69 changes: 52 additions & 17 deletions gpu_test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import subprocess
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -31,6 +32,15 @@
REMOTE_TMP = "/tmp" # noqa: S108


@dataclass
class ParamDecl:
"""A parsed kernel parameter declaration."""

name: str
is_array: bool
size: int # 0 for scalars


class CompileError(Exception):
"""Raised when warpforthc fails to compile Forth source."""

Expand Down Expand Up @@ -315,13 +325,13 @@ def _parse_kernel_name(forth_source: str) -> str:
raise ValueError(msg)


def _parse_param_declarations(forth_source: str) -> list[tuple[str, int]]:
def _parse_param_declarations(forth_source: str) -> list[ParamDecl]:
"""Parse '\\! param <name> <type>' declarations from Forth source.

Returns list of (name, size) for array params in declaration order.
Scalar params are not supported by the GPU runner.
Returns list of ParamDecl in declaration order. Supports both array
params (e.g. i64[256]) and scalar params (e.g. i64).
"""
decls = []
decls: list[ParamDecl] = []
for keyword, parts in _iter_header_directives(forth_source):
if keyword != "param":
continue
Expand All @@ -330,10 +340,14 @@ def _parse_param_declarations(forth_source: str) -> list[tuple[str, int]]:
raise ValueError(msg)
name = parts[1]
type_spec = parts[2]
if "[" not in type_spec:
msg = "Scalar params are not supported by the GPU runner yet"
raise ValueError(msg)
decls.append((name, _parse_array_type(type_spec)))
if "[" in type_spec:
size = _parse_array_type(type_spec)
decls.append(ParamDecl(name=name, is_array=True, size=size))
else:
if type_spec.lower() != "i64":
msg = f"Unsupported scalar type: {type_spec}"
raise ValueError(msg)
decls.append(ParamDecl(name=name, is_array=False, size=0))
return decls


Expand All @@ -347,7 +361,7 @@ def __init__(self, session: VastSession, compiler: Compiler) -> None:
def run(
self,
forth_source: str,
params: dict[str, list[int]] | None = None,
params: dict[str, list[int] | int] | None = None,
grid: tuple[int, int, int] = (1, 1, 1),
block: tuple[int, int, int] = (1, 1, 1),
output_param: int = 0,
Expand All @@ -356,8 +370,10 @@ def run(
"""Compile Forth source locally, execute on remote GPU, return output values.

Param buffer sizes are derived from the Forth source's 'param' declarations.
The params dict maps param names to initial values (padded with zeros to the
declared size). Params not in the dict are zero-initialized.
The params dict maps param names to initial values:
- Array params: list[int] (padded with zeros to declared size)
- Scalar params: int
Params not in the dict are zero-initialized.
"""
# Parse kernel name and param declarations
kernel_name = _parse_kernel_name(forth_source)
Expand All @@ -368,6 +384,15 @@ def run(

params = params or {}

# Validate output_param
if output_param < 0 or output_param >= len(decls):
msg = f"output_param {output_param} out of range (have {len(decls)} params)"
raise ValueError(msg)
if not decls[output_param].is_array:
name = decls[output_param].name
msg = f"output_param {output_param} ('{name}') is a scalar and cannot be read back"
raise ValueError(msg)

# Compile locally
ptx = self.compiler.compile_source(forth_source)

Expand All @@ -389,12 +414,22 @@ def run(
kernel_name,
]

for name, size in decls:
values = params.get(name, [])
buf = [0] * size
for i, v in enumerate(values):
buf[i] = v
cmd_parts.extend(["--param", ",".join(str(v) for v in buf)])
for decl in decls:
if decl.is_array:
values = params.get(decl.name, [])
if not isinstance(values, list):
msg = f"Array param '{decl.name}' expects a list, got {type(values).__name__}"
raise TypeError(msg)
buf = [0] * decl.size
for i, v in enumerate(values):
buf[i] = v
cmd_parts.extend(["--param", f"i64[]:{','.join(str(v) for v in buf)}"])
else:
value = params.get(decl.name, 0)
if isinstance(value, list):
msg = f"Scalar param '{decl.name}' expects an int, got list"
raise TypeError(msg)
cmd_parts.extend(["--param", f"i64:{value}"])

cmd_parts.extend(
[
Expand Down
20 changes: 20 additions & 0 deletions gpu_test/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,26 @@ def test_multi_param(kernel_runner: KernelRunner) -> None:
assert result == [20, 40, 60, 80]


def test_scalar_param(kernel_runner: KernelRunner) -> None:
"""Scalar + array params: each thread multiplies INPUT[i] by SCALE, writes OUTPUT[i]."""
result = kernel_runner.run(
forth_source=(
"\\! kernel main\n\\! param SCALE i64\n"
"\\! param INPUT i64[4]\n"
"\\! param OUTPUT i64[4]\n"
"GLOBAL-ID\n"
"DUP CELLS INPUT + @\n"
"SCALE *\n"
"SWAP CELLS OUTPUT + !"
),
params={"SCALE": 3, "INPUT": [10, 20, 30, 40]},
block=(4, 1, 1),
output_param=2,
output_count=4,
)
assert result == [30, 60, 90, 120]


# --- Matmul ---


Expand Down
20 changes: 20 additions & 0 deletions test/Pipeline/scalar-param.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --convert-forth-to-memref --convert-forth-to-gpu | %FileCheck %s --check-prefix=MID

\ Verify mixed scalar + array params survive the full pipeline
\ CHECK: gpu.binary @warpforth_module

\ Verify scalar becomes i64 arg, array becomes memref
\ MID: gpu.func @main(
\ MID-SAME: i64 {forth.param_name = "SCALE"}
\ MID-SAME: memref<256xi64> {forth.param_name = "DATA"}
\ MID-SAME: kernel
\ MID: gpu.return

\! kernel main
\! param SCALE i64
\! param DATA i64[256]
GLOBAL-ID
DUP CELLS DATA + @
SCALE *
SWAP CELLS DATA + !
Loading