Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !` (global memory), `S@ S!` (shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i64[<N>]` becomes a `memref<Nxi64>` argument; `\! param <name> i64` becomes an `i64` argument. Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Shared Memory**: `\! shared <name> i64[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution (`memref<Nxi64, #gpu.address_space<workgroup>>`). Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for shared accesses. Cannot be referenced inside word definitions.
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations.
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i64[<N>]` becomes a `memref<Nxi64>` argument; `\! param <name> i64` becomes an `i64` argument. `\! param <name> f64[<N>]` becomes a `memref<Nxf64>` argument; `\! param <name> f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Shared Memory**: `\! shared <name> i64[<N>]` or `\! shared <name> f64[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions.
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer
- **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion
- **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call`
Expand Down
43 changes: 25 additions & 18 deletions gpu_test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ParamDecl:
name: str
is_array: bool
size: int # 0 for scalars
base_type: str = "i64" # "i64" or "f64"


class CompileError(Exception):
Expand Down Expand Up @@ -283,15 +284,17 @@ def scp_upload(self, local_path: str | Path, remote_path: str) -> None:
)


def _parse_array_type(type_spec: str) -> int:
def _parse_array_type(type_spec: str) -> tuple[str, int]:
"""Parse 'i64[256]' or 'f64[256]' into (base_type, size)."""
if not type_spec.endswith("]"):
msg = f"Invalid array type spec: {type_spec}"
raise ValueError(msg)
base, size_str = type_spec[:-1].split("[", 1)
if base.lower() != "i64":
base_lower = base.lower()
if base_lower not in ("i64", "f64"):
msg = f"Unsupported base type: {base}"
raise ValueError(msg)
return int(size_str)
return base_lower, int(size_str)


def _iter_header_directives(forth_source: str) -> Generator[tuple[str, list[str]]]:
Expand Down Expand Up @@ -341,13 +344,14 @@ def _parse_param_declarations(forth_source: str) -> list[ParamDecl]:
name = parts[1]
type_spec = parts[2]
if "[" in type_spec:
size = _parse_array_type(type_spec)
decls.append(ParamDecl(name=name, is_array=True, size=size))
base_type, size = _parse_array_type(type_spec)
decls.append(ParamDecl(name=name, is_array=True, size=size, base_type=base_type))
else:
if type_spec.lower() != "i64":
base_type = type_spec.lower()
if base_type not in ("i64", "f64"):
msg = f"Unsupported scalar type: {type_spec}"
raise ValueError(msg)
decls.append(ParamDecl(name=name, is_array=False, size=0))
decls.append(ParamDecl(name=name, is_array=False, size=0, base_type=base_type))
return decls


