Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 18 additions & 25 deletions xdis/unmarsh_graal.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def graal_readBigInteger(self):
Reads a marshaled big integer from the input stream.
"""
negative = False
sz = self.graal_readInt() # Get the size in shorts
sz = self.read_uint32() # Get the size in shorts
if sz < 0:
negative = True
sz = -sz
Expand Down Expand Up @@ -277,7 +277,7 @@ def graal_readBooleanArray(self) -> tuple[bool, ...]:
Python equivalent of Python Graal's readBooleanArray() from
MarshalModuleBuiltins.java
"""
length: int = int(unpack("<i", self.fp.read(4))[0])
length: int = self.read_uint32()
return tuple([bool(self.graal_readByte()) for _ in range(length)])

def graal_readByte(self) -> int:
Expand All @@ -292,7 +292,7 @@ def graal_readBytes(self) -> bytes:
Python equivalent of Python Graal's readBytes() from
MarshalModuleBuiltins.java
"""
length: int = unpack("<i", self.fp.read(4))[0]
length: int = self.read_uint32()
return bytes([self.graal_readByte() for _ in range(length)])

def graal_readDouble(self) -> float:
Expand All @@ -307,23 +307,16 @@ def graal_readDoubleArray(self) -> tuple[float, ...]:
Python equivalent of Python Graal's readDoubleArray() from
MarshalModuleBuiltins.java
"""
length: int = int(unpack("<i", self.fp.read(4))[0])
length: int = self.read_uint32()
return tuple([self.graal_readDouble() for _ in range(length)])

def graal_readInt(self) -> int:
"""
Python equivalent of Python Graal's readInt() from
MarshalModuleBuiltins.java
"""
return int(unpack("<i", self.fp.read(4))[0])

def graal_readIntArray(self) -> tuple[int, ...]:
"""
Python equivalent of Python Graal's readIntArray() from
MarshalModuleBuiltins.java
"""
length: int = int(unpack("<i", self.fp.read(4))[0])
return tuple([self.graal_readInt() for _ in range(length)])
return tuple([self.read_int32() for _ in range(length)])

def graal_readLong(self) -> int:
"""
Expand Down Expand Up @@ -370,18 +363,18 @@ def graal_readStringArray(self) -> tuple[str, ...]:
Python equvalent of Python Graal's readObjectArray() from
MarshalModuleBuiltins.java
"""
length: int = self.graal_readInt()
length: int = self.read_uint32()
return tuple([self.graal_readString() for _ in range(length)])

def graal_readSparseTable(self) -> Dict[int, tuple]:
"""
Python equvalent of Python Graal's readObjectArray() from
MarshalModuleBuiltins.java
"""
self.graal_readInt() # the length return value isn't used.
self.read_uint32() # the length return value isn't used.
table = {} # new int[length][];
while True:
i = self.graal_readInt()
i = self.read_int32()
if i == -1:
return table
table[i] = self.graal_readIntArray()
Expand Down Expand Up @@ -494,15 +487,15 @@ def t_graal_CodeUnit(self, save_ref, bytes_for_s: bool = False):

co_name = self.graal_readString()
co_qualname = self.graal_readString()
co_argcount = self.graal_readInt()
co_kwonlyargcount = self.graal_readInt()
co_posonlyargcount = self.graal_readInt()
co_argcount = self.read_uint32()
co_kwonlyargcount = self.read_uint32()
co_posonlyargcount = self.read_uint32()

co_stacksize = self.graal_readInt()
co_stacksize = self.read_uint32()
co_code_offset_in_file = self.fp.tell()
co_code = self.graal_readBytes()
other_fields["srcOffsetTable"] = self.graal_readBytes()
co_flags = self.graal_readInt()
co_flags = self.read_uint32()

# writeStringArray(code.names);
# writeStringArray(code.varnames);
Expand Down Expand Up @@ -542,11 +535,11 @@ def t_graal_CodeUnit(self, save_ref, bytes_for_s: bool = False):

