Skip to content

Commit 4e19be7

Browse files
authored
Merge pull request RustPython#4460 from evanrittenhouse/strict_mode_new
Implement `strict_mode` keyword for `binascii.a2b_base64()`
2 parents 1f92212 + ff973ca commit 4e19be7

File tree

2 files changed

+129
-37
lines changed

2 files changed

+129
-37
lines changed

Lib/test/test_binascii.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ def test_base64valid(self):
7575
res += b
7676
self.assertEqual(res, self.rawdata)
7777

78-
# TODO: RUSTPYTHON
79-
@unittest.expectedFailure
8078
def test_base64invalid(self):
8179
# Test base64 with random invalid characters sprinkled throughout
8280
# (This requires a new version of binascii.)
@@ -114,8 +112,6 @@ def addnoise(line):
114112
# empty strings. TBD: shouldn't it raise an exception instead ?
115113
self.assertEqual(binascii.a2b_base64(self.type2test(fillers)), b'')
116114

117-
# TODO: RUSTPYTHON
118-
@unittest.expectedFailure
119115
def test_base64_strict_mode(self):
120116
# Test base64 with strict mode on
121117
def _assertRegexTemplate(assert_regex: str, data: bytes, non_strict_mode_expected_result: bytes):
@@ -159,8 +155,6 @@ def assertDiscontinuousPadding(data, non_strict_mode_expected_result: bytes):
159155
assertDiscontinuousPadding(b'ab=c=', b'i\xb7')
160156
assertDiscontinuousPadding(b'ab=ab==', b'i\xb6\x9b')
161157

162-
# TODO: RUSTPYTHON
163-
@unittest.expectedFailure
164158
def test_base64errors(self):
165159
# Test base64 with invalid padding
166160
def assertIncorrectPadding(data):

stdlib/src/binascii.rs

Lines changed: 129 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
1+
pub(super) use decl::crc32;
12
pub(crate) use decl::make_module;
3+
use rustpython_vm::{builtins::PyBaseExceptionRef, convert::ToPyException, VirtualMachine};
24

3-
pub(super) use decl::crc32;
4-
5-
pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
6-
base64::decode_config(input, base64::STANDARD.decode_allow_trailing_bits(true))
7-
}
5+
const PAD: u8 = 61u8;
6+
const MAXLINESIZE: usize = 76; // Excluding the CRLF
87

98
#[pymodule(name = "binascii")]
109
mod decl {
11-
use super::decode;
10+
use super::{MAXLINESIZE, PAD};
1211
use crate::vm::{
13-
builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef},
12+
builtins::{PyIntRef, PyTypeRef},
13+
convert::ToPyException,
1414
function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg},
1515
PyResult, VirtualMachine,
1616
};
1717
use itertools::Itertools;
1818

