Skip to content

Commit aa58cd2

Browse files
committed
Use enum marshal::Type instead of u8
1 parent a353d5e commit aa58cd2

File tree

1 file changed

+94
-39
lines changed

1 file changed

+94
-39
lines changed

vm/src/stdlib/marshal.rs

Lines changed: 94 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,72 @@ mod decl {
1919
use num_bigint::{BigInt, Sign};
2020
use num_traits::Zero;
2121

22-
const STR_BYTE: u8 = b's';
23-
const INT_BYTE: u8 = b'i';
24-
const FLOAT_BYTE: u8 = b'f';
25-
const TRUE_BYTE: u8 = b'T';
26-
const FALSE_BYTE: u8 = b'F';
27-
const LIST_BYTE: u8 = b'[';
28-
const TUPLE_BYTE: u8 = b'(';
29-
const DICT_BYTE: u8 = b',';
30-
const SET_BYTE: u8 = b'~';
31-
const FROZEN_SET_BYTE: u8 = b'<';
32-
const BYTE_ARRAY: u8 = b'>';
33-
const TYPE_CODE: u8 = b'c';
22+
#[repr(u8)]
23+
enum Type {
24+
// Null = b'0',
25+
// None = b'N',
26+
False = b'F',
27+
True = b'T',
28+
// StopIter = b'S',
29+
// Ellipsis = b'.',
30+
Int = b'i',
31+
Float = b'g',
32+
// Complex = b'y',
33+
// Long = b'l', // i32
34+
Bytes = b's', // = TYPE_STRING
35+
// Interned = b't',
36+
// Ref = b'r',
37+
Tuple = b'(',
38+
List = b'[',
39+
Dict = b'{',
40+
Code = b'c',
41+
Str = b'u', // = TYPE_UNICODE
42+
// Unknown = b'?',
43+
Set = b'<',
44+
FrozenSet = b'>',
45+
// Ascii = b'a',
46+
// AsciiInterned = b'A',
47+
// SmallTuple = b')',
48+
// ShortAscii = b'z',
49+
// ShortAsciiInterned = b'Z',
50+
}
51+
// const FLAG_REF: u8 = b'\x80';
52+
53+
impl TryFrom<u8> for Type {
54+
type Error = u8;
55+
fn try_from(value: u8) -> Result<Self, u8> {
56+
use Type::*;
57+
Ok(match value {
58+
// b'0' => Null,
59+
// b'N' => None,
60+
b'F' => False,
61+
b'T' => True,
62+
// b'S' => StopIter,
63+
// b'.' => Ellipsis,
64+
b'i' => Int,
65+
b'g' => Float,
66+
// b'y' => Complex,
67+
// b'l' => Long,
68+
b's' => Bytes,
69+
// b't' => Interned,
70+
// b'r' => Ref,
71+
b'(' => Tuple,
72+
b'[' => List,
73+
b'{' => Dict,
74+
b'c' => Code,
75+
b'u' => Str,
76+
// b'?' => Unknown,
77+
b'<' => Set,
78+
b'>' => FrozenSet,
79+
// b'a' => Ascii,
80+
// b'A' => AsciiInterned,
81+
// b')' => SmallTuple,
82+
// b'z' => ShortAscii,
83+
// b'Z' => ShortAsciiInterned,
84+
c => return Err(c),
85+
})
86+
}
87+
}
3488

3589
fn too_short_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
3690
vm.new_exception_msg(
@@ -59,13 +113,13 @@ mod decl {
59113
pyint @ PyInt => {
60114
if pyint.class().is(vm.ctx.types.bool_type) {
61115
let typ = if pyint.as_bigint().is_zero() {
62-
FALSE_BYTE
116+
Type::False
63117
} else {
64-
TRUE_BYTE
118+
Type::True
65119
};
66-
buf.push(typ);
120+
buf.push(typ as u8);
67121
} else {
68-
buf.push(INT_BYTE);
122+
buf.push(Type::Int as u8);
69123
let (sign, int_bytes) = pyint.as_bigint().to_bytes_le();
70124
let mut len = int_bytes.len() as i32;
71125
if sign == Sign::Minus {
@@ -76,49 +130,49 @@ mod decl {
76130
}
77131
}
78132
pyfloat @ PyFloat => {
79-
buf.push(FLOAT_BYTE);
133+
buf.push(Type::Float as u8);
80134
buf.extend(pyfloat.to_f64().to_le_bytes());
81135
}
82136
pystr @ PyStr => {
83-
buf.push(STR_BYTE);
137+
buf.push(Type::Str as u8);
84138
write_size(buf, pystr.as_str().len(), vm)?;
85139
buf.extend(pystr.as_str().as_bytes());
86140
}
87141
pylist @ PyList => {
88-
buf.push(LIST_BYTE);
142+
buf.push(Type::List as u8);
89143
let pylist_items = pylist.borrow_vec();
90144
dump_seq(buf, pylist_items.iter(), vm)?;
91145
}
92146
pyset @ PySet => {
93-
buf.push(SET_BYTE);
147+
buf.push(Type::Set as u8);
94148
let elements = pyset.elements();
95149
dump_seq(buf, elements.iter(), vm)?;
96150
}
97151
pyfrozen @ PyFrozenSet => {
98-
buf.push(FROZEN_SET_BYTE);
152+
buf.push(Type::FrozenSet as u8);
99153
let elements = pyfrozen.elements();
100154
dump_seq(buf, elements.iter(), vm)?;
101155
}
102156
pytuple @ PyTuple => {
103-
buf.push(TUPLE_BYTE);
157+
buf.push(Type::Tuple as u8);
104158
dump_seq(buf, pytuple.iter(), vm)?;
105159
}
106160
pydict @ PyDict => {
107-
buf.push(DICT_BYTE);
161+
buf.push(Type::Dict as u8);
108162
write_size(buf, pydict.len(), vm)?;
109163
for (key, value) in pydict {
110164
dump_obj(buf, key, vm)?;
111165
dump_obj(buf, value, vm)?;
112166
}
113167
}
114168
bytes @ PyByteArray => {
115-
buf.push(BYTE_ARRAY);
169+
buf.push(Type::Bytes as u8);
116170
let data = bytes.borrow_buf();
117171
write_size(buf, data.len(), vm)?;
118172
buf.extend(&*data);
119173
}
120174
co @ PyCode => {
121-
buf.push(TYPE_CODE);
175+
buf.push(Type::Code as u8);
122176
let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes();
123177
write_size(buf, bytes.len(), vm)?;
124178
buf.extend(bytes);
@@ -191,10 +245,12 @@ mod decl {
191245

192246
fn load_obj<'b>(buf: &'b [u8], vm: &VirtualMachine) -> PyResult<(PyObjectRef, &'b [u8])> {
193247
let (type_indicator, buf) = buf.split_first().ok_or_else(|| too_short_error(vm))?;
194-
let (obj, buf) = match *type_indicator {
195-
TRUE_BYTE => ((true).to_pyobject(vm), buf),
196-
FALSE_BYTE => ((false).to_pyobject(vm), buf),
197-
INT_BYTE => {
248+
let typ = Type::try_from(*type_indicator)
249+
.map_err(|_| vm.new_value_error("bad marshal data (unknown type code)".to_owned()))?;
250+
let (obj, buf) = match typ {
251+
Type::True => ((true).to_pyobject(vm), buf),
252+
Type::False => ((false).to_pyobject(vm), buf),
253+
Type::Int => {
198254
if buf.len() < 4 {
199255
return Err(too_short_error(vm));
200256
}
@@ -212,15 +268,15 @@ mod decl {
212268
let int = BigInt::from_bytes_le(sign, bytes);
213269
(int.to_pyobject(vm), buf)
214270
}
215-
FLOAT_BYTE => {
271+
Type::Float => {
216272
if buf.len() < 8 {
217273
return Err(too_short_error(vm));
218274
}
219275
let (bytes, buf) = buf.split_at(8);
220276
let number = f64::from_le_bytes(bytes.try_into().unwrap());
221277
(vm.ctx.new_float(number).into(), buf)
222278
}
223-
STR_BYTE => {
279+
Type::Str => {
224280
let (len, buf) = read_size(buf, vm)?;
225281
if buf.len() < len {
226282
return Err(too_short_error(vm));
@@ -230,28 +286,28 @@ mod decl {
230286
.map_err(|_| vm.new_value_error("invalid utf8 data".to_owned()))?;
231287
(s.to_pyobject(vm), buf)
232288
}
233-
LIST_BYTE => {
289+
Type::List => {
234290
let (elements, buf) = load_seq(buf, vm)?;
235291
(vm.ctx.new_list(elements).into(), buf)
236292
}
237-
SET_BYTE => {
293+
Type::Set => {
238294
let (elements, buf) = load_seq(buf, vm)?;
239295
let set = PySet::new_ref(&vm.ctx);
240296
for element in elements {
241297
set.add(element, vm)?;
242298
}
243299
(set.to_pyobject(vm), buf)
244300
}
245-
FROZEN_SET_BYTE => {
301+
Type::FrozenSet => {
246302
let (elements, buf) = load_seq(buf, vm)?;
247303
let set = PyFrozenSet::from_iter(vm, elements.into_iter())?;
248304
(set.to_pyobject(vm), buf)
249305
}
250-
TUPLE_BYTE => {
306+
Type::Tuple => {
251307
let (elements, buf) = load_seq(buf, vm)?;
252308
(vm.ctx.new_tuple(elements).into(), buf)
253309
}
254-
DICT_BYTE => {
310+
Type::Dict => {
255311
let (len, mut buf) = read_size(buf, vm)?;
256312
let dict = vm.ctx.new_dict();
257313
for _ in 0..len {
@@ -262,7 +318,7 @@ mod decl {
262318
}
263319
(dict.into(), buf)
264320
}
265-
BYTE_ARRAY => {
321+
Type::Bytes => {
266322
// Following CPython, after marshaling, byte arrays are converted into bytes.
267323
let (len, buf) = read_size(buf, vm)?;
268324
if buf.len() < len {
@@ -271,7 +327,7 @@ mod decl {
271327
let (bytes, buf) = buf.split_at(len);
272328
(vm.ctx.new_bytes(bytes.to_vec()).into(), buf)
273329
}
274-
TYPE_CODE => {
330+
Type::Code => {
275331
// If prefix is not identifiable, assume CodeObject, error out if it doesn't match.
276332
let (len, buf) = read_size(buf, vm)?;
277333
if buf.len() < len {
@@ -287,7 +343,6 @@ mod decl {
287343
})?;
288344
(vm.ctx.new_code(code).into(), buf)
289345
}
290-
_ => return Err(vm.new_value_error("bad marshal data (unknown type code)".to_owned())),
291346
};
292347
Ok((obj, buf))
293348
}

0 commit comments

Comments
 (0)