Skip to content

Commit ff973ca

Browse files
committed
ToPyException for base64::DecodeError
1 parent 362be9f commit ff973ca

File tree

1 file changed

+68
-54
lines changed

1 file changed

+68
-54
lines changed

stdlib/src/binascii.rs

Lines changed: 68 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1+
pub(super) use decl::crc32;
12
pub(crate) use decl::make_module;
3+
use rustpython_vm::{builtins::PyBaseExceptionRef, convert::ToPyException, VirtualMachine};
24

3-
pub(super) use decl::crc32;
5+
const PAD: u8 = 61u8;
6+
const MAXLINESIZE: usize = 76; // Excluding the CRLF
47

58
#[pymodule(name = "binascii")]
69
mod decl {
10+
use super::{MAXLINESIZE, PAD};
711
use crate::vm::{
8-
builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef},
12+
builtins::{PyIntRef, PyTypeRef},
13+
convert::ToPyException,
914
function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg},
1015
PyResult, VirtualMachine,
1116
};
1217
use itertools::Itertools;
1318

14-
const MAXLINESIZE: usize = 76;
15-
1619
#[pyattr(name = "Error", once)]
17-
fn error_type(vm: &VirtualMachine) -> PyTypeRef {
20+
pub(super) fn error_type(vm: &VirtualMachine) -> PyTypeRef {
1821
vm.ctx.new_exception_type(
1922
"binascii",
2023
"Error",
@@ -62,15 +65,18 @@ mod decl {
6265
fn unhexlify(data: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
6366
data.with_ref(|hex_bytes| {
6467
if hex_bytes.len() % 2 != 0 {
65-
return Err(new_binascii_error("Odd-length string".to_owned(), vm));
68+
return Err(super::new_binascii_error(
69+
"Odd-length string".to_owned(),
70+
vm,
71+
));
6672
}
6773

6874
let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2);
6975
for (n1, n2) in hex_bytes.iter().tuples() {
7076
if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) {
7177
unhex.push(n1 << 4 | n2);
7278
} else {
73-
return Err(new_binascii_error(
79+
return Err(super::new_binascii_error(
7480
"Non-hexadecimal digit found".to_owned(),
7581
vm,
7682
));
@@ -139,10 +145,6 @@ mod decl {
139145
newline: bool,
140146
}
141147

142-
fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
143-
vm.new_exception_msg(error_type(vm), msg)
144-
}
145-
146148
#[derive(FromArgs)]
147149
struct A2bBase64Args {
148150
#[pyarg(any)]
@@ -177,8 +179,6 @@ mod decl {
177179
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
178180
];
179181

180-
const PAD: u8 = 61u8;
181-
182182
let A2bBase64Args { s, strict_mode } = args;
183183
s.with_ref(|b| {
184184
if b.is_empty() {
@@ -228,52 +228,43 @@ mod decl {
228228
pads = 0;
229229

230230
// Decode individual ASCII character
231-
if quad_pos == 0 {
232-
quad_pos = 1;
233-
left_char = binary_char as u8;
234-
} else if quad_pos == 1 {
235-
quad_pos = 2;
236-
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
237-
left_char = (binary_char & 0x0f) as u8;
238-
} else if quad_pos == 2 {
239-
quad_pos = 3;
240-
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
241-
left_char = (binary_char & 0x03) as u8;
242-
} else if quad_pos == 3 {
243-
quad_pos = 0;
244-
decoded.push((left_char << 6) | binary_char as u8);
245-
left_char = 0;
231+
match quad_pos {
232+
0 => {
233+
quad_pos = 1;
234+
left_char = binary_char as u8;
235+
}
236+
1 => {
237+
quad_pos = 2;
238+
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
239+
left_char = (binary_char & 0x0f) as u8;
240+
}
241+
2 => {
242+
quad_pos = 3;
243+
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
244+
left_char = (binary_char & 0x03) as u8;
245+
}
246+
3 => {
247+
quad_pos = 0;
248+
decoded.push((left_char << 6) | binary_char as u8);
249+
left_char = 0;
250+
}
251+
_ => unsafe {
252+
// quad_pos is only assigned in this match statement to constants
253+
std::hint::unreachable_unchecked()
254+
},
246255
}
247256
}
248257

249-
return match quad_pos {
258+
match quad_pos {
250259
0 => Ok(decoded),
251-
1 => Err(base64::DecodeError::InvalidLastSymbol(decoded.len() / 3 * 4 + 1, 0)),
252-
_ => Err(base64::DecodeError::InvalidLength)
253-
};
254-
})
255-
.map_err(|err| {
256-
let python_error = match err {
257-
base64::DecodeError::InvalidByte(0, PAD) => {
258-
String::from("Leading padding not allowed")
259-
}
260-
base64::DecodeError::InvalidByte(_, PAD) => {
261-
String::from("Discontinuous padding not allowed")
262-
}
263-
base64::DecodeError::InvalidByte(_, _) => {
264-
String::from("Only base64 data is allowed")
265-
}
266-
base64::DecodeError::InvalidLastSymbol(_, PAD) => {
267-
String::from("Excess data after padding")
268-
}
269-
base64::DecodeError::InvalidLastSymbol(length, _) => {
270-
format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length)
271-
}
272-
base64::DecodeError::InvalidLength => String::from("Incorrect padding"),
273-
};
274-
275-
new_binascii_error(format!("error decoding base64: {python_error}"), vm)
260+
1 => Err(base64::DecodeError::InvalidLastSymbol(
261+
decoded.len() / 3 * 4 + 1,
262+
0,
263+
)),
264+
_ => Err(base64::DecodeError::InvalidLength),
265+
}
276266
})
267+
.map_err(|err| super::Base64DecodeError(err).to_pyexception(vm))
277268
}
278269

279270
#[pyfunction]
@@ -738,3 +729,26 @@ mod decl {
738729
})
739730
}
740731
}
732+
733+
struct Base64DecodeError(base64::DecodeError);
734+
735+
fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
736+
vm.new_exception_msg(decl::error_type(vm), msg)
737+
}
738+
739+
impl ToPyException for Base64DecodeError {
740+
fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
741+
use base64::DecodeError::*;
742+
let message = match self.0 {
743+
InvalidByte(0, PAD) => "Leading padding not allowed".to_owned(),
744+
InvalidByte(_, PAD) => "Discontinuous padding not allowed".to_owned(),
745+
InvalidByte(_, _) => "Only base64 data is allowed".to_owned(),
746+
InvalidLastSymbol(_, PAD) => "Excess data after padding".to_owned(),
747+
InvalidLastSymbol(length, _) => {
748+
format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length)
749+
}
750+
InvalidLength => "Incorrect padding".to_owned(),
751+
};
752+
new_binascii_error(format!("error decoding base64: {message}"), vm)
753+
}
754+
}

0 commit comments

Comments
 (0)