1+ pub ( super ) use decl:: crc32;
12pub ( crate ) use decl:: make_module;
3+ use rustpython_vm:: { builtins:: PyBaseExceptionRef , convert:: ToPyException , VirtualMachine } ;
24
3- pub ( super ) use decl:: crc32;
4-
5- pub fn decode < T : AsRef < [ u8 ] > > ( input : T ) -> Result < Vec < u8 > , base64:: DecodeError > {
6- base64:: decode_config ( input, base64:: STANDARD . decode_allow_trailing_bits ( true ) )
7- }
5+ const PAD : u8 = 61u8 ;
6+ const MAXLINESIZE : usize = 76 ; // Excluding the CRLF
87
98#[ pymodule( name = "binascii" ) ]
109mod decl {
11- use super :: decode ;
10+ use super :: { MAXLINESIZE , PAD } ;
1211 use crate :: vm:: {
13- builtins:: { PyBaseExceptionRef , PyIntRef , PyTypeRef } ,
12+ builtins:: { PyIntRef , PyTypeRef } ,
13+ convert:: ToPyException ,
1414 function:: { ArgAsciiBuffer , ArgBytesLike , OptionalArg } ,
1515 PyResult , VirtualMachine ,
1616 } ;
1717 use itertools:: Itertools ;
1818
19- const MAXLINESIZE : usize = 76 ;
20-
2119 #[ pyattr( name = "Error" , once) ]
22- fn error_type ( vm : & VirtualMachine ) -> PyTypeRef {
20+ pub ( super ) fn error_type ( vm : & VirtualMachine ) -> PyTypeRef {
2321 vm. ctx . new_exception_type (
2422 "binascii" ,
2523 "Error" ,
@@ -67,15 +65,18 @@ mod decl {
6765 fn unhexlify ( data : ArgAsciiBuffer , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
6866 data. with_ref ( |hex_bytes| {
6967 if hex_bytes. len ( ) % 2 != 0 {
70- return Err ( new_binascii_error ( "Odd-length string" . to_owned ( ) , vm) ) ;
68+ return Err ( super :: new_binascii_error (
69+ "Odd-length string" . to_owned ( ) ,
70+ vm,
71+ ) ) ;
7172 }
7273
7374 let mut unhex = Vec :: < u8 > :: with_capacity ( hex_bytes. len ( ) / 2 ) ;
7475 for ( n1, n2) in hex_bytes. iter ( ) . tuples ( ) {
7576 if let ( Some ( n1) , Some ( n2) ) = ( unhex_nibble ( * n1) , unhex_nibble ( * n2) ) {
7677 unhex. push ( n1 << 4 | n2) ;
7778 } else {
78- return Err ( new_binascii_error (
79+ return Err ( super :: new_binascii_error (
7980 "Non-hexadecimal digit found" . to_owned ( ) ,
8081 vm,
8182 ) ) ;
@@ -144,13 +145,20 @@ mod decl {
144145 newline : bool ,
145146 }
146147
147- fn new_binascii_error ( msg : String , vm : & VirtualMachine ) -> PyBaseExceptionRef {
148- vm. new_exception_msg ( error_type ( vm) , msg)
148+ #[ derive( FromArgs ) ]
149+ struct A2bBase64Args {
150+ #[ pyarg( any) ]
151+ s : ArgAsciiBuffer ,
152+ #[ pyarg( named, default = "false" ) ]
153+ strict_mode : bool ,
149154 }
150155
151156 #[ pyfunction]
152- fn a2b_base64 ( s : ArgAsciiBuffer , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
157+ fn a2b_base64 ( args : A2bBase64Args , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
153158 #[ rustfmt:: skip]
159+ // Converts between ASCII and base-64 characters. The index of a given number yields the
160+ // number in ASCII while the value of said index yields the number in base-64. For example
161+ // "=" is 61 in ASCII but 0 (since it's the pad character) in base-64, so BASE64_TABLE[61] == 0
154162 const BASE64_TABLE : [ i8 ; 256 ] = [
155163 -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
156164 -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
@@ -171,25 +179,92 @@ mod decl {
171179 -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
172180 ] ;
173181
182+ let A2bBase64Args { s, strict_mode } = args;
174183 s. with_ref ( |b| {
175- let decoded = if b. len ( ) % 4 == 0 {
176- decode ( b)
177- } else {
178- Err ( base64:: DecodeError :: InvalidLength )
179- } ;
180- decoded. or_else ( |_| {
181- let buf: Vec < _ > = b
182- . iter ( )
183- . copied ( )
184- . filter ( |& c| BASE64_TABLE [ c as usize ] != -1 )
185- . collect ( ) ;
186- if buf. len ( ) % 4 != 0 {
187- return Err ( base64:: DecodeError :: InvalidLength ) ;
184+ if b. is_empty ( ) {
185+ return Ok ( vec ! [ ] ) ;
186+ }
187+
188+ if strict_mode && b[ 0 ] == PAD {
189+ return Err ( base64:: DecodeError :: InvalidByte ( 0 , 61 ) ) ;
190+ }
191+
192+ let mut decoded: Vec < u8 > = vec ! [ ] ;
193+
194+ let mut quad_pos = 0 ; // position in the nibble
195+ let mut pads = 0 ;
196+ let mut left_char: u8 = 0 ;
197+ let mut padding_started = false ;
198+ for ( i, & el) in b. iter ( ) . enumerate ( ) {
199+ if el == PAD {
200+ padding_started = true ;
201+
202+ pads += 1 ;
203+ if quad_pos >= 2 && quad_pos + pads >= 4 {
204+ if strict_mode && i + 1 < b. len ( ) {
205+ // Represents excess data after padding error
206+ return Err ( base64:: DecodeError :: InvalidLastSymbol ( i, PAD ) ) ;
207+ }
208+
209+ return Ok ( decoded) ;
210+ }
211+
212+ continue ;
188213 }
189- decode ( & buf)
190- } )
214+
215+ let binary_char = BASE64_TABLE [ el as usize ] ;
216+ if binary_char >= 64 || binary_char == -1 {
217+ if strict_mode {
218+ // Represents non-base64 data error
219+ return Err ( base64:: DecodeError :: InvalidByte ( i, el) ) ;
220+ }
221+ continue ;
222+ }
223+
224+ if strict_mode && padding_started {
225+ // Represents discontinuous padding error
226+ return Err ( base64:: DecodeError :: InvalidByte ( i, PAD ) ) ;
227+ }
228+ pads = 0 ;
229+
230+ // Decode individual ASCII character
231+ match quad_pos {
232+ 0 => {
233+ quad_pos = 1 ;
234+ left_char = binary_char as u8 ;
235+ }
236+ 1 => {
237+ quad_pos = 2 ;
238+ decoded. push ( ( left_char << 2 ) | ( binary_char >> 4 ) as u8 ) ;
239+ left_char = ( binary_char & 0x0f ) as u8 ;
240+ }
241+ 2 => {
242+ quad_pos = 3 ;
243+ decoded. push ( ( left_char << 4 ) | ( binary_char >> 2 ) as u8 ) ;
244+ left_char = ( binary_char & 0x03 ) as u8 ;
245+ }
246+ 3 => {
247+ quad_pos = 0 ;
248+ decoded. push ( ( left_char << 6 ) | binary_char as u8 ) ;
249+ left_char = 0 ;
250+ }
251+ _ => unsafe {
252+ // quad_pos is only assigned in this match statement to constants
253+ std:: hint:: unreachable_unchecked ( )
254+ } ,
255+ }
256+ }
257+
258+ match quad_pos {
259+ 0 => Ok ( decoded) ,
260+ 1 => Err ( base64:: DecodeError :: InvalidLastSymbol (
261+ decoded. len ( ) / 3 * 4 + 1 ,
262+ 0 ,
263+ ) ) ,
264+ _ => Err ( base64:: DecodeError :: InvalidLength ) ,
265+ }
191266 } )
192- . map_err ( |err| new_binascii_error ( format ! ( "error decoding base64: { err}" ) , vm) )
267+ . map_err ( |err| super :: Base64DecodeError ( err) . to_pyexception ( vm) )
193268 }
194269
195270 #[ pyfunction]
@@ -654,3 +729,26 @@ mod decl {
654729 } )
655730 }
656731}
732+
733+ struct Base64DecodeError ( base64:: DecodeError ) ;
734+
735+ fn new_binascii_error ( msg : String , vm : & VirtualMachine ) -> PyBaseExceptionRef {
736+ vm. new_exception_msg ( decl:: error_type ( vm) , msg)
737+ }
738+
739+ impl ToPyException for Base64DecodeError {
740+ fn to_pyexception ( & self , vm : & VirtualMachine ) -> PyBaseExceptionRef {
741+ use base64:: DecodeError :: * ;
742+ let message = match self . 0 {
743+ InvalidByte ( 0 , PAD ) => "Leading padding not allowed" . to_owned ( ) ,
744+ InvalidByte ( _, PAD ) => "Discontinuous padding not allowed" . to_owned ( ) ,
745+ InvalidByte ( _, _) => "Only base64 data is allowed" . to_owned ( ) ,
746+ InvalidLastSymbol ( _, PAD ) => "Excess data after padding" . to_owned ( ) ,
747+ InvalidLastSymbol ( length, _) => {
748+ format ! ( "Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4" , length)
749+ }
750+ InvalidLength => "Incorrect padding" . to_owned ( ) ,
751+ } ;
752+ new_binascii_error ( format ! ( "error decoding base64: {message}" ) , vm)
753+ }
754+ }
0 commit comments