Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ uv run ruff format gpu_test/
- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Kernel Parameters**: Declared with `PARAM <name> <size>`, each becomes a `memref<Nxi64>` function argument with `forth.param_name` attribute. Using a param name in code pushes its byte address onto the stack via `forth.param_ref`
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i64[<N>]` becomes a `memref<Nxi64>` argument; `\! param <name> i64` becomes an `i64` argument. Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer
- **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion
- **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call`
Expand Down
75 changes: 66 additions & 9 deletions gpu_test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,67 @@ def scp_upload(self, local_path: str | Path, remote_path: str) -> None:
)


def _parse_array_type(type_spec: str) -> int:
if not type_spec.endswith("]"):
msg = f"Invalid array type spec: {type_spec}"
raise ValueError(msg)
base, size_str = type_spec[:-1].split("[", 1)
if base.lower() != "i64":
msg = f"Unsupported base type: {base}"
raise ValueError(msg)
return int(size_str)


def _iter_header_directives(forth_source: str) -> Generator[tuple[str, list[str]]]:
"""Yield (keyword, parts) for each \\! directive in the Forth header.

Strips comments (-- ...) and splits on whitespace. keyword is lowercased.
"""
for line in forth_source.splitlines():
stripped = line.strip()
if not stripped.startswith("\\!"):
continue
directive = stripped[2:].strip()
if "--" in directive:
directive = directive.split("--", 1)[0].strip()
if not directive:
continue
parts = directive.split()
if parts:
yield parts[0].lower(), parts


def _parse_kernel_name(forth_source: str) -> str:
"""Parse '\\! kernel <name>' from Forth source header."""
for keyword, parts in _iter_header_directives(forth_source):
if keyword == "kernel":
if len(parts) < 2:
msg = "Invalid header line: expected '\\! kernel <name>'"
raise ValueError(msg)
return parts[1]
msg = "Forth source has no '\\! kernel' declaration"
raise ValueError(msg)


def _parse_param_declarations(forth_source: str) -> list[tuple[str, int]]:
"""Parse 'param <name> <size>' declarations from Forth source.
"""Parse '\\! param <name> <type>' declarations from Forth source.

Returns list of (name, size) in declaration order.
Returns list of (name, size) for array params in declaration order.
Scalar params are not supported by the GPU runner.
"""
decls = []
for line in forth_source.splitlines():
parts = line.split()
if len(parts) >= 3 and parts[0].upper() == "PARAM":
decls.append((parts[1], int(parts[2])))
for keyword, parts in _iter_header_directives(forth_source):
if keyword != "param":
continue
if len(parts) < 3:
msg = "Invalid header line: expected '\\! param <name> <type>'"
raise ValueError(msg)
name = parts[1]
type_spec = parts[2]
if "[" not in type_spec:
msg = "Scalar params are not supported by the GPU runner yet"
raise ValueError(msg)
decls.append((name, _parse_array_type(type_spec)))
return decls


Expand All @@ -308,10 +359,11 @@ def run(
The params dict maps param names to initial values (padded with zeros to the
declared size). Params not in the dict are zero-initialized.
"""
# Parse param declarations to determine buffer sizes
# Parse kernel name and param declarations
kernel_name = _parse_kernel_name(forth_source)
decls = _parse_param_declarations(forth_source)
if not decls:
msg = "Forth source has no 'param' declarations"
msg = "Forth source has no '\\! param' declarations"
raise ValueError(msg)

