|
| 1 | +pub(super) use decl::crc32; |
1 | 2 | pub(crate) use decl::make_module; |
| 3 | +use rustpython_vm::{builtins::PyBaseExceptionRef, convert::ToPyException, VirtualMachine}; |
2 | 4 |
|
3 | | -pub(super) use decl::crc32; |
| 5 | +const PAD: u8 = 61u8; |
| 6 | +const MAXLINESIZE: usize = 76; // Excluding the CRLF |
4 | 7 |
|
5 | 8 | #[pymodule(name = "binascii")] |
6 | 9 | mod decl { |
| 10 | + use super::{MAXLINESIZE, PAD}; |
7 | 11 | use crate::vm::{ |
8 | | - builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef}, |
| 12 | + builtins::{PyIntRef, PyTypeRef}, |
| 13 | + convert::ToPyException, |
9 | 14 | function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg}, |
10 | 15 | PyResult, VirtualMachine, |
11 | 16 | }; |
12 | 17 | use itertools::Itertools; |
13 | 18 |
|
14 | | - const MAXLINESIZE: usize = 76; |
15 | | - |
16 | 19 | #[pyattr(name = "Error", once)] |
17 | | - fn error_type(vm: &VirtualMachine) -> PyTypeRef { |
| 20 | + pub(super) fn error_type(vm: &VirtualMachine) -> PyTypeRef { |
18 | 21 | vm.ctx.new_exception_type( |
19 | 22 | "binascii", |
20 | 23 | "Error", |
@@ -62,15 +65,18 @@ mod decl { |
62 | 65 | fn unhexlify(data: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> { |
63 | 66 | data.with_ref(|hex_bytes| { |
64 | 67 | if hex_bytes.len() % 2 != 0 { |
65 | | - return Err(new_binascii_error("Odd-length string".to_owned(), vm)); |
| 68 | + return Err(super::new_binascii_error( |
| 69 | + "Odd-length string".to_owned(), |
| 70 | + vm, |
| 71 | + )); |
66 | 72 | } |
67 | 73 |
|
68 | 74 | let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2); |
69 | 75 | for (n1, n2) in hex_bytes.iter().tuples() { |
70 | 76 | if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) { |
71 | 77 | unhex.push(n1 << 4 | n2); |
72 | 78 | } else { |
73 | | - return Err(new_binascii_error( |
| 79 | + return Err(super::new_binascii_error( |
74 | 80 | "Non-hexadecimal digit found".to_owned(), |
75 | 81 | vm, |
76 | 82 | )); |
@@ -139,10 +145,6 @@ mod decl { |
139 | 145 | newline: bool, |
140 | 146 | } |
141 | 147 |
|
142 | | - fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef { |
143 | | - vm.new_exception_msg(error_type(vm), msg) |
144 | | - } |
145 | | - |
146 | 148 | #[derive(FromArgs)] |
147 | 149 | struct A2bBase64Args { |
148 | 150 | #[pyarg(any)] |
@@ -177,8 +179,6 @@ mod decl { |
177 | 179 | -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, |
178 | 180 | ]; |
179 | 181 |
|
180 | | - const PAD: u8 = 61u8; |
181 | | - |
182 | 182 | let A2bBase64Args { s, strict_mode } = args; |
183 | 183 | s.with_ref(|b| { |
184 | 184 | if b.is_empty() { |
@@ -228,52 +228,43 @@ mod decl { |
228 | 228 | pads = 0; |
229 | 229 |
|
230 | 230 | // Decode individual ASCII character |
231 | | - if quad_pos == 0 { |
232 | | - quad_pos = 1; |
233 | | - left_char = binary_char as u8; |
234 | | - } else if quad_pos == 1 { |
235 | | - quad_pos = 2; |
236 | | - decoded.push((left_char << 2) | (binary_char >> 4) as u8); |
237 | | - left_char = (binary_char & 0x0f) as u8; |
238 | | - } else if quad_pos == 2 { |
239 | | - quad_pos = 3; |
240 | | - decoded.push((left_char << 4) | (binary_char >> 2) as u8); |
241 | | - left_char = (binary_char & 0x03) as u8; |
242 | | - } else if quad_pos == 3 { |
243 | | - quad_pos = 0; |
244 | | - decoded.push((left_char << 6) | binary_char as u8); |
245 | | - left_char = 0; |
| 231 | + match quad_pos { |
| 232 | + 0 => { |
| 233 | + quad_pos = 1; |
| 234 | + left_char = binary_char as u8; |
| 235 | + } |
| 236 | + 1 => { |
| 237 | + quad_pos = 2; |
| 238 | + decoded.push((left_char << 2) | (binary_char >> 4) as u8); |
| 239 | + left_char = (binary_char & 0x0f) as u8; |
| 240 | + } |
| 241 | + 2 => { |
| 242 | + quad_pos = 3; |
| 243 | + decoded.push((left_char << 4) | (binary_char >> 2) as u8); |
| 244 | + left_char = (binary_char & 0x03) as u8; |
| 245 | + } |
| 246 | + 3 => { |
| 247 | + quad_pos = 0; |
| 248 | + decoded.push((left_char << 6) | binary_char as u8); |
| 249 | + left_char = 0; |
| 250 | + } |
| 251 | + _ => unsafe { |
| 252 | + // quad_pos is only assigned in this match statement to constants |
| 253 | + std::hint::unreachable_unchecked() |
| 254 | + }, |
246 | 255 | } |
247 | 256 | } |
248 | 257 |
|
249 | | - return match quad_pos { |
| 258 | + match quad_pos { |
250 | 259 | 0 => Ok(decoded), |
251 | | - 1 => Err(base64::DecodeError::InvalidLastSymbol(decoded.len() / 3 * 4 + 1, 0)), |
252 | | - _ => Err(base64::DecodeError::InvalidLength) |
253 | | - }; |
254 | | - }) |
255 | | - .map_err(|err| { |
256 | | - let python_error = match err { |
257 | | - base64::DecodeError::InvalidByte(0, PAD) => { |
258 | | - String::from("Leading padding not allowed") |
259 | | - } |
260 | | - base64::DecodeError::InvalidByte(_, PAD) => { |
261 | | - String::from("Discontinuous padding not allowed") |
262 | | - } |
263 | | - base64::DecodeError::InvalidByte(_, _) => { |
264 | | - String::from("Only base64 data is allowed") |
265 | | - } |
266 | | - base64::DecodeError::InvalidLastSymbol(_, PAD) => { |
267 | | - String::from("Excess data after padding") |
268 | | - } |
269 | | - base64::DecodeError::InvalidLastSymbol(length, _) => { |
270 | | - format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length) |
271 | | - } |
272 | | - base64::DecodeError::InvalidLength => String::from("Incorrect padding"), |
273 | | - }; |
274 | | - |
275 | | - new_binascii_error(format!("error decoding base64: {python_error}"), vm) |
| 260 | + 1 => Err(base64::DecodeError::InvalidLastSymbol( |
| 261 | + decoded.len() / 3 * 4 + 1, |
| 262 | + 0, |
| 263 | + )), |
| 264 | + _ => Err(base64::DecodeError::InvalidLength), |
| 265 | + } |
276 | 266 | }) |
| 267 | + .map_err(|err| super::Base64DecodeError(err).to_pyexception(vm)) |
277 | 268 | } |
278 | 269 |
|
279 | 270 | #[pyfunction] |
@@ -738,3 +729,26 @@ mod decl { |
738 | 729 | }) |
739 | 730 | } |
740 | 731 | } |
| 732 | + |
| 733 | +struct Base64DecodeError(base64::DecodeError); |
| 734 | + |
| 735 | +fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef { |
| 736 | + vm.new_exception_msg(decl::error_type(vm), msg) |
| 737 | +} |
| 738 | + |
| 739 | +impl ToPyException for Base64DecodeError { |
| 740 | + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { |
| 741 | + use base64::DecodeError::*; |
| 742 | + let message = match self.0 { |
| 743 | + InvalidByte(0, PAD) => "Leading padding not allowed".to_owned(), |
| 744 | + InvalidByte(_, PAD) => "Discontinuous padding not allowed".to_owned(), |
| 745 | + InvalidByte(_, _) => "Only base64 data is allowed".to_owned(), |
| 746 | + InvalidLastSymbol(_, PAD) => "Excess data after padding".to_owned(), |
| 747 | + InvalidLastSymbol(length, _) => { |
| 748 | + format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length) |
| 749 | + } |
| 750 | + InvalidLength => "Incorrect padding".to_owned(), |
| 751 | + }; |
| 752 | + new_binascii_error(format!("error decoding base64: {message}"), vm) |
| 753 | + } |
| 754 | +} |
0 commit comments