@@ -2,13 +2,8 @@ pub(crate) use decl::make_module;
22
33pub ( super ) use decl:: crc32;
44
5- pub fn decode < T : AsRef < [ u8 ] > > ( input : T ) -> Result < Vec < u8 > , base64:: DecodeError > {
6- base64:: decode_config ( input, base64:: STANDARD . decode_allow_trailing_bits ( true ) )
7- }
8-
95#[ pymodule( name = "binascii" ) ]
106mod decl {
11- use super :: decode;
127 use crate :: vm:: {
138 builtins:: { PyBaseExceptionRef , PyIntRef , PyTypeRef } ,
149 function:: { ArgAsciiBuffer , ArgBytesLike , OptionalArg } ,
@@ -148,9 +143,20 @@ mod decl {
148143 vm. new_exception_msg ( error_type ( vm) , msg)
149144 }
150145
146+ #[ derive( FromArgs ) ]
147+ struct A2bBase64Args {
148+ #[ pyarg( any) ]
149+ s : ArgAsciiBuffer ,
150+ #[ pyarg( named, default = "false" ) ]
151+ strict_mode : bool ,
152+ }
153+
151154 #[ pyfunction]
152- fn a2b_base64 ( s : ArgAsciiBuffer , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
155+ fn a2b_base64 ( args : A2bBase64Args , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
153156 #[ rustfmt:: skip]
157+ // Converts between ASCII and base-64 characters. The index of a given number yields the
158+ // number in ASCII while the value of said index yields the number in base-64. For example
159+ // "=" is 61 in ASCII but 0 (since it's the pad character) in base-64, so BASE64_TABLE[61] == 0
154160 const BASE64_TABLE : [ i8 ; 256 ] = [
155161 -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
156162 -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
@@ -171,25 +177,107 @@ mod decl {
171177 -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
172178 ] ;
173179
180+ const PAD : u8 = 61u8 ;
181+
182+ let A2bBase64Args { s, strict_mode } = args;
174183 s. with_ref ( |b| {
175- let decoded = if b. len ( ) % 4 == 0 {
176- decode ( b)
177- } else {
178- Err ( base64:: DecodeError :: InvalidLength )
179- } ;
180- decoded. or_else ( |_| {
181- let buf: Vec < _ > = b
182- . iter ( )
183- . copied ( )
184- . filter ( |& c| BASE64_TABLE [ c as usize ] != -1 )
185- . collect ( ) ;
186- if buf. len ( ) % 4 != 0 {
187- return Err ( base64:: DecodeError :: InvalidLength ) ;
184+ if b. is_empty ( ) {
185+ return Ok ( vec ! [ ] ) ;
186+ }
187+
188+ if strict_mode && b[ 0 ] == PAD {
189+ return Err ( base64:: DecodeError :: InvalidByte ( 0 , 61 ) ) ;
190+ }
191+
192+ let mut decoded: Vec < u8 > = vec ! [ ] ;
193+
194+ let mut quad_pos = 0 ; // position in the nibble
195+ let mut pads = 0 ;
196+ let mut left_char: u8 = 0 ;
197+ let mut padding_started = false ;
198+ for ( i, & el) in b. iter ( ) . enumerate ( ) {
199+ if el == PAD {
200+ padding_started = true ;
201+
202+ pads += 1 ;
203+ if quad_pos >= 2 && quad_pos + pads >= 4 {
204+ if strict_mode && i + 1 < b. len ( ) {
205+ // Represents excess data after padding error
206+ return Err ( base64:: DecodeError :: InvalidLastSymbol ( i, el) ) ;
207+ }
208+
209+ return Ok ( decoded) ;
210+ }
211+
212+ continue ;
188213 }
189- decode ( & buf)
190- } )
214+
215+ let binary_char = BASE64_TABLE [ el as usize ] ;
216+ if binary_char >= 64 || binary_char == -1 {
217+ if strict_mode {
218+ // Represents non-base64 data error
219+ return Err ( base64:: DecodeError :: InvalidByte ( i, el) ) ;
220+ }
221+ continue ;
222+ }
223+
224+ if strict_mode && padding_started {
225+ // Represents discontinuous padding error
226+ return Err ( base64:: DecodeError :: InvalidByte ( i, PAD ) ) ;
227+ }
228+ pads = 0 ;
229+
230+ // Decode individual ASCII character
231+ if quad_pos == 0 {
232+ quad_pos = 1 ;
233+ left_char = binary_char as u8 ;
234+ } else if quad_pos == 1 {
235+ quad_pos = 2 ;
236+ decoded. push ( ( left_char << 2 ) | ( binary_char >> 4 ) as u8 ) ;
237+ left_char = ( binary_char & 0x0f ) as u8 ;
238+ } else if quad_pos == 2 {
239+ quad_pos = 3 ;
240+ decoded. push ( ( left_char << 4 ) | ( binary_char >> 2 ) as u8 ) ;
241+ left_char = ( binary_char & 0x03 ) as u8 ;
242+ } else if quad_pos == 3 {
243+ quad_pos = 0 ;
244+ decoded. push ( ( left_char << 6 ) | binary_char as u8 ) ;
245+ left_char = 0 ;
246+ }
247+ }
248+
249+ if quad_pos == 1 {
250+ // Ensure that a PAD never gets passed, since that'd mistakenly cause an excess
251+ // data after padding error
252+ return Err ( base64:: DecodeError :: InvalidLastSymbol (
253+ decoded. len ( ) / 3 * 4 + 1 ,
254+ 0 ,
255+ ) ) ;
256+ } else if quad_pos > 1 {
257+ return Err ( base64:: DecodeError :: InvalidLength ) ;
258+ }
259+
260+ Ok ( decoded)
261+ } )
262+ . map_err ( |err| {
263+ let python_error = match err {
264+ base64:: DecodeError :: InvalidByte ( 0 , PAD ) => {
265+ String :: from ( "Leading padding not allowed" )
266+ }
267+ base64:: DecodeError :: InvalidByte ( _, PAD ) => {
268+ String :: from ( "Discontinuous padding not allowed" )
269+ }
270+ base64:: DecodeError :: InvalidByte ( _, _) => {
271+ String :: from ( "Only base64 data is allowed" )
272+ }
273+ base64:: DecodeError :: InvalidLastSymbol ( _, _) => {
274+ String :: from ( "Excess data after padding" )
275+ }
276+ base64:: DecodeError :: InvalidLength => String :: from ( "Not implemented (yet)" ) ,
277+ } ;
278+
279+ new_binascii_error ( format ! ( "error decoding base64: {python_error}" ) , vm)
191280 } )
192- . map_err ( |err| new_binascii_error ( format ! ( "error decoding base64: {err}" ) , vm) )
193281 }
194282
195283 #[ pyfunction]
0 commit comments