Skip to content

Commit 404c398

Browse files
evanrittenhouseyouknowone
authored andcommitted
Implement strict_mode keyword for binascii.a2b_base64
1 parent d7f65cb commit 404c398

File tree

2 files changed

+110
-24
lines changed

2 files changed

+110
-24
lines changed

Lib/test/test_binascii.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,6 @@ def addnoise(line):
114114
# empty strings. TBD: shouldn't it raise an exception instead ?
115115
self.assertEqual(binascii.a2b_base64(self.type2test(fillers)), b'')
116116

117-
# TODO: RUSTPYTHON
118-
@unittest.expectedFailure
119117
def test_base64_strict_mode(self):
120118
# Test base64 with strict mode on
121119
def _assertRegexTemplate(assert_regex: str, data: bytes, non_strict_mode_expected_result: bytes):

stdlib/src/binascii.rs

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,8 @@ pub(crate) use decl::make_module;
22

33
pub(super) use decl::crc32;
44

5-
pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
6-
base64::decode_config(input, base64::STANDARD.decode_allow_trailing_bits(true))
7-
}
8-
95
#[pymodule(name = "binascii")]
106
mod decl {
11-
use super::decode;
127
use crate::vm::{
138
builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef},
149
function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg},
@@ -148,9 +143,20 @@ mod decl {
148143
vm.new_exception_msg(error_type(vm), msg)
149144
}
150145

146+
#[derive(FromArgs)]
147+
struct A2bBase64Args {
148+
#[pyarg(any)]
149+
s: ArgAsciiBuffer,
150+
#[pyarg(named, default = "false")]
151+
strict_mode: bool,
152+
}
153+
151154
#[pyfunction]
152-
fn a2b_base64(s: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
155+
fn a2b_base64(args: A2bBase64Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
153156
#[rustfmt::skip]
157+
// Converts between ASCII and base-64 characters. The index of a given number yields the
158+
// number in ASCII while the value of said index yields the number in base-64. For example
159+
// "=" is 61 in ASCII but 0 (since it's the pad character) in base-64, so BASE64_TABLE[61] == 0
154160
const BASE64_TABLE: [i8; 256] = [
155161
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
156162
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
@@ -171,25 +177,107 @@ mod decl {
171177
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
172178
];
173179

180+
const PAD: u8 = 61u8;
181+
182+
let A2bBase64Args { s, strict_mode } = args;
174183
s.with_ref(|b| {
175-
let decoded = if b.len() % 4 == 0 {
176-
decode(b)
177-
} else {
178-
Err(base64::DecodeError::InvalidLength)
179-
};
180-
decoded.or_else(|_| {
181-
let buf: Vec<_> = b
182-
.iter()
183-
.copied()
184-
.filter(|&c| BASE64_TABLE[c as usize] != -1)
185-
.collect();
186-
if buf.len() % 4 != 0 {
187-
return Err(base64::DecodeError::InvalidLength);
184+
if b.is_empty() {
185+
return Ok(vec![]);
186+
}
187+
188+
if strict_mode && b[0] == PAD {
189+
return Err(base64::DecodeError::InvalidByte(0, 61));
190+
}
191+
192+
let mut decoded: Vec<u8> = vec![];
193+
194+
let mut quad_pos = 0; // position in the nibble
195+
let mut pads = 0;
196+
let mut left_char: u8 = 0;
197+
let mut padding_started = false;
198+
for (i, &el) in b.iter().enumerate() {
199+
if el == PAD {
200+
padding_started = true;
201+
202+
pads += 1;
203+
if quad_pos >= 2 && quad_pos + pads >= 4 {
204+
if strict_mode && i + 1 < b.len() {
205+
// Represents excess data after padding error
206+
return Err(base64::DecodeError::InvalidLastSymbol(i, el));
207+
}
208+
209+
return Ok(decoded);
210+
}
211+
212+
continue;
188213
}
189-
decode(&buf)
190-
})
214+
215+
let binary_char = BASE64_TABLE[el as usize];
216+
if binary_char >= 64 || binary_char == -1 {
217+
if strict_mode {
218+
// Represents non-base64 data error
219+
return Err(base64::DecodeError::InvalidByte(i, el));
220+
}
221+
continue;
222+
}
223+
224+
if strict_mode && padding_started {
225+
// Represents discontinuous padding error
226+
return Err(base64::DecodeError::InvalidByte(i, PAD));
227+
}
228+
pads = 0;
229+
230+
// Decode individual ASCII character
231+
if quad_pos == 0 {
232+
quad_pos = 1;
233+
left_char = binary_char as u8;
234+
} else if quad_pos == 1 {
235+
quad_pos = 2;
236+
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
237+
left_char = (binary_char & 0x0f) as u8;
238+
} else if quad_pos == 2 {
239+
quad_pos = 3;
240+
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
241+
left_char = (binary_char & 0x03) as u8;
242+
} else if quad_pos == 3 {
243+
quad_pos = 0;
244+
decoded.push((left_char << 6) | binary_char as u8);
245+
left_char = 0;
246+
}
247+
}
248+
249+
if quad_pos == 1 {
250+
// Ensure that a PAD never gets passed, since that'd mistakenly cause an excess
251+
// data after padding error
252+
return Err(base64::DecodeError::InvalidLastSymbol(
253+
decoded.len() / 3 * 4 + 1,
254+
0,
255+
));
256+
} else if quad_pos > 1 {
257+
return Err(base64::DecodeError::InvalidLength);
258+
}
259+
260+
Ok(decoded)
261+
})
262+
.map_err(|err| {
263+
let python_error = match err {
264+
base64::DecodeError::InvalidByte(0, PAD) => {
265+
String::from("Leading padding not allowed")
266+
}
267+
base64::DecodeError::InvalidByte(_, PAD) => {
268+
String::from("Discontinuous padding not allowed")
269+
}
270+
base64::DecodeError::InvalidByte(_, _) => {
271+
String::from("Only base64 data is allowed")
272+
}
273+
base64::DecodeError::InvalidLastSymbol(_, _) => {
274+
String::from("Excess data after padding")
275+
}
276+
base64::DecodeError::InvalidLength => String::from("Not implemented (yet)"),
277+
};
278+
279+
new_binascii_error(format!("error decoding base64: {python_error}"), vm)
191280
})
192-
.map_err(|err| new_binascii_error(format!("error decoding base64: {err}"), vm))
193281
}
194282

195283
#[pyfunction]

0 commit comments

Comments
 (0)