other_fields["primitiveConstants"] = self.graal_readLongArray()
other_fields["exception_handler_ranges"] = self.graal_readIntArray()
other_fields["condition_profileCount"] = self.graal_readInt()
other_fields["startLine"] = self.graal_readInt()
other_fields["startColumn"] = self.graal_readInt()
other_fields["endLine"] = self.graal_readInt()
other_fields["endColumn"] = self.graal_readInt()
other_fields["condition_profileCount"] = self.read_uint32()
other_fields["startLine"] = self.read_uint32()
other_fields["startColumn"] = self.read_uint32()
other_fields["endLine"] = self.read_uint32()
other_fields["endColumn"] = self.read_uint32()
other_fields["outputCanQuicken"] = self.graal_readBytes()
other_fields["variableShouldUnbox"] = self.graal_readBytes()
other_fields["generalizeInputsMap"] = self.graal_readSparseTable()
Expand Down
13 changes: 0 additions & 13 deletions xdis/unmarsh_rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
object.
"""

from struct import unpack
from typing import Any, Dict, List, Tuple, Union

from xdis.codetype.code313rust import Code313Rust, SourceLocation
Expand Down Expand Up @@ -277,18 +276,6 @@ def t_bigint(self, save_ref: bool=False, bytes_for_s: bool=False):
value = int.from_bytes(byte_data, byteorder='little')
return value if is_positive else -value

def read_int16(self):
return int(unpack("<h", self.fp.read(2))[0])

def read_int32(self):
return int(unpack("<i", self.fp.read(4))[0])

def read_slice(self, n: int) -> bytes:
return self.fp.read(n)

def read_uint32(self):
return int(unpack("<I", self.fp.read(4))[0])

def read_string(self, n: int, bytes_for_s: bool=False) -> Union[bytes, str]:
s = self.read_slice(n)
if not bytes_for_s:
Expand Down
69 changes: 42 additions & 27 deletions xdis/unmarshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ def __init__(self, fp, magic_int, bytes_for_s, code_objects={}) -> None:

self.UNMARSHAL_DISPATCH_TABLE = UNMARSHAL_DISPATCH_TABLE

def read_int16(self) -> int:
return int(unpack("<h", self.fp.read(2))[0])

def read_int32(self) -> int:
return int(unpack("<i", self.fp.read(4))[0])

def read_int64(self) -> int:
return int(unpack("<q", self.fp.read(8))[0])

def read_slice(self, n: int) -> bytes:
return self.fp.read(n)

def read_uint32(self) -> int:
return int(unpack("<I", self.fp.read(4))[0])

def load(self):
"""
``marshal.load()`` written in Python. When the Python bytecode magic loaded is the
Expand Down Expand Up @@ -302,16 +317,16 @@ def t_True(self, save_ref, bytes_for_s: bool = False) -> bool:
return True

def t_int32(self, save_ref, bytes_for_s: bool = False):
return self.r_ref(int(unpack("<i", self.fp.read(4))[0]), save_ref)
return self.r_ref(self.read_int32(), save_ref)

def t_long(self, save_ref, bytes_for_s: bool = False):
n = unpack("<i", self.fp.read(4))[0]
n = self.read_uint32()
if n == 0:
return long(0)
size = abs(n)
d = long(0)
for j in range(0, size):
md = int(unpack("<h", self.fp.read(2))[0])
md = self.read_int16()
# This operation and turn "d" from a long back
# into an int.
d += md << j * 15
Expand All @@ -323,7 +338,7 @@ def t_long(self, save_ref, bytes_for_s: bool = False):

# Python 3.4 removed this.
def t_int64(self, save_ref, bytes_for_s: bool = False):
obj = unpack("<q", self.fp.read(8))[0]
obj = self.read_int64()
if save_ref:
self.intern_objects.append(obj)
return obj
Expand All @@ -342,7 +357,7 @@ def unpack_pre_24() -> float:
return float(self.fp.read(unpack("B", self.fp.read(1))[0]))

def unpack_newer() -> float:
return float(self.fp.read(unpack("<i", self.fp.read(4))[0]))
return float(self.fp.read(self.read_int32()))

get_float = unpack_pre_24 if self.magic_int <= 62061 else unpack_newer