Expand All @@ -361,18 +365,18 @@ def __init__(self, session: VastSession, compiler: Compiler) -> None:
def run(
self,
forth_source: str,
params: dict[str, list[int] | int] | None = None,
params: dict[str, list[int] | list[float] | int | float] | None = None,
grid: tuple[int, int, int] = (1, 1, 1),
block: tuple[int, int, int] = (1, 1, 1),
output_param: int = 0,
output_count: int | None = None,
) -> list[int]:
) -> list[int] | list[float]:
"""Compile Forth source locally, execute on remote GPU, return output values.

Param buffer sizes are derived from the Forth source's 'param' declarations.
The params dict maps param names to initial values:
- Array params: list[int] (padded with zeros to declared size)
- Scalar params: int
- Array params: list of int or float (padded with zeros to declared size)
- Scalar params: int or float
Params not in the dict are zero-initialized.
"""
# Parse kernel name and param declarations
Expand Down Expand Up @@ -420,16 +424,17 @@ def run(
if not isinstance(values, list):
msg = f"Array param '{decl.name}' expects a list, got {type(values).__name__}"
raise TypeError(msg)
buf = [0] * decl.size
zero = 0.0 if decl.base_type == "f64" else 0
buf = [zero] * decl.size
for i, v in enumerate(values):
buf[i] = v
cmd_parts.extend(["--param", f"i64[]:{','.join(str(v) for v in buf)}"])
cmd_parts.extend(["--param", f"{decl.base_type}[]:{','.join(str(v) for v in buf)}"])
else:
value = params.get(decl.name, 0)
value = params.get(decl.name, 0.0 if decl.base_type == "f64" else 0)
if isinstance(value, list):
msg = f"Scalar param '{decl.name}' expects an int, got list"
msg = f"Scalar param '{decl.name}' expects a scalar, got list"
raise TypeError(msg)
cmd_parts.extend(["--param", f"i64:{value}"])
cmd_parts.extend(["--param", f"{decl.base_type}:{value}"])

cmd_parts.extend(
[
Expand All @@ -448,8 +453,10 @@ def run(
cmd = " ".join(cmd_parts)
stdout = self.session.ssh_run(cmd, timeout=120)

# Parse CSV output
return [int(v) for v in stdout.strip().split(",")]
# Parse CSV output — type depends on the output param
out_type = decls[output_param].base_type
parse = float if out_type == "f64" else int
return [parse(v) for v in stdout.strip().split(",")]


# --- Fixtures ---
Expand Down
173 changes: 173 additions & 0 deletions gpu_test/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,68 @@ def test_tiled_matmul_i64(kernel_runner: KernelRunner) -> None:
assert result == expected


def test_tiled_matmul_f64(kernel_runner: KernelRunner) -> None:
"""Tiled f64 matmul with shared memory: C = A(4x4) * B(4x4) -> C(4x4).

Uses 2x2 tiles, float shared memory for A/B tiles, and BARRIER for sync.
Grid: (2,2,1), Block: (2,2,1) — 4 blocks of 4 threads each.
"""
result = kernel_runner.run(
forth_source=(
"\\! kernel main\n"
"\\! param A f64[16]\n"
"\\! param B f64[16]\n"
"\\! param C f64[16]\n"
"\\! shared SA f64[4]\n"
"\\! shared SB f64[4]\n"
"BID-Y 2 * TID-Y +\n"
"BID-X 2 * TID-X +\n"
"0.0\n"
"2 0 DO\n"
" 2 PICK 4 * I 2 * + TID-X + CELLS A + F@\n"
" TID-Y 2 * TID-X + CELLS SA + SF!\n"
" I 2 * TID-Y + 4 * 2 PICK + CELLS B + F@\n"
" TID-Y 2 * TID-X + CELLS SB + SF!\n"
" BARRIER\n"
" 2 0 DO\n"
" TID-Y 2 * I + CELLS SA + SF@\n"
" I 2 * TID-X + CELLS SB + SF@\n"
" F* F+\n"
" LOOP\n"
" BARRIER\n"
"LOOP\n"
"ROT 4 * ROT + CELLS C + F!"
),
params={
"A": [1.5, 2.0, 0.5, 3.0, 4.0, 1.5, 2.5, 0.5, 0.5, 3.0, 1.0, 2.0, 2.0, 0.5, 3.5, 1.5],
"B": [1.0, 0.5, 2.0, 1.5, 3.0, 1.0, 0.5, 2.0, 0.5, 2.5, 1.0, 0.5, 2.0, 1.5, 3.0, 1.0],
},
grid=(2, 2, 1),
block=(2, 2, 1),
output_param=2,
output_count=16,
)
expected = [
13.75,
8.5,
13.5,
9.5,
10.75,
10.5,
12.75,
10.75,
14.0,
8.75,
9.5,
9.25,
8.25,
12.5,
12.25,
7.25,
]
assert result == [pytest.approx(v) for v in expected]


# --- User-Defined Words ---


Expand All @@ -410,3 +472,114 @@ def test_user_defined_word(kernel_runner: KernelRunner) -> None:
),
)
assert result[0] == 10


# --- Float Arithmetic ---


def test_float_addition(kernel_runner: KernelRunner) -> None:
"""F+: 1.5 + 2.5 = 4.0."""
result = kernel_runner.run(
forth_source="\\! kernel main\n\\! param DATA f64[256]\n1.5 2.5 F+\n0 CELLS DATA + F!",
)
assert result[0] == pytest.approx(4.0)


def test_float_subtraction(kernel_runner: KernelRunner) -> None:
"""F-: 10.0 - 3.5 = 6.5."""
result = kernel_runner.run(
forth_source="\\! kernel main\n\\! param DATA f64[256]\n10.0 3.5 F-\n0 CELLS DATA + F!",
)
assert result[0] == pytest.approx(6.5)


def test_float_multiplication(kernel_runner: KernelRunner) -> None:
"""F*: 6.0 * 7.5 = 45.0."""
result = kernel_runner.run(
forth_source="\\! kernel main\n\\! param DATA f64[256]\n6.0 7.5 F*\n0 CELLS DATA + F!",
)
assert result[0] == pytest.approx(45.0)


def test_float_division(kernel_runner: KernelRunner) -> None:
"""F/: 42.0 / 6.0 = 7.0."""
result = kernel_runner.run(
forth_source="\\! kernel main\n\\! param DATA f64[256]\n42.0 6.0 F/\n0 CELLS DATA + F!",
)
assert result[0] == pytest.approx(7.0)


# --- Float Memory ---


def test_float_load_store(kernel_runner: KernelRunner) -> None:
"""F@ and F!: read from DATA[0], multiply by 2, write to DATA[1]."""
result = kernel_runner.run(
forth_source=(
"\\! kernel main\n\\! param DATA f64[256]\n0 CELLS DATA + F@\n2.0 F*\n1 CELLS DATA + F!"
),
params={"DATA": [3.14]},
output_count=2,
)
assert result[1] == pytest.approx(6.28)


# --- Float Scalar Params ---


def test_float_scalar_param(kernel_runner: KernelRunner) -> None:
"""Scalar f64 param: each thread scales DATA[i] by SCALE."""
result = kernel_runner.run(
forth_source=(
"\\! kernel main\n\\! param DATA f64[256]\n\\! param SCALE f64\n"
"GLOBAL-ID\n"
"DUP CELLS DATA + F@\n"
"SCALE F*\n"
"SWAP CELLS DATA + F!"
),
params={"DATA": [1.0, 2.0, 3.0, 4.0], "SCALE": 2.5},
block=(4, 1, 1),
output_count=4,
)
assert result == [
pytest.approx(2.5),
pytest.approx(5.0),
pytest.approx(7.5),
pytest.approx(10.0),
]


# --- Float Comparisons ---


def test_float_comparisons(kernel_runner: KernelRunner) -> None:
"""F=, F<, F>: True = -1, False = 0 (pushed as i64 on the stack)."""
result = kernel_runner.run(
forth_source=(
"\\! kernel main\n\\! param DATA i64[256]\n"
"3.14 3.14 F= 0 CELLS DATA + !\n"
"1.0 2.0 F< 1 CELLS DATA + !\n"
"5.0 3.0 F> 2 CELLS DATA + !"
),
output_count=3,
)
assert result == [-1, -1, -1]


# --- Float Conversion ---


def test_int_to_float_conversion(kernel_runner: KernelRunner) -> None:
"""S>F: convert int 7 to float, multiply by 1.5, store as f64."""
result = kernel_runner.run(
forth_source=("\\! kernel main\n\\! param DATA f64[256]\n7 S>F 1.5 F*\n0 CELLS DATA + F!"),
)
assert result[0] == pytest.approx(10.5)


def test_float_to_int_conversion(kernel_runner: KernelRunner) -> None:
"""F>S: convert float 7.9 to int (truncates to 7), store as i64."""
result = kernel_runner.run(
forth_source=("\\! kernel main\n\\! param DATA i64[256]\n7.9 F>S\n0 CELLS DATA + !"),
)
assert result[0] == 7
Loading