diff --git a/package/version b/package/version index 3a28972df..010fbfb88 100644 --- a/package/version +++ b/package/version @@ -1 +1 @@ -0.1.148 +0.1.149 diff --git a/pykwasm/pyproject.toml b/pykwasm/pyproject.toml index 2cd96eee5..f4ae91554 100644 --- a/pykwasm/pyproject.toml +++ b/pykwasm/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "pykwasm" -version = "0.1.148" +version = "0.1.149" description = "" readme = "README.md" requires-python = "~=3.10" diff --git a/pykwasm/src/pykwasm/binary/__init__.py b/pykwasm/src/pykwasm/binary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pykwasm/src/pykwasm/binary/combinators.py b/pykwasm/src/pykwasm/binary/combinators.py new file mode 100644 index 000000000..0adc46731 --- /dev/null +++ b/pykwasm/src/pykwasm/binary/combinators.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .integers import u32 +from .utils import WasmParseError, reset + +if TYPE_CHECKING: + from .utils import A, InputStream, Parser + + +def iterate(f: Parser[A], s: InputStream) -> list[A]: + results = [] + while True: + try: + results.append(f(s)) + except (WasmParseError, IndexError, ValueError): + break + return results + + +def sized(p: Parser[A], s: InputStream) -> A: + size = u32(s) + start_pos = s.tell() + res = p(s) + end_pos = s.tell() + if end_pos - start_pos != size: + raise WasmParseError('Size mismatch') + return res + + +def parse_n(p: Parser[A], n: int, s: InputStream) -> list[A]: + results = [] + for _ in range(n): + x = p(s) + results.append(x) + return results + + +def list_of(p: Parser[A], s: InputStream) -> list[A]: + n = u32(s) + return parse_n(p, n, s) + + +def either(ps: list[Parser[A]], s: InputStream) -> A: + for p in ps: + pos = s.tell() + try: + return p(s) + except WasmParseError: + reset(pos, s) + continue + raise WasmParseError('None of the alternatives succeeded') diff --git a/pykwasm/src/pykwasm/binary/floats.py b/pykwasm/src/pykwasm/binary/floats.py new file mode 100644 index 000000000..de1dd5359 --- /dev/null +++ b/pykwasm/src/pykwasm/binary/floats.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import struct +from typing import TYPE_CHECKING + +from .utils import read_bytes + +if TYPE_CHECKING: + from .utils import InputStream + + +def f32(s: InputStream) -> float: + bs = read_bytes(4, s) + f = struct.unpack(' float: + bs = read_bytes(8, s) + f = struct.unpack(' int: + return u32(s) + + +def funcidx(s: InputStream) -> int: + return u32(s) + + +def tableidx(s: InputStream) -> int: + return u32(s) + + +def memidx(s: InputStream) -> int: + return u32(s) + + +def globalidx(s: InputStream) -> int: + return u32(s) + + +def tagidx(s: InputStream) -> int: + return u32(s) + + +def elemidx(s: InputStream) -> int: + return u32(s) + + +def dataidx(s: InputStream) -> int: + return u32(s) + + +def localidx(s: InputStream) -> int: + return u32(s) + + +def labelidx(s: InputStream) -> int: + return u32(s) + + +def externidx(s: InputStream) -> KInner: + match read_byte(s): + case 0x00: + return wast.externidx_func(funcidx(s)) + case 0x01: + return wast.externidx_table(tableidx(s)) + case 0x02: + return wast.externidx_memory(memidx(s)) + case 0x03: + return wast.externidx_global(globalidx(s)) + case 0x04: + return wast.externidx_tag(tagidx(s)) + case x: + raise WasmParseError(f'Invalid externidx descriptor: {x}') diff --git a/pykwasm/src/pykwasm/binary/instructions.py b/pykwasm/src/pykwasm/binary/instructions.py new file mode 100644 index 000000000..6d95ad272 --- /dev/null +++ b/pykwasm/src/pykwasm/binary/instructions.py @@ -0,0 +1,541 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pykwasm.kwasm_ast as wast +from pykwasm.binary.combinators import list_of +from pykwasm.binary.floats import f32, f64 +from pykwasm.binary.indices import elemidx, funcidx, globalidx, labelidx, localidx, memidx, tableidx, typeidx +from pykwasm.binary.integers import i32, i64, u32, u64 +from pykwasm.binary.types import heaptype, valtype +from pykwasm.binary.utils import WasmParseError, expect_bytes, peek_byte, read_byte, skip + +if TYPE_CHECKING: + from pyk.kast.inner import KInner + + from pykwasm.binary.utils import InputStream + + +def instr_seq(terminator: int, s: InputStream) -> list[KInner]: + res: list[KInner] = [] + + while peek_byte(s) != terminator: + i = instr(s) + res.append(i) + + expect_bytes(bytes([terminator]), s) + return res + + +def instr(s: InputStream) -> KInner: + def blocktype(s: InputStream) -> KInner: + if peek_byte(s) == 0x40: + skip(1, s) + return wast.vec_type(wast.val_types([])) + + try: + t = valtype(s) + return wast.vec_type(wast.val_types([t])) + except WasmParseError as e: + # TODO handle the `blocktype ::= i:i33` case + raise WasmParseError('Could not parse blocktype') from e + + # returns memory index, alignment, offset + def memarg(s: InputStream) -> tuple[int, int, int]: + n = u32(s) + if n < 2**6: + m = u64(s) + return 0, n, m + elif n < 2**7: + x = memidx(s) + m = u64(s) + return x, n - 2**6, m + else: + raise WasmParseError(f'Invalid memarg. n: {n}') + + opcode = read_byte(s) + # '0xFC' is a special opcode prefix shared by several instructions. + # These instructions are distinguished by a u32 value following the opcode. + additional_value: int | None = None + if opcode == 0xFC: + try: + additional_value = u32(s) + except Exception as e: + raise WasmParseError('Cannot parse opcode, expected u32 after 0xFC') from e + + match opcode: + # Parametric instructions + case 0x00: + return wast.UNREACHABLE + case 0x01: + return wast.NOP + case 0x1A: + return wast.DROP + case 0x1B: + return wast.SELECT + case 0x1C: + # TODO handle ts + ts = list_of(valtype, s) # noqa F841 + return wast.SELECT + + # Control instructions + case 0x02: + bt = blocktype(s) + ins = instr_seq(0x0B, s) + return wast.BLOCK(bt, wast.instrs(ins), wast.KInt(0)) # TODO either track block ids or deprecate them + case 0x03: + bt = blocktype(s) + ins = instr_seq(0x0B, s) + return wast.LOOP(bt, wast.instrs(ins), wast.KInt(0)) # TODO either track block ids or deprecate them + case 0x04: + bt = blocktype(s) + in1 = instr_seq(0x05, s) + in2 = instr_seq(0x0B, s) + return wast.IF( + bt, wast.instrs(in1), wast.instrs(in2), wast.KInt(0) + ) # TODO either track block ids or deprecate them + case 0x0C: + l = labelidx(s) + return wast.BR(l) + case 0x0D: + l = labelidx(s) + return wast.BR_IF(l) + case 0x0E: + ls = list_of(labelidx, s) + ln = labelidx(s) + return wast.BR_TABLE(tuple(ls), ln) + case 0x0F: + return wast.RETURN + case 0x10: + x = funcidx(s) + return wast.CALL(x) + case 0x11: + y = typeidx(s) + x = tableidx(s) + return wast.CALL_INDIRECT(x, y) + + # Variable instructions + case 0x20: + x = localidx(s) + return wast.GET_LOCAL(x) + case 0x21: + x = localidx(s) + return wast.SET_LOCAL(x) + case 0x22: + x = localidx(s) + return wast.TEE_LOCAL(x) + case 0x23: + x = globalidx(s) + return wast.GET_GLOBAL(x) + case 0x24: + x = globalidx(s) + return wast.SET_GLOBAL(x) + + # Table instructions + case 0x25: + x = tableidx(s) + return wast.TABLE_GET(x) + case 0x26: + x = tableidx(s) + return wast.TABLE_SET(x) + case 0xFC if additional_value == 12: + y = elemidx(s) + x = tableidx(s) + return wast.TABLE_INIT(x, y) + case 0xFC if additional_value == 13: + x = elemidx(s) + return wast.ELEM_DROP(x) + case 0xFC if additional_value == 14: + x1 = tableidx(s) + x2 = tableidx(s) + return wast.TABLE_COPY(x1, x2) + case 0xFC if additional_value == 15: + x = tableidx(s) + return wast.TABLE_GROW(x) + case 0xFC if additional_value == 16: + x = tableidx(s) + return wast.TABLE_SIZE(x) + case 0xFC if additional_value == 17: + x = tableidx(s) + return wast.TABLE_FILL(x) + + # Memory instructions + # TODO handle memory index and alignment + case 0x28: + mx, align, offset = memarg(s) # noqa F841 + return wast.I32_LOAD(offset) + case 0x29: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_LOAD(offset) + case 0x2A: + mx, align, offset = memarg(s) # noqa F841 + return wast.F32_LOAD(offset) + case 0x2B: + mx, align, offset = memarg(s) # noqa F841 + return wast.F64_LOAD(offset) + case 0x2C: + mx, align, offset = memarg(s) # noqa F841 + return wast.I32_LOAD8_S(offset) + case 0x2D: + mx, align, offset = memarg(s) # noqa F841 + return wast.I32_LOAD8_U(offset) + case 0x2E: + mx, align, offset = memarg(s) # noqa F841 + return wast.I32_LOAD16_S(offset) + case 0x2F: + mx, align, offset = memarg(s) # noqa F841 + return wast.I32_LOAD16_U(offset) + case 0x30: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_LOAD8_S(offset) + case 0x31: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_LOAD8_U(offset) + case 0x32: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_LOAD16_S(offset) + case 0x33: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_LOAD16_U(offset) + case 0x34: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_LOAD32_S(offset) + case 0x35: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_LOAD32_U(offset) + case 0x36: + mx, align, offset = memarg(s) # noqa F841 + return wast.I32_STORE(offset) + case 0x37: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_STORE(offset) + case 0x38: + mx, align, offset = memarg(s) # noqa F841 + return wast.F32_STORE(offset) + case 0x39: + mx, align, offset = memarg(s) # noqa F841 + return wast.F64_STORE(offset) + case 0x3A: + mx, align, offset = memarg(s) # noqa F841 + return wast.I32_STORE8(offset) + case 0x3B: + mx, align, offset = memarg(s) # noqa F841 + return wast.I32_STORE16(offset) + case 0x3C: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_STORE8(offset) + case 0x3D: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_STORE16(offset) + case 0x3E: + mx, align, offset = memarg(s) # noqa F841 + return wast.I64_STORE32(offset) + case 0x3F: + # TODO use memory index + _x = memidx(s) # noqa 841 + return wast.MEMORY_SIZE + case 0x40: + # TODO use memory index + _x = memidx(s) # noqa 841 + return wast.MEMORY_GROW + + # Reference instructions + case 0xD0: + ht = heaptype(s) + return wast.REF_NULL(ht) + case 0xD1: + return wast.REF_IS_NULL + case 0xD2: + x = funcidx(s) + return wast.REF_FUNC(x) + + # Numeric instructions + # Constants + case 0x41: + i = i32(s) + return wast.I32_CONST(i) + case 0x42: + i = i64(s) + return wast.I64_CONST(i) + case 0x43: + f = f32(s) + return wast.F32_CONST(f) + case 0x44: + f = f64(s) + return wast.F64_CONST(f) + + # i32 comparison + case 0x45: + return wast.I32_EQZ + case 0x46: + return wast.I32_EQ + case 0x47: + return wast.I32_NE + case 0x48: + return wast.I32_LT_S + case 0x49: + return wast.I32_LT_U + case 0x4A: + return wast.I32_GT_S + case 0x4B: + return wast.I32_GT_U + case 0x4C: + return wast.I32_LE_S + case 0x4D: + return wast.I32_LE_U + case 0x4E: + return wast.I32_GE_S + case 0x4F: + return wast.I32_GE_U + + # i64 comparison + case 0x50: + return wast.I64_EQZ + case 0x51: + return wast.I64_EQ + case 0x52: + return wast.I64_NE + case 0x53: + return wast.I64_LT_S + case 0x54: + return wast.I64_LT_U + case 0x55: + return wast.I64_GT_S + case 0x56: + return wast.I64_GT_U + case 0x57: + return wast.I64_LE_S + case 0x58: + return wast.I64_LE_U + case 0x59: + return wast.I64_GE_S + case 0x5A: + return wast.I64_GE_U + + # f32 comparison + case 0x5B: + return wast.F32_EQ + case 0x5C: + return wast.F32_NE + case 0x5D: + return wast.F32_LT + case 0x5E: + return wast.F32_GT + case 0x5F: + return wast.F32_LE + case 0x60: + return wast.F32_GE + + # f64 comparison + case 0x61: + return wast.F64_EQ + case 0x62: + return wast.F64_NE + case 0x63: + return wast.F64_LT + case 0x64: + return wast.F64_GT + case 0x65: + return wast.F64_LE + case 0x66: + return wast.F64_GE + + # i32 unary / binary + case 0x67: + return wast.I32_CLZ + case 0x68: + return wast.I32_CTZ + case 0x69: + return wast.I32_POPCNT + case 0x6A: + return wast.I32_ADD + case 0x6B: + return wast.I32_SUB + case 0x6C: + return wast.I32_MUL + case 0x6D: + return wast.I32_DIV_S + case 0x6E: + return wast.I32_DIV_U + case 0x6F: + return wast.I32_REM_S + case 0x70: + return wast.I32_REM_U + case 0x71: + return wast.I32_AND + case 0x72: + return wast.I32_OR + case 0x73: + return wast.I32_XOR + case 0x74: + return wast.I32_SHL + case 0x75: + return wast.I32_SHR_S + case 0x76: + return wast.I32_SHR_U + case 0x77: + return wast.I32_ROTL + case 0x78: + return wast.I32_ROTR + + # i64 unary / binary + case 0x79: + return wast.I64_CLZ + case 0x7A: + return wast.I64_CTZ + case 0x7B: + return wast.I64_POPCNT + case 0x7C: + return wast.I64_ADD + case 0x7D: + return wast.I64_SUB + case 0x7E: + return wast.I64_MUL + case 0x7F: + return wast.I64_DIV_S + case 0x80: + return wast.I64_DIV_U + case 0x81: + return wast.I64_REM_S + case 0x82: + return wast.I64_REM_U + case 0x83: + return wast.I64_AND + case 0x84: + return wast.I64_OR + case 0x85: + return wast.I64_XOR + case 0x86: + return wast.I64_SHL + case 0x87: + return wast.I64_SHR_S + case 0x88: + return wast.I64_SHR_U + case 0x89: + return wast.I64_ROTL + case 0x8A: + return wast.I64_ROTR + + # f32 unary / binary + case 0x8B: + return wast.F32_ABS + case 0x8C: + return wast.F32_NEG + case 0x8D: + return wast.F32_CEIL + case 0x8E: + return wast.F32_FLOOR + case 0x8F: + return wast.F32_TRUNC + case 0x90: + return wast.F32_NEAREST + case 0x91: + return wast.F32_SQRT + case 0x92: + return wast.F32_ADD + case 0x93: + return wast.F32_SUB + case 0x94: + return wast.F32_MUL + case 0x95: + return wast.F32_DIV + case 0x96: + return wast.F32_MIN + case 0x97: + return wast.F32_MAX + case 0x98: + return wast.F32_COPYSIGN + + # f64 unary / binary + case 0x99: + return wast.F64_ABS + case 0x9A: + return wast.F64_NEG + case 0x9B: + return wast.F64_CEIL + case 0x9C: + return wast.F64_FLOOR + case 0x9D: + return wast.F64_TRUNC + case 0x9E: + return wast.F64_NEAREST + case 0x9F: + return wast.F64_SQRT + case 0xA0: + return wast.F64_ADD + case 0xA1: + return wast.F64_SUB + case 0xA2: + return wast.F64_MUL + case 0xA3: + return wast.F64_DIV + case 0xA4: + return wast.F64_MIN + case 0xA5: + return wast.F64_MAX + case 0xA6: + return wast.F64_COPYSIGN + + # conversion instructions + case 0xA7: + return wast.I32_WRAP_I64 + case 0xA8: + return wast.I32_TRUNC_S_F32 + case 0xA9: + return wast.I32_TRUNC_U_F32 + case 0xAA: + return wast.I32_TRUNC_S_F64 + case 0xAB: + return wast.I32_TRUNC_U_F64 + case 0xAC: + return wast.I64_EXTEND_S_I32 + case 0xAD: + return wast.I64_EXTEND_U_I32 + case 0xAE: + return wast.I64_TRUNC_S_F32 + case 0xAF: + return wast.I64_TRUNC_U_F32 + case 0xB0: + return wast.I64_TRUNC_S_F64 + case 0xB1: + return wast.I64_TRUNC_U_F64 + case 0xB2: + return wast.F32_CONVERT_S_I32 + case 0xB3: + return wast.F32_CONVERT_U_I32 + case 0xB4: + return wast.F32_CONVERT_S_I64 + case 0xB5: + return wast.F32_CONVERT_U_I64 + case 0xB6: + return wast.F32_DEMOTE_F64 + case 0xB7: + return wast.F64_CONVERT_S_I32 + case 0xB8: + return wast.F64_CONVERT_U_I32 + case 0xB9: + return wast.F64_CONVERT_S_I64 + case 0xBA: + return wast.F64_CONVERT_U_I64 + case 0xBB: + return wast.F64_PROMOTE_F32 + + # sign-extension instructions + case 0xC0: + return wast.I32_EXTEND8_s + case 0xC1: + return wast.I32_EXTEND16_s + case 0xC2: + return wast.I64_EXTEND8_s + case 0xC3: + return wast.I64_EXTEND16_s + case 0xC4: + return wast.I64_EXTEND32_s + + # Handle unsupported opcodes + case _: + raise WasmParseError(f'Unsupported opcode: {opcode :#04x}, {additional_value}') + + +def expr(s: InputStream) -> list[KInner]: + return instr_seq(0x0B, s) diff --git a/pykwasm/src/pykwasm/binary/integers.py b/pykwasm/src/pykwasm/binary/integers.py new file mode 100644 index 000000000..0f4479178 --- /dev/null +++ b/pykwasm/src/pykwasm/binary/integers.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .utils import WasmParseError, read_byte + +if TYPE_CHECKING: + from .utils import InputStream + + +def parse_uint(bits: int, s: InputStream) -> int: + n = read_byte(s) + + # No continuation bit + if n < 2**7: + if n >= 2**bits: + raise WasmParseError(f'Value {n} exceeds u{bits} range') + return n + + # Has continuation bit + if bits <= 7: # Not enough bits remaining + raise WasmParseError(f'LEB128 encoding too long for u{bits}') + + m = parse_uint(bits - 7, s) + return (n - 2**7) + (2**7) * m + + +def parse_sint(bits: int, s: InputStream) -> int: + n = read_byte(s) + + # No continuation bit and 0 sign bit + if n < 2**6: + if n >= 2 ** (bits - 1): + raise WasmParseError(f'Value {n} exceeds s{bits} range') + return n + + # No continuation bit and 1 sign bit + if 2**6 <= n and n < 2**7: + signed_value = n - 2**7 + if signed_value < -(2 ** (bits - 1)): + raise WasmParseError(f'Value {signed_value} exceeds s{bits} range') + return n - 2**7 + + # Has continuation bit + if bits <= 7: # Not enough bits remaining + raise WasmParseError(f'LEB128 encoding too long for s{bits}') + + i = parse_sint(bits - 7, s) + return (2**7) * i + (n - 2**7) + + +def u32(s: InputStream) -> int: + return parse_uint(32, s) + + +def u64(s: InputStream) -> int: + return parse_uint(64, s) + + +def to_uninterpreted(bits: int, x: int) -> int: + if 0 <= x: + return x + return x + 2 ** bits + +def parse_iint(bits: int, s: InputStream) -> int: + i = parse_sint(bits, s) + return to_uninterpreted(bits, i) + + +def i32(s: InputStream) -> int: + return parse_iint(32, s) + + +def i64(s: InputStream) -> int: + return parse_iint(64, s) diff --git a/pykwasm/src/pykwasm/binary/module.py b/pykwasm/src/pykwasm/binary/module.py new file mode 100644 index 000000000..38b991078 --- /dev/null +++ b/pykwasm/src/pykwasm/binary/module.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pyk.kast.inner import KApply, KLabel, KSort, KToken + +import pykwasm.kwasm_ast as wast +from pykwasm.binary.indices import externidx, funcidx, memidx, tableidx, typeidx +from pykwasm.binary.instructions import expr + +from .combinators import list_of, sized +from .integers import u32 +from .types import externtype_as_import_desc, globaltype, memtype, rectype, reftype, tabletype, valtype +from .utils import WasmEOFError, WasmParseError, peek_byte, peek_bytes, read_byte, read_bytes, skip + +if TYPE_CHECKING: + from pyk.kast.inner import KInner + + from .types import RecType + from .utils import A, InputStream, Parser + +MAGIC = b'\x00\x61\x73\x6d' +VERSION = b'\x01\x00\x00\x00' + +EMPTY_SECTION: KInner = wast.EMPTY_DEFNS + + +def parse_magic(s: InputStream) -> None: + try: + val = s.read(4) + assert val == MAGIC + except Exception as e: + raise WasmParseError('Could not read magic value') from e + + +def parse_version(s: InputStream) -> None: + try: + val = s.read(4) + assert val == VERSION + except Exception as e: + raise WasmParseError('Could not read version') from e + + +def parse_custom_section(s: InputStream) -> bytes: + try: + n = peek_byte(s) + except WasmEOFError: + return b'' + + if n != 0: + return b'' + skip(1, s) + return read_bytes(n, s) + + +def parse_custom_sections(s: InputStream) -> list[bytes]: + sects = [] + while True: + sect = parse_custom_section(s) + if not sect: + break + sects.append(sect) + return sects + + +def section(sec_id: int, p: Parser[A], default: A, s: InputStream) -> A: + try: + n = peek_byte(s) + except WasmEOFError: + return default + + if n != sec_id: + return default + skip(1, s) + return sized(p, s) + + +# TODO support recursive types in type sections +def typesec(s: InputStream) -> KInner: + def rectype_to_type_decl(rc: RecType) -> KInner: + match rc: + case [(_, [], st)]: + return wast.type(st) + case _: + raise WasmParseError('recursive types are not supported') + + rectypes = list_of(rectype, s) + type_decls = [rectype_to_type_decl(t) for t in rectypes] + return wast.defns(type_decls) + + +def byte_list(s: InputStream) -> bytes: + n = u32(s) + return read_bytes(n, s) + + +def name(s: InputStream) -> str: + bs = byte_list(s) + return bs.decode('utf-8') + + +def importsec(s: InputStream) -> KInner: + def import_(s: InputStream) -> KInner: + nm1 = wast.wasm_string(name(s)) + nm2 = wast.wasm_string(name(s)) + desc = externtype_as_import_desc(s) + return wast.imp(nm1, nm2, desc) + + imports = list_of(import_, s) + return wast.defns(imports) + + +def funcsec(s: InputStream) -> list[int]: + return list_of(typeidx, s) + + +def tablesec(s: InputStream) -> KInner: + def table(s: InputStream) -> KInner: + match peek_bytes(2, s): + case b'\x40\x00': + raise WasmParseError('table_tt_with_initializer not implemented') + case _: + # TODO this should return '𝗍𝖺𝖻𝗅𝖾 tt (𝗋𝖾𝖿.π—‡π—Žπ—…π—… ht) if tt=at lim (𝗋𝖾𝖿 π—‡π—Žπ—…π—…? ht)' + _at, lim, rt = tabletype(s) + return wast.table(lim, rt) + + tables = list_of(table, s) + return wast.defns(tables) + + +def memsec(s: InputStream) -> KInner: + mems = list_of(memtype, s) + return wast.defns([wast.memory(lim) for _at, lim in mems]) + + +def globalsec(s: InputStream) -> KInner: + def glob(s: InputStream) -> KInner: + gt = globaltype(s) + e = wast.instrs(expr(s)) + return wast.glob(gt, e) + + return wast.defns(list_of(glob, s)) + + +def exportsec(s: InputStream) -> KInner: + def export(s: InputStream) -> KInner: + nm = name(s) + xx = externidx(s) + return wast.export(wast.wasm_string(nm), index=xx) + + return wast.defns(list_of(export, s)) + + +def startsec(s: InputStream) -> KInner: + start_func = funcidx(s) + return wast.defns([wast.start(start_func)]) + + +def elemsec(s: InputStream) -> KInner: + def elemkind(s: InputStream) -> KInner: + match read_byte(s): + case 0: + return wast.funcref + case x: + raise WasmParseError(f'Invalid elemkind: {x}') + + def elem_init(init: list[list[KInner]]) -> list[int | None]: + def expr_to_int(expr: list[KInner]) -> int | None: + # 'expr' must be a constant expression consisting of a reference instruction + match expr: + case [KApply(KLabel('aRef.func'), (KToken(sfuncidx, KSort('Int')),))]: + return int(sfuncidx) + case [KApply(KLabel('aRef.null'), _)]: + return None + case _: + raise WasmParseError(f'Expected a constant reference instruction as elem initializer, got {expr}') + + return [expr_to_int(e) for e in init] + + def elem(s: InputStream) -> KInner: + match read_byte(s): + case 0: + e0 = expr(s) + ys = list_of(funcidx, s) + return wast.elem(wast.funcref, wast.elem_active(0, wast.instrs(e0)), ys) + case 1: + rt = elemkind(s) + ys = list_of(funcidx, s) + return wast.elem(rt, wast.elem_passive(), ys) + case 2: + x = tableidx(s) + e = expr(s) + rt = elemkind(s) + ys = list_of(funcidx, s) + return wast.elem(rt, wast.elem_active(x, wast.instrs(e)), ys) + case 3: + rt = elemkind(s) + ys = list_of(funcidx, s) + return wast.elem(rt, wast.elem_declarative(), ys) + case 4: + e0 = expr(s) + es = list_of(expr, s) + return wast.elem(wast.funcref, wast.elem_active(0, wast.instrs(e0)), elem_init(es)) + case 5: + rt = reftype(s) + es = list_of(expr, s) + return wast.elem(rt, wast.elem_passive(), elem_init(es)) + case 6: + x = tableidx(s) + e0 = expr(s) + rt = reftype(s) + es = list_of(expr, s) + return wast.elem(rt, wast.elem_active(x, wast.instrs(e0)), elem_init(es)) + case 7: + rt = reftype(s) + es = list_of(expr, s) + return wast.elem(rt, wast.elem_declarative(), elem_init(es)) + case x: + raise WasmParseError(f'Invalid elem descriptor: {x}') + + return wast.defns(list_of(elem, s)) + + +def codesec(s: InputStream) -> list[tuple[list[KInner], list[KInner]]]: + def locals(s: InputStream) -> tuple[int, KInner]: + n = u32(s) + t = valtype(s) + return n, t + + def func(s: InputStream) -> tuple[list[KInner], list[KInner]]: + locs = list_of(locals, s) + body = expr(s) + + locs_flattened = [t for n, t in locs for _ in range(n)] + return locs_flattened, body + + def code(s: InputStream) -> tuple[list[KInner], list[KInner]]: + return sized(func, s) + + return list_of(code, s) + + +def datasec(s: InputStream) -> KInner: + def data(s: InputStream) -> KInner: + match read_byte(s): + case 0: + e = expr(s) + bs = byte_list(s) + return wast.data(bs, wast.datamode_active(0, wast.instrs(e))) + case 1: + bs = byte_list(s) + return wast.data(bs, wast.datamode_passive()) + case 2: + x = memidx(s) + e = expr(s) + bs = byte_list(s) + return wast.data(bs, wast.datamode_active(x, wast.instrs(e))) + case x: + raise WasmParseError(f'Invalid data segment descriptor: {x}') + + return wast.defns(list_of(data, s)) + + +def datacntsec(s: InputStream) -> int | None: + return u32(s) + + +def parse_module(s: InputStream) -> KInner: + parse_magic(s) + parse_version(s) + + custom_sections = [] + customs = lambda: custom_sections.extend(parse_custom_sections(s)) + + customs() + + type_section = section(1, typesec, EMPTY_SECTION, s) + + customs() + + import_section = section(2, importsec, EMPTY_SECTION, s) + + customs() + + function_section: list[int] = section(3, funcsec, [], s) + + customs() + + table_section = section(4, tablesec, EMPTY_SECTION, s) + + customs() + + memory_section = section(5, memsec, EMPTY_SECTION, s) + + customs() + + global_section = section(6, globalsec, EMPTY_SECTION, s) + + customs() + + export_section = section(7, exportsec, EMPTY_SECTION, s) + + customs() + + start_section = section(8, startsec, EMPTY_SECTION, s) + + customs() + + elem_section = section(9, elemsec, EMPTY_SECTION, s) + + customs() + + data_cnt_section = section(12, datacntsec, None, s) # noqa f841 + + customs() + + code_section: list[tuple[list[KInner], list[KInner]]] = section(10, codesec, [], s) + + customs() + + data_section = section(11, datasec, EMPTY_SECTION, s) + + customs() + + functions = [ + wast.func( + wast.KInt(typ_idx), + wast.vec_type(wast.val_types(locals)), + wast.instrs(code), + ) + for typ_idx, (locals, code) in zip(function_section, code_section, strict=True) + ] + + return wast.module( + types=type_section, + funcs=wast.defns(functions), + tables=table_section, + mems=memory_section, + globs=global_section, + elem=elem_section, + data=data_section, + start=start_section, + imports=import_section, + exports=export_section, + ) diff --git a/pykwasm/src/pykwasm/binary/types.py b/pykwasm/src/pykwasm/binary/types.py new file mode 100644 index 000000000..b5cecae18 --- /dev/null +++ b/pykwasm/src/pykwasm/binary/types.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pykwasm.binary.indices import typeidx + +from .. import kwasm_ast as wast +from .combinators import either, list_of +from .integers import u64 +from .utils import WasmParseError, peek_byte, read_byte, skip + +if TYPE_CHECKING: + from typing import TypeAlias + + from pyk.kast.inner import KInner + + from .utils import InputStream + +# subtype struct in recursive types +# subtype ::= sub final? typeuse* comptype +SubType: TypeAlias = tuple[bool, list[int], 'KInner'] + +# A group of mutually recursive composite types +# rectype :: = rec list(subtype) +RecType: TypeAlias = list[SubType] + + +def rectype(s: InputStream) -> RecType: + n = peek_byte(s) + if n == 0x4E: + skip(1, s) + sts = list_of(parse_subtype, s) + return sts + return [parse_subtype(s)] + + +def parse_subtype(s: InputStream) -> SubType: + n = peek_byte(s) + + if n == 0x4F: + skip(1, s) + xs = list_of(typeidx, s) + ct = parse_comptype(s) + return (True, xs, ct) + if n == 0x50: + skip(1, s) + xs = list_of(typeidx, s) + ct = parse_comptype(s) + return (False, xs, ct) + + ct = parse_comptype(s) + return (True, [], ct) + + +def parse_comptype(s: InputStream) -> KInner: + n = read_byte(s) + if n == 0x60: + ts1 = resulttype(s) + ts2 = resulttype(s) + return wast.func_type(wast.vec_type(wast.val_types(ts1)), wast.vec_type(wast.val_types(ts2))) + + if n == 0x5E: + raise WasmParseError('array composite type is not yet supported') + if n == 0x5F: + raise WasmParseError('struct composite type is not yet supported') + raise WasmParseError(f'composite type not yet supported: {n}') + + +def resulttype(s: InputStream) -> list[KInner]: + return list_of(valtype, s) + + +def valtype(s: InputStream) -> KInner: + d = peek_byte(s) + + try: + return either( + [ + numtype, + reftype, + # TODO implement vectypes + # vectype, + ], + s, + ) + except WasmParseError: + raise WasmParseError(f'Could not parse valtype. Descriptor: {d}') from None + + +def reftype(s: InputStream) -> KInner: + match read_byte(s): + case 0x70: + return wast.funcref + case 0x6F: + return wast.externref + case d: + raise WasmParseError(f'Unsupported reftype descriptor: {d}') + + +def numtype(s: InputStream) -> KInner: + match read_byte(s): + case 0x7C: + return wast.f64 + case 0x7D: + return wast.f32 + case 0x7E: + return wast.i64 + case 0x7F: + return wast.i32 + case d: + raise WasmParseError(f'Invalid numtype descriptor: {d}') + + +Limits: TypeAlias = tuple[int, int | None] + + +# Returns (addressType, (n, m)) +def limits(s: InputStream) -> tuple[KInner, Limits]: + match read_byte(s): + case 0x00: + n = u64(s) + return wast.i32, (n, None) + case 0x01: + n = u64(s) + m = u64(s) + return wast.i32, (n, m) + case 0x02: + n = u64(s) + return wast.i64, (n, None) + case 0x03: + n = u64(s) + m = u64(s) + return wast.i64, (n, m) + case d: + raise WasmParseError(f'Invalid limit descriptor: {d}') + + +# Returns (addressType, limits, reftype) +def tabletype(s: InputStream) -> tuple[KInner, Limits, KInner]: + rt = reftype(s) + at, lim = limits(s) + return at, lim, rt + + +# Returns (addressType, limits) +def memtype(s: InputStream) -> tuple[KInner, Limits]: + return limits(s) + + +def mut(s: InputStream) -> KInner: + match read_byte(s): + case 0x00: + return wast.MUT_CONST + case 0x01: + return wast.MUT_VAR + case d: + raise WasmParseError(f'Invalid mut descriptor. Expected 0x00 or 0x01, got: {d}') + + +def globaltype(s: InputStream) -> KInner: + t = valtype(s) + m = mut(s) + return wast.global_type(m, t) + + +def tagtype(s: InputStream) -> int: + match read_byte(s): + case 0x00: + return typeidx(s) + case d: + raise WasmParseError(f'Invalid tagtype descriptor. Expected 0x00, got: {d}') + + +# TODO this function parses externtype from wasm 3.0 but returns ImportDesc from Wasm 1.0. +# Update the ImportDefn K definition to reflect the changes. +def externtype_as_import_desc(s: InputStream) -> KInner: + match read_byte(s): + case 0x00: + x = typeidx(s) + return wast.func_desc(x) + case 0x01: + _at, lim, _rt = tabletype(s) + return wast.table_desc(lim) + case 0x02: + _at, lim = memtype(s) + return wast.memory_desc(lim) + case 0x03: + gt = globaltype(s) + return wast.global_desc(gt) + case 0x04: + # jt = tagtype(s) + raise WasmParseError('externtype tagtype variant is not yet supported.') + case d: + raise WasmParseError(f'Invalid externtype descriptor. Expected [0x00-0x04], got: {d}') + + +def heaptype(s: InputStream) -> KInner: + match read_byte(s): + case 0x6F: + return wast.HEAPTYPE_EXTERN + case 0x70: + return wast.HEAPTYPE_FUNC + case x: + raise WasmParseError(f'Unsupported heaptype descriptor: {x}') diff --git a/pykwasm/src/pykwasm/binary/utils.py b/pykwasm/src/pykwasm/binary/utils.py new file mode 100644 index 000000000..4efa5d44d --- /dev/null +++ b/pykwasm/src/pykwasm/binary/utils.py @@ -0,0 +1,53 @@ +from typing import BinaryIO, Callable, TypeAlias, TypeVar + +InputStream: TypeAlias = BinaryIO + +A = TypeVar('A') +Parser: TypeAlias = Callable[[InputStream], A] +NParser: TypeAlias = Callable[[int, InputStream], A] + + +def read_bytes(n: int, s: InputStream) -> bytes: + b = s.read(n) + if len(b) != n: + raise WasmEOFError(f'Unexpected EOF. Expected {n} bytes, got {len(b)}') + return b + + +def read_byte(s: InputStream) -> int: + return read_bytes(1, s)[0] + + +def peek_bytes(n: int, s: InputStream) -> bytes: + pos = s.tell() + b = s.read(n) + if len(b) != n: + raise WasmEOFError(f'Unexpected EOF. Expected {n} bytes, got {len(b)}') + reset(pos, s) + return b + + +def peek_byte(s: InputStream) -> int: + return peek_bytes(1, s)[0] + + +def skip(n: int, s: InputStream) -> None: + s.seek(n, 1) + + +def reset(n: int, s: InputStream) -> None: + s.seek(n, 0) + + +def expect_bytes(expected: bytes, s: InputStream) -> None: + bs = read_bytes(len(expected), s) + if bs != expected: + raise WasmParseError(f'Expected {expected!r}, got {bs!r}') + + +class WasmParseError(Exception): + pass + + +class WasmEOFError(WasmParseError): + pass diff --git a/pykwasm/src/pykwasm/kdist/wasm-semantics/test.md b/pykwasm/src/pykwasm/kdist/wasm-semantics/test.md index 06266dac5..26f726dc4 100644 --- a/pykwasm/src/pykwasm/kdist/wasm-semantics/test.md +++ b/pykwasm/src/pykwasm/kdist/wasm-semantics/test.md @@ -321,19 +321,19 @@ The conformance tests contain imports of the `"spectest"` module. start: .EmptyStmts , importDefns: .EmptyStmts , exports: - #export (... name: #token("\"global_i32\"" , "WasmStringToken") , index: 0 ) - #export (... name: #token("\"global_i64\"" , "WasmStringToken") , index: 1 ) - #export (... name: #token("\"global_f32\"" , "WasmStringToken") , index: 2 ) - #export (... name: #token("\"global_f64\"" , "WasmStringToken") , index: 3 ) - #export (... name: #token("\"table\"" , "WasmStringToken") , index: 0 ) - #export (... name: #token("\"memory\"" , "WasmStringToken") , index: 0 ) - #export (... name: #token("\"print\"" , "WasmStringToken") , index: 0 ) - #export (... name: #token("\"print_i32\"" , "WasmStringToken") , index: 1 ) - #export (... name: #token("\"print_i64\"" , "WasmStringToken") , index: 2 ) - #export (... name: #token("\"print_f32\"" , "WasmStringToken") , index: 3 ) - #export (... name: #token("\"print_f64\"" , "WasmStringToken") , index: 4 ) - #export (... name: #token("\"print_i32_f32\"" , "WasmStringToken") , index: 5 ) - #export (... name: #token("\"print_f64_f64\"" , "WasmStringToken") , index: 6 ) + #export (... name: #token("\"global_i32\"" , "WasmStringToken") , index: #externIdxGlobal(0) ) + #export (... name: #token("\"global_i64\"" , "WasmStringToken") , index: #externIdxGlobal(1) ) + #export (... name: #token("\"global_f32\"" , "WasmStringToken") , index: #externIdxGlobal(2) ) + #export (... name: #token("\"global_f64\"" , "WasmStringToken") , index: #externIdxGlobal(3) ) + #export (... name: #token("\"table\"" , "WasmStringToken") , index: #externIdxTable(0) ) + #export (... name: #token("\"memory\"" , "WasmStringToken") , index: #externIdxMemory(0) ) + #export (... name: #token("\"print\"" , "WasmStringToken") , index: #externIdxFunc(0) ) + #export (... name: #token("\"print_i32\"" , "WasmStringToken") , index: #externIdxFunc(1) ) + #export (... name: #token("\"print_i64\"" , "WasmStringToken") , index: #externIdxFunc(2) ) + #export (... name: #token("\"print_f32\"" , "WasmStringToken") , index: #externIdxFunc(3) ) + #export (... name: #token("\"print_f64\"" , "WasmStringToken") , index: #externIdxFunc(4) ) + #export (... name: #token("\"print_i32_f32\"" , "WasmStringToken") , index: #externIdxFunc(5) ) + #export (... name: #token("\"print_f64_f64\"" , "WasmStringToken") , index: #externIdxFunc(6) ) .EmptyStmts , metadata: #meta (... id: , funcIds: .Map , filename: .String ) ) diff --git a/pykwasm/src/pykwasm/kdist/wasm-semantics/wasm-text.md b/pykwasm/src/pykwasm/kdist/wasm-semantics/wasm-text.md index 4b4b3e233..3c259aad5 100644 --- a/pykwasm/src/pykwasm/kdist/wasm-semantics/wasm-text.md +++ b/pykwasm/src/pykwasm/kdist/wasm-semantics/wasm-text.md @@ -998,19 +998,19 @@ Wasm currently supports only one table, so we do not need to resolve any identif Wasm currently supports only one memory, so we do not need to resolve any identifiers. ```k - rule #t2aDefn(( data _:Index (offset IS) DS )) => #data(0, #t2aInstrs(IS), #DS2Bytes(DS)) + rule #t2aDefn(( data _:Index (offset IS) DS )) => #data(#DS2Bytes(DS), #active(0, #t2aInstrs(IS))) ``` #### Exports ```k - rule #t2aDefn(( export ENAME ( func IDENT:Identifier ) )) => #export(ENAME, {IDS[IDENT]}:>Int) requires IDENT in_keys(IDS) - rule #t2aDefn(( export ENAME ( global IDENT:Identifier ) )) => #export(ENAME, {IDS[IDENT]}:>Int) requires IDENT in_keys(IDS) - rule #t2aDefn<_>(( export ENAME ( func I:Int ) )) => #export(ENAME, I) - rule #t2aDefn<_>(( export ENAME ( global I:Int ) )) => #export(ENAME, I) + rule #t2aDefn(( export ENAME ( func IDENT:Identifier ) )) => #export(ENAME, #externIdxFunc({IDS[IDENT]}:>Int)) requires IDENT in_keys(IDS) + rule #t2aDefn(( export ENAME ( global IDENT:Identifier ) )) => #export(ENAME, #externIdxGlobal({IDS[IDENT]}:>Int)) requires IDENT in_keys(IDS) + rule #t2aDefn<_>(( export ENAME ( func I:Int ) )) => #export(ENAME, #externIdxFunc(I)) + rule #t2aDefn<_>(( export ENAME ( global I:Int ) )) => #export(ENAME, #externIdxGlobal(I)) - rule #t2aDefn<_>(( export ENAME ( table _ ) )) => #export(ENAME, 0) - rule #t2aDefn<_>(( export ENAME ( memory _ ) )) => #export(ENAME, 0) + rule #t2aDefn<_>(( export ENAME ( table _ ) )) => #export(ENAME, #externIdxTable(0)) + rule #t2aDefn<_>(( export ENAME ( memory _ ) )) => #export(ENAME, #externIdxMemory(0)) ``` #### Other Definitions diff --git a/pykwasm/src/pykwasm/kdist/wasm-semantics/wasm.md b/pykwasm/src/pykwasm/kdist/wasm-semantics/wasm.md index 8470abaff..b6270c398 100644 --- a/pykwasm/src/pykwasm/kdist/wasm-semantics/wasm.md +++ b/pykwasm/src/pykwasm/kdist/wasm-semantics/wasm.md @@ -1766,11 +1766,14 @@ Memories can be initialized with data, specified as a list of bytes together wit The `data` initializer simply puts these bytes into the specified memory, starting at the offset. ```k - syntax DataDefn ::= #data(index : Int, offset : Instrs, data : Bytes) [symbol(aDataDefn)] - | "data" "{" Int Bytes "}" + syntax DataDefn ::= #data(data: Bytes, mode: DataMode) [symbol(aDataDefn)] + syntax DataMode ::= #active(memidx: Int, offset: Instrs) [symbol(aDataModeActive)] + | "#passive" [symbol(aDataModePassive)] // -------------------------------------------- - // Default to memory 0. - rule #data(IDX, IS, DATA) => sequenceInstrs(IS) ~> data { IDX DATA } ... + syntax KItem ::= "data" "{" Int Bytes "}" + + rule #data(DATA, #active(IDX, IS)) => sequenceInstrs(IS) ~> data { IDX DATA } ... + rule #data(_DATA, #passive) => .K ... rule data { MEMIDX DSBYTES } => trap ... < i32 > OFFSET : _STACK @@ -1787,7 +1790,6 @@ The `data` initializer simply puts these bytes into the specified memory, starti requires OFFSET +Int lengthBytes(DSBYTES) >Int SIZE *Int #pageSize() - // For now, deal only with memory 0. rule data { MEMIDX DSBYTES } => .K ... < i32 > OFFSET : STACK => STACK CUR @@ -1832,7 +1834,12 @@ Export Exports make functions, tables, memories and globals available for importing into other modules. ```k - syntax ExportDefn ::= #export(name : WasmString, index : Int) [symbol(aExportDefn)] + syntax ExportDefn ::= #export(name : WasmString, index : ExternIdx) [symbol(aExportDefn)] + syntax ExternIdx ::= #externIdxFunc(Int) [symbol(aExternIdxFunc)] + | #externIdxTable(Int) [symbol(aExternIdxTable)] + | #externIdxMemory(Int) [symbol(aExternIdxMemory)] + | #externIdxGlobal(Int) [symbol(aExternIdxGlobal)] + | #externIdxTag(Int) [symbol(aExternIdxTag)] syntax Alloc ::= ExportDefn // --------------------------- rule #export(ENAME, IDX) => .K ... diff --git a/pykwasm/src/pykwasm/kwasm_ast.py b/pykwasm/src/pykwasm/kwasm_ast.py index a818e0a35..a0480e8a8 100644 --- a/pykwasm/src/pykwasm/kwasm_ast.py +++ b/pykwasm/src/pykwasm/kwasm_ast.py @@ -53,6 +53,8 @@ ELEM = 'aElemDefn' DATA = 'aDataDefn' +DATAMODE_ACTIVE = 'aDataModeActive' +DATAMODE_PASSIVE = 'aDataModePassive' START = 'aStartDefn' @@ -164,6 +166,9 @@ def idx_to_ref(idx: int | None) -> KInner: MUT_CONST = KApply('mutConst', []) MUT_VAR = KApply('mutVar', []) +HEAPTYPE_FUNC = KApply('func', []) +HEAPTYPE_EXTERN = KApply('extern', []) + def vec_type(valtypes: KInner) -> KInner: return KApply(VEC_TYPE, [valtypes]) @@ -173,7 +178,7 @@ def func_type(params: KInner, results: KInner) -> KInner: return KApply(FUNC_TYPE, [params, results]) -def limits(tup: tuple[int, int]) -> KInner: +def limits(tup: tuple[int, int | None]) -> KInner: i = tup[0] j = tup[1] if j is None: @@ -228,9 +233,9 @@ def CALL(function_idx: int) -> KInner: return KApply('aCall', [KInt(function_idx)]) -def CALL_INDIRECT(type_idx: int) -> KInner: +def CALL_INDIRECT(table_idx: int, type_idx: int) -> KInner: type_use = KApply('aTypeUseIndex', [KInt(type_idx)]) - return KApply('aCall_indirect', [KInt(0), type_use]) + return KApply('aCall_indirect', [KInt(table_idx), type_use]) ########################## @@ -245,8 +250,8 @@ def REF_FUNC(func_idx: int) -> KInner: REF_IS_NULL: KInner = KApply('aRef.is_null', []) -def REF_NULL(t: str) -> KInner: - return KApply('aRef.null', [KApply(t, [])]) +def REF_NULL(heaptype: KInner) -> KInner: + return KApply('aRef.null', [heaptype]) ########################## @@ -626,11 +631,11 @@ def global_desc(global_type: KInner, id: KInner = EMPTY_ID) -> KInner: return KApply(GLOBAL_DESC, [id, global_type]) -def table_desc(lim: tuple[int, int], id: KInner = EMPTY_ID) -> KInner: +def table_desc(lim: tuple[int, int | None], id: KInner = EMPTY_ID) -> KInner: return KApply(TABLE_DESC, [id, limits(lim)]) -def memory_desc(lim: tuple[int, int], id: KInner = EMPTY_ID) -> KInner: +def memory_desc(lim: tuple[int, int | None], id: KInner = EMPTY_ID) -> KInner: return KApply(MEMORY_DESC, [id, limits(lim)]) @@ -647,11 +652,11 @@ def func(type: KInner, locals: KInner, body: KInner, metadata: KInner = EMPTY_FU return KApply(FUNC, [type, locals, body, metadata]) -def table(lim: tuple[int, int], typ: KInner, metadata: KInner = EMPTY_ID) -> KInner: +def table(lim: tuple[int, int | None], typ: KInner, metadata: KInner = EMPTY_ID) -> KInner: return KApply(TABLE, [limits(lim), typ, metadata]) -def memory(lim: tuple[int, int], metadata: KInner = EMPTY_ID) -> KInner: +def memory(lim: tuple[int, int | None], metadata: KInner = EMPTY_ID) -> KInner: return KApply(MEMORY, [limits(lim), metadata]) @@ -675,8 +680,16 @@ def elem(typ: KInner, elem_mode: KInner, init: Iterable[int | None], metadata: K return KApply(ELEM, [typ, refs(init), elem_mode, metadata]) -def data(memory_idx: int, offset: KInner, data: bytes) -> KInner: - return KApply(DATA, [KInt(memory_idx), offset, KBytes(data)]) +def data(bs: bytes, mode: KInner) -> KInner: + return KApply(DATA, [KBytes(bs), mode]) + + +def datamode_active(memidx: int, offset: KInner) -> KInner: + return KApply(DATAMODE_ACTIVE, [KInt(memidx), offset]) + + +def datamode_passive() -> KInner: + return KApply(DATAMODE_PASSIVE, []) def start(start_idx: int) -> KInner: @@ -687,8 +700,28 @@ def imp(mod_name: KInner, name: KInner, import_desc: KInner) -> KInner: return KApply(IMPORT, [mod_name, name, import_desc]) -def export(name: KInner, index: int) -> KInner: - return KApply(EXPORT, [name, KInt(index)]) +def export(name: KInner, index: KInner) -> KInner: + return KApply(EXPORT, [name, index]) + + +def externidx_func(i: int) -> KInner: + return KApply('aExternIdxFunc', [KInt(i)]) + + +def externidx_table(i: int) -> KInner: + return KApply('aExternIdxTable', [KInt(i)]) + + +def externidx_memory(i: int) -> KInner: + return KApply('aExternIdxMemory', [KInt(i)]) + + +def externidx_global(i: int) -> KInner: + return KApply('aExternIdxGlobal', [KInt(i)]) + + +def externidx_tag(i: int) -> KInner: + return KApply('aExternIdxTag', [KInt(i)]) def module_metadata(mid: None = None, fids: None = None, filename: str | None = None) -> KInner: diff --git a/pykwasm/src/pykwasm/wasm2kast.py b/pykwasm/src/pykwasm/wasm2kast.py index 91cea1ff6..d899f346e 100644 --- a/pykwasm/src/pykwasm/wasm2kast.py +++ b/pykwasm/src/pykwasm/wasm2kast.py @@ -10,7 +10,19 @@ from typing import TYPE_CHECKING from wasm import instructions -from wasm.datatypes import GlobalType, MemoryType, Mutability, TableType, TypeIdx, ValType, addresses +from wasm.datatypes import ( + FunctionIdx, + GlobalIdx, + GlobalType, + MemoryIdx, + MemoryType, + Mutability, + TableIdx, + TableType, + TypeIdx, + ValType, + addresses, +) from wasm.datatypes.element_segment import ElemModeActive, ElemModeDeclarative, ElemModePassive from wasm.opcodes import BinaryOpcode from wasm.parsers import parse_module @@ -162,7 +174,8 @@ def elem(e: ElementSegment): def data(d: DataSegment): offset = instrs(d.offset) - return a.data(d.memory_idx, offset, d.init) + mode = a.datamode_active(d.memory_idx, offset) + return a.data(d.init, mode) def start(s: StartFunction): @@ -188,7 +201,16 @@ def imp(i: Import): def export(e: Export): name = a.wasm_string(e.name) - idx = e.desc + if isinstance(e.desc, FunctionIdx): + idx = a.externidx_func(e.desc) + elif isinstance(e.desc, GlobalIdx): + idx = a.externidx_global(e.desc) + elif isinstance(e.desc, MemoryIdx): + idx = a.externidx_memory(e.desc) + elif isinstance(e.desc, TableIdx): + idx = a.externidx_table(e.desc) + else: + raise ValueError(f'Invalid extern index: {e.desc}') return a.export(name, idx) @@ -227,7 +249,7 @@ def instr(i): if i.opcode == B.CALL: return a.CALL(i.function_idx) if i.opcode == B.CALL_INDIRECT: - return a.CALL_INDIRECT(i.type_idx) + return a.CALL_INDIRECT(i.table_idx, i.type_idx) if i.opcode == B.ELSE: raise (ValueError('ELSE opcode: should have been filtered out.')) if i.opcode == B.END: @@ -321,9 +343,9 @@ def instr(i): return a.REF_FUNC(i.funcidx) if isinstance(i, instructions.RefNull): if i.reftype is addresses.FunctionAddress: - return a.REF_NULL('func') + return a.REF_NULL(a.HEAPTYPE_FUNC) if i.reftype is addresses.ExternAddress: - return a.REF_NULL('extern') + return a.REF_NULL(a.HEAPTYPE_EXTERN) raise ValueError(f'Unknown heap type: {i}, {i.reftype}') if isinstance(i, instructions.TableGet): return a.TABLE_GET(i.tableidx) diff --git a/pykwasm/src/tests/integration/test_binary_parser.py b/pykwasm/src/tests/integration/test_binary_parser.py index d0c76404f..20dd00a2e 100644 --- a/pykwasm/src/tests/integration/test_binary_parser.py +++ b/pykwasm/src/tests/integration/test_binary_parser.py @@ -7,18 +7,21 @@ from typing import TYPE_CHECKING import pytest -from pyk.kast.inner import KSequence, KSort, Subst +from pyk.kast.inner import KApply, KLabel, KSequence, KSort, Subst from pyk.kast.manip import split_config_from +from pyk.kast.prelude.utils import token +from pykwasm.binary.module import parse_module from pykwasm.wasm2kast import wasm2kast if TYPE_CHECKING: from pyk.kast import KInner + from pyk.kore.syntax import Pattern from pyk.ktool.krun import KRun BINARY_DIR = Path(__file__).parent / 'binary' -BINARY_WAT_FILES = BINARY_DIR.glob('*.wat') +BINARY_WAT_FILES = list(BINARY_DIR.glob('*.wat')) sys.setrecursionlimit(1500000000) @@ -39,7 +42,49 @@ def test_wasm2kast(krun_llvm: KRun, wat_path: Path) -> None: run_module(krun_llvm, module) -def run_module(krun: KRun, parsed_module: KInner) -> None: +@pytest.mark.parametrize('wat_path', BINARY_WAT_FILES, ids=str) +def test_self_binary_parser(krun_llvm: KRun, wat_path: Path) -> None: + # Given + wat2wasm_cmd = ['wat2wasm', str(wat_path), '--output=/dev/stdout'] + proc_res = run(wat2wasm_cmd, check=True, capture_output=True) + wasm_file = BytesIO(proc_res.stdout) + + assert not proc_res.returncode + + # When + module = parse_module(wasm_file) + + # Then + + # Can convert to Kore successfully + krun_llvm.kast_to_kore(module, KSort('ModuleDecl')) + + # Can run (initiate) successfully + run_module(krun_llvm, module) + + +@pytest.mark.parametrize('wat_path', BINARY_WAT_FILES, ids=str) +def test_diff(krun_llvm: KRun, wat_path: Path) -> None: + # Given + wat2wasm_cmd = ['wat2wasm', str(wat_path), '--output=/dev/stdout'] + proc_res = run(wat2wasm_cmd, check=True, capture_output=True) + + assert not proc_res.returncode + + # When + module_wasm2kast = remove_block_ids(wasm2kast(BytesIO(proc_res.stdout))) + + module_self = parse_module(BytesIO(proc_res.stdout)) + + # Then + assert module_wasm2kast == module_self + + run_wasm2kast = run_module(krun_llvm, module_wasm2kast) + run_self = run_module(krun_llvm, module_self) + assert run_wasm2kast == run_self + + +def run_module(krun: KRun, parsed_module: KInner) -> Pattern: try: # Create an initial config config_kast = krun.definition.init_config(KSort('GeneratedTopCell')) @@ -53,7 +98,23 @@ def run_module(krun: KRun, parsed_module: KInner) -> None: config_kore = krun.kast_to_kore(config_with_module, KSort('GeneratedTopCell')) # Run the config - krun.run_pattern(config_kore) + return krun.run_pattern(config_kore) except Exception as e: raise Exception('Received error while running') from e + + +def remove_block_ids(k: KInner) -> KInner: + match k: + + case KApply(KLabel('aBlock'), (vec_type, instrs, _)): + return KApply('aBlock', (vec_type, remove_block_ids(instrs), token(0))) + + case KApply(KLabel('aIf'), (vec_type, then_instrs, else_instrs, _)): + return KApply('aIf', (vec_type, remove_block_ids(then_instrs), remove_block_ids(else_instrs), token(0))) + + case KApply(KLabel('aLoop'), (vec_type, instrs, _)): + return KApply(KLabel('aLoop'), (vec_type, remove_block_ids(instrs), token(0))) + + case _: + return k.map_inner(remove_block_ids) diff --git a/pykwasm/src/tests/unit/test_binary_parser.py b/pykwasm/src/tests/unit/test_binary_parser.py new file mode 100644 index 000000000..f1f469356 --- /dev/null +++ b/pykwasm/src/tests/unit/test_binary_parser.py @@ -0,0 +1,86 @@ +import io +import struct +import pytest + +from pykwasm.binary.utils import WasmParseError, WasmEOFError +from pykwasm.binary import floats, integers + +def stream(data: bytes) -> io.BytesIO: + """Helper: wrap bytes in a seekable stream.""" + return io.BytesIO(data) + + +class TestFloats: + VALUES = [0.0, 3.14, -1.5, 1.23456789, -9.99, float('inf'), float('-inf')] + + @pytest.mark.parametrize("value", VALUES) + def test_f32(self, value): + encoded = struct.pack(' bytes: + buf = [] + while True: + b = value & 0x7F + value >>= 7 + if value: + buf.append(b | 0x80) + else: + buf.append(b) + break + return bytes(buf) + + @staticmethod + def encode_sleb128(value: int) -> bytes: + buf = [] + while True: + b = value & 0x7F + value >>= 7 + if (value == 0 and b & 0x40 == 0) or (value == -1 and b & 0x40 != 0): + buf.append(b) + break + buf.append(b | 0x80) + return bytes(buf) + + @pytest.mark.parametrize("value", U32_VALUES) + def test_unsigned_32(self, value): + encoded = self.encode_uleb128(value) + assert integers.u32(stream(encoded)) == value + + @pytest.mark.parametrize("value", U64_VALUES) + def test_unsigned_64(self, value): + encoded = self.encode_uleb128(value) + assert integers.u64(stream(encoded)) == value + + @pytest.mark.parametrize("value", I32_VALUES) + def test_uninterpreted_32(self, value): + expected = integers.to_uninterpreted(32, value) + encoded = self.encode_sleb128(value) + assert integers.i32(stream(encoded)) == expected + + @pytest.mark.parametrize("value", I64_VALUES) + def test_uninterpreted_64(self, value): + expected = integers.to_uninterpreted(64, value) + encoded = self.encode_sleb128(value) + assert integers.i64(stream(encoded)) == expected \ No newline at end of file