@@ -19,18 +19,72 @@ mod decl {
1919 use num_bigint:: { BigInt , Sign } ;
2020 use num_traits:: Zero ;
2121
22- const STR_BYTE : u8 = b's' ;
23- const INT_BYTE : u8 = b'i' ;
24- const FLOAT_BYTE : u8 = b'f' ;
25- const TRUE_BYTE : u8 = b'T' ;
26- const FALSE_BYTE : u8 = b'F' ;
27- const LIST_BYTE : u8 = b'[' ;
28- const TUPLE_BYTE : u8 = b'(' ;
29- const DICT_BYTE : u8 = b',' ;
30- const SET_BYTE : u8 = b'~' ;
31- const FROZEN_SET_BYTE : u8 = b'<' ;
32- const BYTE_ARRAY : u8 = b'>' ;
33- const TYPE_CODE : u8 = b'c' ;
22+ #[ repr( u8 ) ]
23+ enum Type {
24+ // Null = b'0',
25+ // None = b'N',
26+ False = b'F' ,
27+ True = b'T' ,
28+ // StopIter = b'S',
29+ // Ellipsis = b'.',
30+ Int = b'i' ,
31+ Float = b'g' ,
32+ // Complex = b'y',
33+ // Long = b'l', // i32
34+ Bytes = b's' , // = TYPE_STRING
35+ // Interned = b't',
36+ // Ref = b'r',
37+ Tuple = b'(' ,
38+ List = b'[' ,
39+ Dict = b'{' ,
40+ Code = b'c' ,
41+ Str = b'u' , // = TYPE_UNICODE
42+ // Unknown = b'?',
43+ Set = b'<' ,
44+ FrozenSet = b'>' ,
45+ // Ascii = b'a',
46+ // AsciiInterned = b'A',
47+ // SmallTuple = b')',
48+ // ShortAscii = b'z',
49+ // ShortAsciiInterned = b'Z',
50+ }
51+ // const FLAG_REF: u8 = b'\x80';
52+
53+ impl TryFrom < u8 > for Type {
54+ type Error = u8 ;
55+ fn try_from ( value : u8 ) -> Result < Self , u8 > {
56+ use Type :: * ;
57+ Ok ( match value {
58+ // b'0' => Null,
59+ // b'N' => None,
60+ b'F' => False ,
61+ b'T' => True ,
62+ // b'S' => StopIter,
63+ // b'.' => Ellipsis,
64+ b'i' => Int ,
65+ b'g' => Float ,
66+ // b'y' => Complex,
67+ // b'l' => Long,
68+ b's' => Bytes ,
69+ // b't' => Interned,
70+ // b'r' => Ref,
71+ b'(' => Tuple ,
72+ b'[' => List ,
73+ b'{' => Dict ,
74+ b'c' => Code ,
75+ b'u' => Str ,
76+ // b'?' => Unknown,
77+ b'<' => Set ,
78+ b'>' => FrozenSet ,
79+ // b'a' => Ascii,
80+ // b'A' => AsciiInterned,
81+ // b')' => SmallTuple,
82+ // b'z' => ShortAscii,
83+ // b'Z' => ShortAsciiInterned,
84+ c => return Err ( c) ,
85+ } )
86+ }
87+ }
3488
3589 fn too_short_error ( vm : & VirtualMachine ) -> PyBaseExceptionRef {
3690 vm. new_exception_msg (
@@ -59,13 +113,13 @@ mod decl {
59113 pyint @ PyInt => {
60114 if pyint. class( ) . is( vm. ctx. types. bool_type) {
61115 let typ = if pyint. as_bigint( ) . is_zero( ) {
62- FALSE_BYTE
116+ Type :: False
63117 } else {
64- TRUE_BYTE
118+ Type :: True
65119 } ;
66- buf. push( typ) ;
120+ buf. push( typ as u8 ) ;
67121 } else {
68- buf. push( INT_BYTE ) ;
122+ buf. push( Type :: Int as u8 ) ;
69123 let ( sign, int_bytes) = pyint. as_bigint( ) . to_bytes_le( ) ;
70124 let mut len = int_bytes. len( ) as i32 ;
71125 if sign == Sign :: Minus {
@@ -76,49 +130,49 @@ mod decl {
76130 }
77131 }
78132 pyfloat @ PyFloat => {
79- buf. push( FLOAT_BYTE ) ;
133+ buf. push( Type :: Float as u8 ) ;
80134 buf. extend( pyfloat. to_f64( ) . to_le_bytes( ) ) ;
81135 }
82136 pystr @ PyStr => {
83- buf. push( STR_BYTE ) ;
137+ buf. push( Type :: Str as u8 ) ;
84138 write_size( buf, pystr. as_str( ) . len( ) , vm) ?;
85139 buf. extend( pystr. as_str( ) . as_bytes( ) ) ;
86140 }
87141 pylist @ PyList => {
88- buf. push( LIST_BYTE ) ;
142+ buf. push( Type :: List as u8 ) ;
89143 let pylist_items = pylist. borrow_vec( ) ;
90144 dump_seq( buf, pylist_items. iter( ) , vm) ?;
91145 }
92146 pyset @ PySet => {
93- buf. push( SET_BYTE ) ;
147+ buf. push( Type :: Set as u8 ) ;
94148 let elements = pyset. elements( ) ;
95149 dump_seq( buf, elements. iter( ) , vm) ?;
96150 }
97151 pyfrozen @ PyFrozenSet => {
98- buf. push( FROZEN_SET_BYTE ) ;
152+ buf. push( Type :: FrozenSet as u8 ) ;
99153 let elements = pyfrozen. elements( ) ;
100154 dump_seq( buf, elements. iter( ) , vm) ?;
101155 }
102156 pytuple @ PyTuple => {
103- buf. push( TUPLE_BYTE ) ;
157+ buf. push( Type :: Tuple as u8 ) ;
104158 dump_seq( buf, pytuple. iter( ) , vm) ?;
105159 }
106160 pydict @ PyDict => {
107- buf. push( DICT_BYTE ) ;
161+ buf. push( Type :: Dict as u8 ) ;
108162 write_size( buf, pydict. len( ) , vm) ?;
109163 for ( key, value) in pydict {
110164 dump_obj( buf, key, vm) ?;
111165 dump_obj( buf, value, vm) ?;
112166 }
113167 }
114168 bytes @ PyByteArray => {
115- buf. push( BYTE_ARRAY ) ;
169+ buf. push( Type :: Bytes as u8 ) ;
116170 let data = bytes. borrow_buf( ) ;
117171 write_size( buf, data. len( ) , vm) ?;
118172 buf. extend( & * data) ;
119173 }
120174 co @ PyCode => {
121- buf. push( TYPE_CODE ) ;
175+ buf. push( Type :: Code as u8 ) ;
122176 let bytes = co. code. map_clone_bag( & bytecode:: BasicBag ) . to_bytes( ) ;
123177 write_size( buf, bytes. len( ) , vm) ?;
124178 buf. extend( bytes) ;
@@ -191,10 +245,12 @@ mod decl {
191245
192246 fn load_obj < ' b > ( buf : & ' b [ u8 ] , vm : & VirtualMachine ) -> PyResult < ( PyObjectRef , & ' b [ u8 ] ) > {
193247 let ( type_indicator, buf) = buf. split_first ( ) . ok_or_else ( || too_short_error ( vm) ) ?;
194- let ( obj, buf) = match * type_indicator {
195- TRUE_BYTE => ( ( true ) . to_pyobject ( vm) , buf) ,
196- FALSE_BYTE => ( ( false ) . to_pyobject ( vm) , buf) ,
197- INT_BYTE => {
248+ let typ = Type :: try_from ( * type_indicator)
249+ . map_err ( |_| vm. new_value_error ( "bad marshal data (unknown type code)" . to_owned ( ) ) ) ?;
250+ let ( obj, buf) = match typ {
251+ Type :: True => ( ( true ) . to_pyobject ( vm) , buf) ,
252+ Type :: False => ( ( false ) . to_pyobject ( vm) , buf) ,
253+ Type :: Int => {
198254 if buf. len ( ) < 4 {
199255 return Err ( too_short_error ( vm) ) ;
200256 }
@@ -212,15 +268,15 @@ mod decl {
212268 let int = BigInt :: from_bytes_le ( sign, bytes) ;
213269 ( int. to_pyobject ( vm) , buf)
214270 }
215- FLOAT_BYTE => {
271+ Type :: Float => {
216272 if buf. len ( ) < 8 {
217273 return Err ( too_short_error ( vm) ) ;
218274 }
219275 let ( bytes, buf) = buf. split_at ( 8 ) ;
220276 let number = f64:: from_le_bytes ( bytes. try_into ( ) . unwrap ( ) ) ;
221277 ( vm. ctx . new_float ( number) . into ( ) , buf)
222278 }
223- STR_BYTE => {
279+ Type :: Str => {
224280 let ( len, buf) = read_size ( buf, vm) ?;
225281 if buf. len ( ) < len {
226282 return Err ( too_short_error ( vm) ) ;
@@ -230,28 +286,28 @@ mod decl {
230286 . map_err ( |_| vm. new_value_error ( "invalid utf8 data" . to_owned ( ) ) ) ?;
231287 ( s. to_pyobject ( vm) , buf)
232288 }
233- LIST_BYTE => {
289+ Type :: List => {
234290 let ( elements, buf) = load_seq ( buf, vm) ?;
235291 ( vm. ctx . new_list ( elements) . into ( ) , buf)
236292 }
237- SET_BYTE => {
293+ Type :: Set => {
238294 let ( elements, buf) = load_seq ( buf, vm) ?;
239295 let set = PySet :: new_ref ( & vm. ctx ) ;
240296 for element in elements {
241297 set. add ( element, vm) ?;
242298 }
243299 ( set. to_pyobject ( vm) , buf)
244300 }
245- FROZEN_SET_BYTE => {
301+ Type :: FrozenSet => {
246302 let ( elements, buf) = load_seq ( buf, vm) ?;
247303 let set = PyFrozenSet :: from_iter ( vm, elements. into_iter ( ) ) ?;
248304 ( set. to_pyobject ( vm) , buf)
249305 }
250- TUPLE_BYTE => {
306+ Type :: Tuple => {
251307 let ( elements, buf) = load_seq ( buf, vm) ?;
252308 ( vm. ctx . new_tuple ( elements) . into ( ) , buf)
253309 }
254- DICT_BYTE => {
310+ Type :: Dict => {
255311 let ( len, mut buf) = read_size ( buf, vm) ?;
256312 let dict = vm. ctx . new_dict ( ) ;
257313 for _ in 0 ..len {
@@ -262,7 +318,7 @@ mod decl {
262318 }
263319 ( dict. into ( ) , buf)
264320 }
265- BYTE_ARRAY => {
321+ Type :: Bytes => {
266322 // Following CPython, after marshaling, byte arrays are converted into bytes.
267323 let ( len, buf) = read_size ( buf, vm) ?;
268324 if buf. len ( ) < len {
@@ -271,7 +327,7 @@ mod decl {
271327 let ( bytes, buf) = buf. split_at ( len) ;
272328 ( vm. ctx . new_bytes ( bytes. to_vec ( ) ) . into ( ) , buf)
273329 }
274- TYPE_CODE => {
330+ Type :: Code => {
275331 // If prefix is not identifiable, assume CodeObject, error out if it doesn't match.
276332 let ( len, buf) = read_size ( buf, vm) ?;
277333 if buf. len ( ) < len {
@@ -287,7 +343,6 @@ mod decl {
287343 } ) ?;
288344 ( vm. ctx . new_code ( code) . into ( ) , buf)
289345 }
290- _ => return Err ( vm. new_value_error ( "bad marshal data (unknown type code)" . to_owned ( ) ) ) ,
291346 } ;
292347 Ok ( ( obj, buf) )
293348 }
0 commit comments