Expand All @@ -363,7 +378,7 @@ def t_string(self, save_ref, bytes_for_s: bool):
In Python3, this is a ``bytes`` type. In Python2, it is a string type;
``bytes_for_s`` is True when a Python 3 interpreter is reading Python 2 bytecode.
"""
strsize = unpack("<i", self.fp.read(4))[0]
strsize = self.read_uint32()
s = self.fp.read(strsize)
if not bytes_for_s:
s = compat_str(s)
Expand All @@ -377,7 +392,7 @@ def t_ASCII_interned(self, save_ref, bytes_for_s: bool = False):
the string.
"""
# FIXME: check
strsize = unpack("<i", self.fp.read(4))[0]
strsize = self.read_uint32()
interned = compat_str(self.fp.read(strsize))
self.intern_strings.append(interned)
return self.r_ref(interned, save_ref)
Expand All @@ -388,7 +403,7 @@ def t_ASCII(self, save_ref, bytes_for_s: bool = False):
There are true strings in Python3 as opposed to
bytes.
"""
strsize = unpack("<i", self.fp.read(4))[0]
strsize = self.read_uint32()
s = self.fp.read(strsize)
s = compat_str(s)
return self.r_ref(s, save_ref)
Expand All @@ -407,13 +422,13 @@ def t_short_ASCII_interned(self, save_ref, bytes_for_s: bool = False):
return self.r_ref(interned, save_ref)

def t_interned(self, save_ref, bytes_for_s: bool = False):
strsize = unpack("<i", self.fp.read(4))[0]
strsize = self.read_uint32()
interned = compat_str(self.fp.read(strsize))
self.intern_strings.append(interned)
return self.r_ref(interned, save_ref)

def t_unicode(self, save_ref, bytes_for_s: bool = False):
strsize = unpack("<i", self.fp.read(4))[0]
strsize = self.read_uint32()
unicodestring = self.fp.read(strsize)
if self.version_triple < (3, 0):
string = UnicodeForPython3(unicodestring)
Expand All @@ -434,7 +449,7 @@ def t_small_tuple(self, save_ref, bytes_for_s: bool = False):
return self.r_ref_insert(ret, i)

def t_tuple(self, save_ref, bytes_for_s: bool = False):
tuplesize = unpack("<i", self.fp.read(4))[0]
tuplesize = self.read_uint32()
ret = self.r_ref(tuple(), save_ref)
while tuplesize > 0:
ret += (self.r_object(bytes_for_s=bytes_for_s),)
Expand All @@ -443,15 +458,15 @@ def t_tuple(self, save_ref, bytes_for_s: bool = False):

def t_list(self, save_ref, bytes_for_s: bool = False):
# FIXME: check me
n = unpack("<i", self.fp.read(4))[0]
n = self.read_uint32()
ret = self.r_ref(list(), save_ref)
while n > 0:
ret += (self.r_object(bytes_for_s=bytes_for_s),)
n -= 1
return ret

def t_frozenset(self, save_ref, bytes_for_s: bool = False):
setsize = unpack("<i", self.fp.read(4))[0]
setsize = self.read_uint32()
collection, i = self.r_ref_reserve([], save_ref)
while setsize > 0:
collection.append(self.r_object(bytes_for_s=bytes_for_s))
Expand All @@ -462,7 +477,7 @@ def t_frozenset(self, save_ref, bytes_for_s: bool = False):
return self.r_ref_insert(final_frozenset, i)

def t_set(self, save_ref, bytes_for_s: bool = False):
setsize = unpack("<i", self.fp.read(4))[0]
setsize = self.read_uint32()
ret, i = self.r_ref_reserve(tuple(), save_ref)
while setsize > 0:
ret += (self.r_object(bytes_for_s=bytes_for_s),)
Expand All @@ -484,7 +499,7 @@ def t_dict(self, save_ref, bytes_for_s: bool = False):
return ret

def t_python2_string_reference(self, save_ref, bytes_for_s: bool = False):
refnum = unpack("<i", self.fp.read(4))[0]
refnum = self.read_uint32()
return self.intern_strings[refnum]

def t_slice(self, save_ref, bytes_for_s: bool = False):
Expand Down Expand Up @@ -522,9 +537,9 @@ def t_code(self, save_ref, bytes_for_s: bool = False):
self.version_triple = magic_int2tuple(self.magic_int)

if self.version_triple >= (2, 3):
co_argcount = unpack("<i", self.fp.read(4))[0]
co_argcount = self.read_uint32()
elif self.version_triple >= (1, 3):
co_argcount = unpack("<h", self.fp.read(2))[0]
co_argcount = self.read_int16()
else:
co_argcount = 0

Expand All @@ -538,7 +553,7 @@ def t_code(self, save_ref, bytes_for_s: bool = False):
co_posonlyargcount = None

if self.version_triple >= (3, 0):
kwonlyargcount = unpack("<i", self.fp.read(4))[0]
kwonlyargcount = self.read_uint32()
else:
kwonlyargcount = 0

Expand All @@ -547,21 +562,21 @@ def t_code(self, save_ref, bytes_for_s: bool = False):
self.version_triple[:2] == (3, 11) and self.is_pypy
):
if self.version_triple >= (2, 3):
co_nlocals = unpack("<i", self.fp.read(4))[0]
co_nlocals = self.read_uint32()
elif self.version_triple >= (1, 3):
co_nlocals = unpack("<h", self.fp.read(2))[0]
co_nlocals = self.read_int16()

if self.version_triple >= (2, 3):
co_stacksize = unpack("<i", self.fp.read(4))[0]
co_stacksize = self.read_uint32()
elif self.version_triple >= (1, 5):
co_stacksize = unpack("<h", self.fp.read(2))[0]
co_stacksize = self.read_int16()
else:
co_stacksize = 0

if self.version_triple >= (2, 3):
co_flags = unpack("<i", self.fp.read(4))[0]
co_flags = self.read_uint32()
elif self.version_triple >= (1, 3):
co_flags = unpack("<h", self.fp.read(2))[0]
co_flags = self.read_int16()
else:
co_flags = 0

Expand Down Expand Up @@ -628,9 +643,9 @@ def t_code(self, save_ref, bytes_for_s: bool = False):

if self.version_triple >= (1, 5):
if self.version_triple >= (2, 3):
co_firstlineno = unpack("<i", self.fp.read(4))[0]
co_firstlineno = self.read_int32()
else:
co_firstlineno = unpack("<h", self.fp.read(2))[0]
co_firstlineno = self.read_int16()

if self.version_triple >= (3, 11) and not self.is_pypy:
co_linetable = self.r_object(bytes_for_s=bytes_for_s)
Expand Down Expand Up @@ -775,7 +790,7 @@ def t_code_old(self, _, bytes_for_s: bool = False):

# Since Python 3.4
def t_object_reference(self, save_ref=None, bytes_for_s: bool = False):
refnum = unpack("<i", self.fp.read(4))[0]
refnum = self.read_uint32()
return self.intern_objects[refnum]

def t_unknown(self, save_ref=None, bytes_for_s: bool = False):
Expand Down