params = params or {}
Expand All @@ -330,7 +382,12 @@ def run(
ptx_path.unlink()

# Build remote command
cmd_parts = [f"{REMOTE_TMP}/warpforth-runner", f"{REMOTE_TMP}/kernel.ptx"]
cmd_parts = [
f"{REMOTE_TMP}/warpforth-runner",
f"{REMOTE_TMP}/kernel.ptx",
"--kernel",
kernel_name,
]

for name, size in decls:
values = params.get(name, [])
Expand Down
98 changes: 73 additions & 25 deletions gpu_test/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,39 @@
def test_addition(kernel_runner: KernelRunner) -> None:
"""3 + 4 = 7."""
result = kernel_runner.run(
forth_source="PARAM DATA 256\n3 4 +\n0 CELLS DATA + !",
forth_source="\\! kernel main\n\\! param DATA i64[256]\n3 4 +\n0 CELLS DATA + !",
)
assert result[0] == 7


def test_subtraction(kernel_runner: KernelRunner) -> None:
"""10 - 3 = 7."""
result = kernel_runner.run(
forth_source="PARAM DATA 256\n10 3 -\n0 CELLS DATA + !",
forth_source="\\! kernel main\n\\! param DATA i64[256]\n10 3 -\n0 CELLS DATA + !",
)
assert result[0] == 7


def test_multiplication(kernel_runner: KernelRunner) -> None:
"""6 * 7 = 42."""
result = kernel_runner.run(
forth_source="PARAM DATA 256\n6 7 *\n0 CELLS DATA + !",
forth_source="\\! kernel main\n\\! param DATA i64[256]\n6 7 *\n0 CELLS DATA + !",
)
assert result[0] == 42


def test_division(kernel_runner: KernelRunner) -> None:
"""42 / 6 = 7."""
result = kernel_runner.run(
forth_source="PARAM DATA 256\n42 6 /\n0 CELLS DATA + !",
forth_source="\\! kernel main\n\\! param DATA i64[256]\n42 6 /\n0 CELLS DATA + !",
)
assert result[0] == 7


def test_modulo(kernel_runner: KernelRunner) -> None:
"""17 MOD 5 = 2."""
result = kernel_runner.run(
forth_source="PARAM DATA 256\n17 5 MOD\n0 CELLS DATA + !",
forth_source="\\! kernel main\n\\! param DATA i64[256]\n17 5 MOD\n0 CELLS DATA + !",
)
assert result[0] == 2

Expand All @@ -61,7 +61,9 @@ def test_modulo(kernel_runner: KernelRunner) -> None:
def test_dup(kernel_runner: KernelRunner) -> None:
"""DUP duplicates top of stack: 5 DUP → [5, 5]."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n5 DUP\n1 CELLS DATA + !\n0 CELLS DATA + !"),
forth_source=(
"\\! kernel main\n\\! param DATA i64[256]\n5 DUP\n1 CELLS DATA + !\n0 CELLS DATA + !"
),
output_count=2,
)
assert result == [5, 5]
Expand All @@ -70,7 +72,9 @@ def test_dup(kernel_runner: KernelRunner) -> None:
def test_swap(kernel_runner: KernelRunner) -> None:
"""SWAP exchanges top two: 1 2 SWAP → [2, 1]."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n1 2 SWAP\n1 CELLS DATA + !\n0 CELLS DATA + !"),
forth_source=(
"\\! kernel main\n\\! param DATA i64[256]\n1 2 SWAP\n1 CELLS DATA + !\n0 CELLS DATA + !"
),
output_count=2,
)
assert result == [2, 1]
Expand All @@ -80,7 +84,12 @@ def test_over(kernel_runner: KernelRunner) -> None:
"""OVER copies second element: 1 2 OVER → [1, 2, 1]."""
result = kernel_runner.run(
forth_source=(
"PARAM DATA 256\n1 2 OVER\n2 CELLS DATA + !\n1 CELLS DATA + !\n0 CELLS DATA + !"
"\\! kernel main\n"
"\\! param DATA i64[256]\n"
"1 2 OVER\n"
"2 CELLS DATA + !\n"
"1 CELLS DATA + !\n"
"0 CELLS DATA + !"
),
output_count=3,
)
Expand All @@ -91,7 +100,12 @@ def test_rot(kernel_runner: KernelRunner) -> None:
"""ROT rotates top three: 1 2 3 ROT → [2, 3, 1]."""
result = kernel_runner.run(
forth_source=(
"PARAM DATA 256\n1 2 3 ROT\n2 CELLS DATA + !\n1 CELLS DATA + !\n0 CELLS DATA + !"
"\\! kernel main\n"
"\\! param DATA i64[256]\n"
"1 2 3 ROT\n"
"2 CELLS DATA + !\n"
"1 CELLS DATA + !\n"
"0 CELLS DATA + !"
),
output_count=3,
)
Expand All @@ -101,7 +115,7 @@ def test_rot(kernel_runner: KernelRunner) -> None:
def test_drop(kernel_runner: KernelRunner) -> None:
"""DROP removes top: 1 2 DROP → [1]."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n1 2 DROP\n0 CELLS DATA + !"),
forth_source=("\\! kernel main\n\\! param DATA i64[256]\n1 2 DROP\n0 CELLS DATA + !"),
)
assert result[0] == 1

Expand All @@ -113,7 +127,7 @@ def test_comparisons(kernel_runner: KernelRunner) -> None:
"""Test =, <, >, 0= in a single kernel. True = -1, False = 0."""
result = kernel_runner.run(
forth_source=(
"PARAM DATA 256\n"
"\\! kernel main\n\\! param DATA i64[256]\n"
"5 5 = 0 CELLS DATA + !\n"
"3 5 < 1 CELLS DATA + !\n"
"5 3 > 2 CELLS DATA + !\n"
Expand All @@ -130,7 +144,14 @@ def test_comparisons(kernel_runner: KernelRunner) -> None:
def test_if_else_then(kernel_runner: KernelRunner) -> None:
"""IF/ELSE/THEN: if DATA[0] > 0, write 1 to DATA[1], else write 2."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n0 CELLS DATA + @\n0 >\nIF 1 ELSE 2 THEN\n1 CELLS DATA + !"),
forth_source=(
"\\! kernel main\n"
"\\! param DATA i64[256]\n"
"0 CELLS DATA + @\n"
"0 >\n"
"IF 1 ELSE 2 THEN\n"
"1 CELLS DATA + !"
),
params={"DATA": [5]},
output_count=2,
)
Expand All @@ -140,15 +161,19 @@ def test_if_else_then(kernel_runner: KernelRunner) -> None:
def test_begin_until(kernel_runner: KernelRunner) -> None:
"""BEGIN/UNTIL countdown: 10 BEGIN 1- DUP 0= UNTIL → final value is 0."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n10 BEGIN 1 - DUP 0= UNTIL\n0 CELLS DATA + !"),
forth_source=(
"\\! kernel main\n\\! param DATA i64[256]\n10 BEGIN 1 - DUP 0= UNTIL\n0 CELLS DATA + !"
),
)
assert result[0] == 0


def test_do_loop(kernel_runner: KernelRunner) -> None:
"""DO/LOOP: write I values 0..4 to DATA[0..4]."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n5 0 DO\n I I CELLS DATA + !\nLOOP"),
forth_source=(
"\\! kernel main\n\\! param DATA i64[256]\n5 0 DO\n I I CELLS DATA + !\nLOOP"
),
output_count=5,
)
assert result == [0, 1, 2, 3, 4]
Expand All @@ -157,7 +182,16 @@ def test_do_loop(kernel_runner: KernelRunner) -> None:
def test_do_plus_loop(kernel_runner: KernelRunner) -> None:
"""DO/+LOOP: write I values 0, 2, 4, 6, 8 to DATA[0..4]."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n0\n10 0 DO\n I OVER CELLS DATA + !\n 1 +\n2 +LOOP\nDROP"),
forth_source=(
"\\! kernel main\n"
"\\! param DATA i64[256]\n"
"0\n"
"10 0 DO\n"
" I OVER CELLS DATA + !\n"
" 1 +\n"
"2 +LOOP\n"
"DROP"
),
output_count=5,
)
assert result == [0, 2, 4, 6, 8]
Expand All @@ -166,7 +200,16 @@ def test_do_plus_loop(kernel_runner: KernelRunner) -> None:
def test_do_plus_loop_negative(kernel_runner: KernelRunner) -> None:
"""DO/+LOOP with negative step: count down from 10 to 1."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n0\n0 10 DO\n I OVER CELLS DATA + !\n 1 +\n-1 +LOOP\nDROP"),
forth_source=(
"\\! kernel main\n"
"\\! param DATA i64[256]\n"
"0\n"
"0 10 DO\n"
" I OVER CELLS DATA + !\n"
" 1 +\n"
"-1 +LOOP\n"
"DROP"
),
output_count=10,
)
assert result == [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
Expand All @@ -180,7 +223,7 @@ def test_multi_while(kernel_runner: KernelRunner) -> None:
"""
result = kernel_runner.run(
forth_source=(
"PARAM DATA 256\n"
"\\! kernel main\n\\! param DATA i64[256]\n"
"20 BEGIN DUP 10 > WHILE DUP 2 MOD 0= WHILE 1 - REPEAT THEN\n"
"0 CELLS DATA + !"
),
Expand All @@ -196,7 +239,10 @@ def test_while_until(kernel_runner: KernelRunner) -> None:
"""
result = kernel_runner.run(
forth_source=(
"PARAM DATA 256\n10 BEGIN DUP 0 > WHILE 1 - DUP 5 = UNTIL THEN\n0 CELLS DATA + !"
"\\! kernel main\n"
"\\! param DATA i64[256]\n"
"10 BEGIN DUP 0 > WHILE 1 - DUP 5 = UNTIL THEN\n"
"0 CELLS DATA + !"
),
)
assert result[0] == 5
Expand All @@ -208,7 +254,7 @@ def test_while_until(kernel_runner: KernelRunner) -> None:
def test_global_id(kernel_runner: KernelRunner) -> None:
"""4 threads each write GLOBAL-ID to DATA[GLOBAL-ID]."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\nGLOBAL-ID\nDUP CELLS DATA + !"),
forth_source=("\\! kernel main\n\\! param DATA i64[256]\nGLOBAL-ID\nDUP CELLS DATA + !"),
block=(4, 1, 1),
output_count=4,
)
Expand All @@ -219,8 +265,8 @@ def test_multi_param(kernel_runner: KernelRunner) -> None:
"""Two params: each thread reads INPUT[i], doubles it, writes OUTPUT[i]."""
result = kernel_runner.run(
forth_source=(
"PARAM INPUT 4\n"
"PARAM OUTPUT 4\n"
"\\! kernel main\n\\! param INPUT i64[4]\n"
"\\! param OUTPUT i64[4]\n"
"GLOBAL-ID\n"
"DUP CELLS INPUT + @\n"
"DUP +\n"
Expand All @@ -243,9 +289,9 @@ def test_naive_matmul_i64(kernel_runner: KernelRunner) -> None:
# GLOBAL-ID maps to (row, col) with row = gid / N, col = gid MOD N.
result = kernel_runner.run(
forth_source=(
"PARAM A 8\n"
"PARAM B 12\n"
"PARAM C 6\n"
"\\! kernel main\n\\! param A i64[8]\n"
"\\! param B i64[12]\n"
"\\! param C i64[6]\n"
"GLOBAL-ID\n"
"DUP 3 /\n"
"SWAP 3 MOD\n"
Expand Down Expand Up @@ -277,6 +323,8 @@ def test_naive_matmul_i64(kernel_runner: KernelRunner) -> None:
def test_user_defined_word(kernel_runner: KernelRunner) -> None:
""": DOUBLE DUP + ; then 5 DOUBLE → 10."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n: DOUBLE DUP + ;\n5 DOUBLE\n0 CELLS DATA + !"),
forth_source=(
"\\! kernel main\n\\! param DATA i64[256]\n: DOUBLE DUP + ;\n5 DOUBLE\n0 CELLS DATA + !"
),
)
assert result[0] == 10
2 changes: 1 addition & 1 deletion lib/Conversion/ForthToGPU/ForthToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ struct ConvertForthToGPUPass

void convertFuncToGPU(func::FuncOp funcOp, gpu::GPUModuleOp gpuModule,
IRRewriter &rewriter) {
bool isKernel = funcOp.getName() == "main";
bool isKernel = funcOp->hasAttr("forth.kernel");

if (isKernel) {
auto gpuFunc = createGPUFunc(funcOp, gpuModule, rewriter);
Expand Down
Loading