diff --git a/CLAUDE.md b/CLAUDE.md index 7feac8f..a97eae5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,7 +49,7 @@ Requires MLIR/LLVM with `MLIR_DIR` and `LLVM_DIR` configured in CMake. ./build/bin/warpforth-translate --mlir-to-ptx > kernel.ptx # Execute PTX on GPU -./warpforth-runner kernel.ptx --param 1,2,3 --param 0,0,0,0 --output-param 1 --output-count 5 +./warpforth-runner kernel.ptx --param i64[]:1,2,3 --param i64:42 --output-param 0 --output-count 3 ``` ## Adding New Operations diff --git a/gpu_test/conftest.py b/gpu_test/conftest.py index 87fd95e..de0b273 100644 --- a/gpu_test/conftest.py +++ b/gpu_test/conftest.py @@ -8,6 +8,7 @@ import subprocess import tempfile import time +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING @@ -31,6 +32,15 @@ REMOTE_TMP = "/tmp" # noqa: S108 +@dataclass +class ParamDecl: + """A parsed kernel parameter declaration.""" + + name: str + is_array: bool + size: int # 0 for scalars + + class CompileError(Exception): """Raised when warpforthc fails to compile Forth source.""" @@ -315,13 +325,13 @@ def _parse_kernel_name(forth_source: str) -> str: raise ValueError(msg) -def _parse_param_declarations(forth_source: str) -> list[tuple[str, int]]: +def _parse_param_declarations(forth_source: str) -> list[ParamDecl]: """Parse '\\! param ' declarations from Forth source. - Returns list of (name, size) for array params in declaration order. - Scalar params are not supported by the GPU runner. + Returns list of ParamDecl in declaration order. Supports both array + params (e.g. i64[256]) and scalar params (e.g. i64). """ - decls = [] + decls: list[ParamDecl] = [] for keyword, parts in _iter_header_directives(forth_source): if keyword != "param": continue @@ -330,10 +340,14 @@ def _parse_param_declarations(forth_source: str) -> list[tuple[str, int]]: raise ValueError(msg) name = parts[1] type_spec = parts[2] - if "[" not in type_spec: - msg = "Scalar params are not supported by the GPU runner yet" - raise ValueError(msg) - decls.append((name, _parse_array_type(type_spec))) + if "[" in type_spec: + size = _parse_array_type(type_spec) + decls.append(ParamDecl(name=name, is_array=True, size=size)) + else: + if type_spec.lower() != "i64": + msg = f"Unsupported scalar type: {type_spec}" + raise ValueError(msg) + decls.append(ParamDecl(name=name, is_array=False, size=0)) return decls @@ -347,7 +361,7 @@ def __init__(self, session: VastSession, compiler: Compiler) -> None: def run( self, forth_source: str, - params: dict[str, list[int]] | None = None, + params: dict[str, list[int] | int] | None = None, grid: tuple[int, int, int] = (1, 1, 1), block: tuple[int, int, int] = (1, 1, 1), output_param: int = 0, @@ -356,8 +370,10 @@ def run( """Compile Forth source locally, execute on remote GPU, return output values. Param buffer sizes are derived from the Forth source's 'param' declarations. - The params dict maps param names to initial values (padded with zeros to the - declared size). Params not in the dict are zero-initialized. + The params dict maps param names to initial values: + - Array params: list[int] (padded with zeros to declared size) + - Scalar params: int + Params not in the dict are zero-initialized. """ # Parse kernel name and param declarations kernel_name = _parse_kernel_name(forth_source) @@ -368,6 +384,15 @@ def run( params = params or {} + # Validate output_param + if output_param < 0 or output_param >= len(decls): + msg = f"output_param {output_param} out of range (have {len(decls)} params)" + raise ValueError(msg) + if not decls[output_param].is_array: + name = decls[output_param].name + msg = f"output_param {output_param} ('{name}') is a scalar and cannot be read back" + raise ValueError(msg) + # Compile locally ptx = self.compiler.compile_source(forth_source) @@ -389,12 +414,22 @@ def run( kernel_name, ] - for name, size in decls: - values = params.get(name, []) - buf = [0] * size - for i, v in enumerate(values): - buf[i] = v - cmd_parts.extend(["--param", ",".join(str(v) for v in buf)]) + for decl in decls: + if decl.is_array: + values = params.get(decl.name, []) + if not isinstance(values, list): + msg = f"Array param '{decl.name}' expects a list, got {type(values).__name__}" + raise TypeError(msg) + buf = [0] * decl.size + for i, v in enumerate(values): + buf[i] = v + cmd_parts.extend(["--param", f"i64[]:{','.join(str(v) for v in buf)}"]) + else: + value = params.get(decl.name, 0) + if isinstance(value, list): + msg = f"Scalar param '{decl.name}' expects an int, got list" + raise TypeError(msg) + cmd_parts.extend(["--param", f"i64:{value}"]) cmd_parts.extend( [ diff --git a/gpu_test/test_kernels.py b/gpu_test/test_kernels.py index ad02a5e..4f161ab 100644 --- a/gpu_test/test_kernels.py +++ b/gpu_test/test_kernels.py @@ -280,6 +280,26 @@ def test_multi_param(kernel_runner: KernelRunner) -> None: assert result == [20, 40, 60, 80] +def test_scalar_param(kernel_runner: KernelRunner) -> None: + """Scalar + array params: each thread multiplies INPUT[i] by SCALE, writes OUTPUT[i].""" + result = kernel_runner.run( + forth_source=( + "\\! kernel main\n\\! param SCALE i64\n" + "\\! param INPUT i64[4]\n" + "\\! param OUTPUT i64[4]\n" + "GLOBAL-ID\n" + "DUP CELLS INPUT + @\n" + "SCALE *\n" + "SWAP CELLS OUTPUT + !" + ), + params={"SCALE": 3, "INPUT": [10, 20, 30, 40]}, + block=(4, 1, 1), + output_param=2, + output_count=4, + ) + assert result == [30, 60, 90, 120] + + # --- Matmul --- diff --git a/test/Pipeline/scalar-param.forth b/test/Pipeline/scalar-param.forth new file mode 100644 index 0000000..00771b0 --- /dev/null +++ b/test/Pipeline/scalar-param.forth @@ -0,0 +1,20 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --convert-forth-to-memref --convert-forth-to-gpu | %FileCheck %s --check-prefix=MID + +\ Verify mixed scalar + array params survive the full pipeline +\ CHECK: gpu.binary @warpforth_module + +\ Verify scalar becomes i64 arg, array becomes memref +\ MID: gpu.func @main( +\ MID-SAME: i64 {forth.param_name = "SCALE"} +\ MID-SAME: memref<256xi64> {forth.param_name = "DATA"} +\ MID-SAME: kernel +\ MID: gpu.return + +\! kernel main +\! param SCALE i64 +\! param DATA i64[256] +GLOBAL-ID +DUP CELLS DATA + @ +SCALE * +SWAP CELLS DATA + ! diff --git a/tools/warpforth-runner/warpforth-runner.cpp b/tools/warpforth-runner/warpforth-runner.cpp index 719175e..9f4f98c 100644 --- a/tools/warpforth-runner/warpforth-runner.cpp +++ b/tools/warpforth-runner/warpforth-runner.cpp @@ -5,19 +5,21 @@ /// -std=c++17`. /// /// Usage: -/// warpforth-runner kernel.ptx --param 1,2,3 --param 0,0,0,0 \ +/// warpforth-runner kernel.ptx --param i64[]:1,2,3 --param i64:42 \ /// --grid 4,1,1 --block 64,1,1 --kernel main \ /// --output-param 0 --output-count 3 #include +#include #include -#include #include -#include #include +#include #include #include +#include +#include #include #define CHECK_CU(call) \ @@ -26,43 +28,113 @@ if (err != CUDA_SUCCESS) { \ const char *errStr = nullptr; \ cuGetErrorString(err, &errStr); \ - fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ - errStr ? errStr : "unknown"); \ + std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << ": " \ + << (errStr ? errStr : "unknown") << "\n"; \ exit(1); \ } \ } while (0) +enum class ParamKind { Array, Scalar }; + struct Param { + ParamKind kind = ParamKind::Array; std::vector values; + CUdeviceptr devicePtr = 0; + int64_t scalarValue = 0; }; struct Dims { unsigned x = 1, y = 1, z = 1; }; -static Dims parseDims(const char *s) { - Dims d; - if (sscanf(s, "%u,%u,%u", &d.x, &d.y, &d.z) != 3) { - fprintf(stderr, "Error: expected 3 comma-separated values, got: %s\n", s); +static int parseIntArg(std::string_view s, std::string_view optName) { + int value = 0; + auto [ptr, ec] = std::from_chars(s.data(), s.data() + s.size(), value); + if (ec != std::errc{} || ptr != s.data() + s.size()) { + std::cerr << "Error: " << optName << " expects an integer, got: " << s + << "\n"; exit(1); } + return value; +} + +static Dims parseDims(std::string_view s) { + Dims d; + const char *p = s.data(); + const char *end = s.data() + s.size(); + + auto dimsErr = [&]() { + std::cerr << "Error: expected 3 comma-separated values, got: " << s << "\n"; + exit(1); + }; + + auto [p1, ec1] = std::from_chars(p, end, d.x); + if (ec1 != std::errc{} || p1 == end || *p1 != ',') + dimsErr(); + + auto [p2, ec2] = std::from_chars(p1 + 1, end, d.y); + if (ec2 != std::errc{} || p2 == end || *p2 != ',') + dimsErr(); + + auto [p3, ec3] = std::from_chars(p2 + 1, end, d.z); + if (ec3 != std::errc{} || p3 != end) + dimsErr(); + return d; } -static Param parseParam(const char *s) { +static Param parseParam(std::string_view s) { Param p; - std::istringstream iss(s); + std::string input(s); + + // Find the type prefix (everything before the first ':') + auto colonPos = input.find(':'); + if (colonPos == std::string::npos) { + std::cerr << "Error: --param requires type prefix (e.g. i64:42 or " + "i64[]:1,2,3), got: " + << s << "\n"; + exit(1); + } + + std::string typePrefix = input.substr(0, colonPos); + std::string valueStr = input.substr(colonPos + 1); + + if (valueStr.empty()) { + std::cerr << "Error: --param requires at least one value, got: " << s + << "\n"; + exit(1); + } + + // Determine kind from type prefix + if (typePrefix == "i64[]") { + p.kind = ParamKind::Array; + } else if (typePrefix == "i64") { + p.kind = ParamKind::Scalar; + } else { + std::cerr << "Error: unsupported param type '" << typePrefix + << "' (expected i64 or i64[]), got: " << s << "\n"; + exit(1); + } + + // Parse values + std::istringstream iss(valueStr); std::string token; - while (std::getline(iss, token, ',')) { + while (std::getline(iss, token, ',')) p.values.push_back(std::stoll(token)); + + if (p.kind == ParamKind::Scalar && p.values.size() != 1) { + std::cerr << "Error: scalar param expects exactly one value, got: " << s + << "\n"; + exit(1); } + return p; } -static std::string readFile(const char *path) { - std::ifstream f(path, std::ios::binary); +static std::string readFile(std::string_view path) { + std::ifstream f(std::string(path), std::ios::binary); if (!f) { - fprintf(stderr, "Error: cannot open %s\n", path); + std::cerr << "Error: cannot open " << path << "\n"; exit(1); } std::ostringstream ss; @@ -80,70 +152,65 @@ int main(int argc, char **argv) { // Parse arguments for (int i = 1; i < argc; ++i) { - if (strcmp(argv[i], "--param") == 0) { + std::string_view arg = argv[i]; + auto needsValue = [&](std::string_view opt) { if (++i >= argc) { - fprintf(stderr, "Error: --param requires a value\n"); - return 1; + std::cerr << "Error: " << opt << " requires a value\n"; + exit(1); } + }; + if (arg == "--param") { + needsValue("--param"); params.push_back(parseParam(argv[i])); - } else if (strcmp(argv[i], "--grid") == 0) { - if (++i >= argc) { - fprintf(stderr, "Error: --grid requires a value\n"); - return 1; - } + } else if (arg == "--grid") { + needsValue("--grid"); grid = parseDims(argv[i]); - } else if (strcmp(argv[i], "--block") == 0) { - if (++i >= argc) { - fprintf(stderr, "Error: --block requires a value\n"); - return 1; - } + } else if (arg == "--block") { + needsValue("--block"); block = parseDims(argv[i]); - } else if (strcmp(argv[i], "--output-param") == 0) { - if (++i >= argc) { - fprintf(stderr, "Error: --output-param requires a value\n"); - return 1; - } - outputParam = atoi(argv[i]); - } else if (strcmp(argv[i], "--output-count") == 0) { - if (++i >= argc) { - fprintf(stderr, "Error: --output-count requires a value\n"); - return 1; - } - outputCount = atoi(argv[i]); - } else if (strcmp(argv[i], "--kernel") == 0) { - if (++i >= argc) { - fprintf(stderr, "Error: --kernel requires a value\n"); - return 1; - } + } else if (arg == "--output-param") { + needsValue("--output-param"); + outputParam = parseIntArg(argv[i], "--output-param"); + } else if (arg == "--output-count") { + needsValue("--output-count"); + outputCount = parseIntArg(argv[i], "--output-count"); + } else if (arg == "--kernel") { + needsValue("--kernel"); kernelName = argv[i]; - } else if (argv[i][0] == '-') { - fprintf(stderr, "Error: unknown option %s\n", argv[i]); - return 1; + } else if (arg[0] == '-') { + std::cerr << "Error: unknown option " << arg << "\n"; + exit(1); } else { ptxFile = argv[i]; } } if (!ptxFile) { - fprintf(stderr, "Usage: warpforth-runner kernel.ptx --kernel NAME " - "[--param V,...] [--grid X,Y,Z] [--block X,Y,Z] " - "[--output-param N] [--output-count N]\n"); + std::cerr << "Usage: warpforth-runner kernel.ptx --kernel NAME " + "[--param i64[]:V,...] [--param i64:V] [--grid X,Y,Z] " + "[--block X,Y,Z] [--output-param N] [--output-count N]\n"; return 1; } if (!kernelName) { - fprintf(stderr, "Error: --kernel NAME is required\n"); + std::cerr << "Error: --kernel NAME is required\n"; return 1; } if (params.empty()) { - fprintf(stderr, "Error: at least one --param is required\n"); + std::cerr << "Error: at least one --param is required\n"; return 1; } if (outputParam < 0 || outputParam >= static_cast(params.size())) { - fprintf(stderr, "Error: output-param %d out of range (have %zu params)\n", - outputParam, params.size()); + std::cerr << "Error: output-param " << outputParam << " out of range (have " + << params.size() << " params)\n"; + return 1; + } + + if (params[outputParam].kind == ParamKind::Scalar) { + std::cerr << "Error: output-param " << outputParam + << " is a scalar (cannot read back)\n"; return 1; } @@ -166,18 +233,23 @@ int main(int argc, char **argv) { CUfunction func; CHECK_CU(cuModuleGetFunction(&func, module, kernelName)); - // Allocate device buffers and copy data - std::vector devicePtrs(params.size()); - for (size_t i = 0; i < params.size(); ++i) { - size_t bytes = params[i].values.size() * sizeof(int64_t); - CHECK_CU(cuMemAlloc(&devicePtrs[i], bytes)); - CHECK_CU(cuMemcpyHtoD(devicePtrs[i], params[i].values.data(), bytes)); + // Allocate device buffers (arrays) or store scalar values + for (auto &p : params) { + if (p.kind == ParamKind::Array) { + size_t bytes = p.values.size() * sizeof(int64_t); + CHECK_CU(cuMemAlloc(&p.devicePtr, bytes)); + CHECK_CU(cuMemcpyHtoD(p.devicePtr, p.values.data(), bytes)); + } else { + p.scalarValue = p.values[0]; + } } // Set up kernel parameters — Driver API expects array of pointers to args std::vector kernelArgs(params.size()); for (size_t i = 0; i < params.size(); ++i) { - kernelArgs[i] = &devicePtrs[i]; + kernelArgs[i] = (params[i].kind == ParamKind::Array) + ? static_cast(¶ms[i].devicePtr) + : static_cast(¶ms[i].scalarValue); } // Launch kernel @@ -189,21 +261,22 @@ int main(int argc, char **argv) { // Copy back output param size_t outSize = params[outputParam].values.size(); std::vector output(outSize); - CHECK_CU(cuMemcpyDtoH(output.data(), devicePtrs[outputParam], + CHECK_CU(cuMemcpyDtoH(output.data(), params[outputParam].devicePtr, outSize * sizeof(int64_t))); // Print CSV to stdout size_t count = outputCount >= 0 ? static_cast(outputCount) : outSize; for (size_t i = 0; i < count; ++i) { if (i > 0) - printf(","); - printf("%ld", static_cast(output[i])); + std::cout << ","; + std::cout << output[i]; } - printf("\n"); + std::cout << "\n"; - // Cleanup - for (auto &ptr : devicePtrs) { - cuMemFree(ptr); + // Cleanup — only free device memory for array params + for (auto &p : params) { + if (p.kind == ParamKind::Array) + cuMemFree(p.devicePtr); } cuModuleUnload(module); cuCtxDestroy(ctx);