19-
const MAXLINESIZE: usize = 76;
20-
2119
#[pyattr(name = "Error", once)]
22-
fn error_type(vm: &VirtualMachine) -> PyTypeRef {
20+
pub(super) fn error_type(vm: &VirtualMachine) -> PyTypeRef {
2321
vm.ctx.new_exception_type(
2422
"binascii",
2523
"Error",
@@ -67,15 +65,18 @@ mod decl {
6765
fn unhexlify(data: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
6866
data.with_ref(|hex_bytes| {
6967
if hex_bytes.len() % 2 != 0 {
70-
return Err(new_binascii_error("Odd-length string".to_owned(), vm));
68+
return Err(super::new_binascii_error(
69+
"Odd-length string".to_owned(),
70+
vm,
71+
));
7172
}
7273

7374
let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2);
7475
for (n1, n2) in hex_bytes.iter().tuples() {
7576
if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) {
7677
unhex.push(n1 << 4 | n2);
7778
} else {
78-
return Err(new_binascii_error(
79+
return Err(super::new_binascii_error(
7980
"Non-hexadecimal digit found".to_owned(),
8081
vm,
8182
));
@@ -144,13 +145,20 @@ mod decl {
144145
newline: bool,
145146
}
146147

147-
fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
148-
vm.new_exception_msg(error_type(vm), msg)
148+
#[derive(FromArgs)]
149+
struct A2bBase64Args {
150+
#[pyarg(any)]
151+
s: ArgAsciiBuffer,
152+
#[pyarg(named, default = "false")]
153+
strict_mode: bool,
149154
}
150155

151156
#[pyfunction]
152-
fn a2b_base64(s: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
157+
fn a2b_base64(args: A2bBase64Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
153158
#[rustfmt::skip]
159+
// Converts between ASCII and base-64 characters. The index of a given number yields the
160+
// number in ASCII while the value of said index yields the number in base-64. For example
161+
// "=" is 61 in ASCII but 0 (since it's the pad character) in base-64, so BASE64_TABLE[61] == 0
154162
const BASE64_TABLE: [i8; 256] = [
155163
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
156164
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
@@ -171,25 +179,92 @@ mod decl {
171179
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
172180
];
173181

182+
let A2bBase64Args { s, strict_mode } = args;
174183
s.with_ref(|b| {
175-
let decoded = if b.len() % 4 == 0 {
176-
decode(b)
177-
} else {
178-
Err(base64::DecodeError::InvalidLength)
179-
};
180-
decoded.or_else(|_| {
181-
let buf: Vec<_> = b
182-
.iter()
183-
.copied()
184-
.filter(|&c| BASE64_TABLE[c as usize] != -1)
185-
.collect();
186-
if buf.len() % 4 != 0 {
187-
return Err(base64::DecodeError::InvalidLength);
184+
if b.is_empty() {
185+
return Ok(vec![]);
186+
}
187+
188+
if strict_mode && b[0] == PAD {
189+
return Err(base64::DecodeError::InvalidByte(0, 61));
190+
}
191+
192+
let mut decoded: Vec<u8> = vec![];
193+
194+
let mut quad_pos = 0; // position in the nibble
195+
let mut pads = 0;
196+
let mut left_char: u8 = 0;
197+
let mut padding_started = false;
198+
for (i, &el) in b.iter().enumerate() {
199+
if el == PAD {
200+
padding_started = true;
201+
202+
pads += 1;
203+
if quad_pos >= 2 && quad_pos + pads >= 4 {
204+
if strict_mode && i + 1 < b.len() {
205+
// Represents excess data after padding error
206+
return Err(base64::DecodeError::InvalidLastSymbol(i, PAD));
207+
}
208+
209+
return Ok(decoded);
210+
}
211+
212+
continue;
188213
}
189-
decode(&buf)
190-
})
214+
215+
let binary_char = BASE64_TABLE[el as usize];
216+
if binary_char >= 64 || binary_char == -1 {
217+
if strict_mode {
218+
// Represents non-base64 data error
219+
return Err(base64::DecodeError::InvalidByte(i, el));
220+
}
221+
continue;
222+
}
223+
224+
if strict_mode && padding_started {
225+
// Represents discontinuous padding error
226+
return Err(base64::DecodeError::InvalidByte(i, PAD));
227+
}
228+
pads = 0;
229+
230+
// Decode individual ASCII character
231+
match quad_pos {
232+
0 => {
233+
quad_pos = 1;
234+
left_char = binary_char as u8;
235+
}
236+
1 => {
237+
quad_pos = 2;
238+
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
239+
left_char = (binary_char & 0x0f) as u8;
240+
}
241+
2 => {
242+
quad_pos = 3;
243+
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
244+
left_char = (binary_char & 0x03) as u8;
245+
}
246+
3 => {
247+
quad_pos = 0;
248+
decoded.push((left_char << 6) | binary_char as u8);
249+
left_char = 0;
250+
}
251+
_ => unsafe {
252+
// quad_pos is only assigned in this match statement to constants
253+
std::hint::unreachable_unchecked()
254+
},
255+
}
256+
}
257+
258+
match quad_pos {
259+
0 => Ok(decoded),
260+
1 => Err(base64::DecodeError::InvalidLastSymbol(
261+
decoded.len() / 3 * 4 + 1,
262+
0,
263+
)),
264+
_ => Err(base64::DecodeError::InvalidLength),
265+
}
191266
})
192-
.map_err(|err| new_binascii_error(format!("error decoding base64: {err}"), vm))
267+
.map_err(|err| super::Base64DecodeError(err).to_pyexception(vm))
193268
}
194269

195270
#[pyfunction]
@@ -654,3 +729,26 @@ mod decl {
654729
})
655730
}
656731
}
732+
733+
struct Base64DecodeError(base64::DecodeError);
734+
735+
fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
736+
vm.new_exception_msg(decl::error_type(vm), msg)
737+
}
738+
739+
impl ToPyException for Base64DecodeError {
740+
fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
741+
use base64::DecodeError::*;
742+
let message = match self.0 {
743+
InvalidByte(0, PAD) => "Leading padding not allowed".to_owned(),
744+
InvalidByte(_, PAD) => "Discontinuous padding not allowed".to_owned(),
745+
InvalidByte(_, _) => "Only base64 data is allowed".to_owned(),
746+
InvalidLastSymbol(_, PAD) => "Excess data after padding".to_owned(),
747+
InvalidLastSymbol(length, _) => {
748+
format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length)
749+
}
750+
InvalidLength => "Incorrect padding".to_owned(),
751+
};
752+
new_binascii_error(format!("error decoding base64: {message}"), vm)
753+
}
754+
}

0 commit comments

Comments
 (0)