diff --git a/.github/workflows/sha2.yml b/.github/workflows/sha2.yml index ad843edf0..f3ccbcd7a 100644 --- a/.github/workflows/sha2.yml +++ b/.github/workflows/sha2.yml @@ -86,7 +86,7 @@ jobs: RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" - run: cargo test --all-features env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft-compact" + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" --cfg sha2_backend_soft="compact" # macOS tests macos: @@ -112,7 +112,7 @@ jobs: RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" - run: cargo test --all-features env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft-compact" + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" --cfg sha2_backend_soft="compact" # Windows tests windows: @@ -139,7 +139,7 @@ jobs: RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" - run: cargo test --all-features env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft-compact" + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" --cfg sha2_backend_soft="compact" # Cross-compiled tests cross: @@ -186,16 +186,16 @@ jobs: toolchain: nightly - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft-compact" -C target-feature=+zknh,+zbkb + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" --cfg sha2_backend_soft="compact" - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" --cfg sha2_backend_riscv_zknh="compact" -C target-feature=+zknh,+zbkb riscv32-zknh: runs-on: ubuntu-latest @@ -208,16 +208,16 @@ jobs: components: rust-src - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft-compact" -C target-feature=+zknh,+zbkb + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" --cfg sha2_backend_soft="compact" - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" --cfg sha2_backend_riscv_zknh="compact" -C target-feature=+zknh,+zbkb # wasmtime tests wasm: diff --git a/sha2/Cargo.toml b/sha2/Cargo.toml index 0abfd4867..f75ce9a2a 100644 --- a/sha2/Cargo.toml +++ b/sha2/Cargo.toml @@ -35,7 +35,9 @@ oid = ["digest/oid"] [lints.rust.unexpected_cfgs] level = "warn" check-cfg = [ - 'cfg(sha2_backend, values("soft", "soft-compact", "riscv-zknh", "riscv-zknh-compact"))', + 'cfg(sha2_backend, values("aarch64-sha2", "aarch64-sha3", "soft", "riscv-zknh", "riscv-zknh-compact", "x86-avx2", "x86-shani"))', + 'cfg(sha2_backend_soft, values("compact"))', + 'cfg(sha2_backend_riscv_zknh, values("compact"))', ] [package.metadata.docs.rs] diff --git a/sha2/src/sha256.rs b/sha2/src/sha256.rs index 2dc5e15e2..386b4df08 100644 --- a/sha2/src/sha256.rs +++ b/sha2/src/sha256.rs @@ -2,31 +2,53 @@ cfg_if::cfg_if! { if #[cfg(sha2_backend = "soft")] { mod soft; use soft::compress; - } else if #[cfg(sha2_backend = "soft-compact")] { - mod soft_compact; - use soft_compact::compress; - } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - mod soft; + } else if #[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + sha2_backend = "x86-shani", + ))] { + #[cfg(not(all( + target_feature = "sha", + target_feature = "sse2", + target_feature = "ssse3", + target_feature = "sse4.1", + )))] + compile_error!("x86-shani backend requires sha, sse2, ssse3, sse4.1 target features"); + mod x86_shani; - use x86_shani::compress; + + fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + // SAFETY: we checked above that the required target features are enabled + unsafe { x86_shani::compress(state, blocks) } + } } else if #[cfg(all( any(target_arch = "riscv32", target_arch = "riscv64"), - sha2_backend = "riscv-zknh" + any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"), ))] { + #[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") + )))] + compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); + mod riscv_zknh; - mod riscv_zknh_utils; - use riscv_zknh::compress; + + fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + // SAFETY: we checked above that the required target features are enabled + unsafe { riscv_zknh::compress(state, blocks) } + } } else if #[cfg(all( - any(target_arch = "riscv32", target_arch = "riscv64"), - sha2_backend = "riscv-zknh-compact" + target_arch = "aarch64", + sha2_backend = "aarch64-sha2", ))] { - mod riscv_zknh_compact; - mod riscv_zknh_utils; - use riscv_zknh_compact::compress; - } else if #[cfg(target_arch = "aarch64")] { - mod soft; + #[cfg(not(target_feature = "sha2"))] + compile_error!("aarch64-sha2 backend requires sha2 target feature"); + mod aarch64_sha2; - use aarch64_sha2::compress; + + fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + // SAFETY: we checked above that the required target features are enabled + unsafe { aarch64_sha2::compress(state, blocks) } + } } else if #[cfg(target_arch = "loongarch64")] { mod loongarch64_asm; use loongarch64_asm::compress; @@ -35,17 +57,35 @@ cfg_if::cfg_if! { use wasm32_simd128::compress; } else { mod soft; - use soft::compress; - } -} -#[inline(always)] -#[allow(dead_code)] -fn to_u32s(block: &[u8; 64]) -> [u32; 16] { - core::array::from_fn(|i| { - let chunk = block[4 * i..][..4].try_into().unwrap(); - u32::from_be_bytes(chunk) - }) + cfg_if::cfg_if! { + if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + mod x86_shani; + cpufeatures::new!(shani_cpuid, "sha", "sse2", "ssse3", "sse4.1"); + } else if #[cfg(target_arch = "aarch64")] { + mod aarch64_sha2; + cpufeatures::new!(sha2_hwcap, "sha2"); + } + } + + fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + cfg_if::cfg_if! { + if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + if shani_cpuid::get() { + // SAFETY: we checked that required target features are available + return unsafe { x86_shani::compress(state, blocks) }; + } + } else if #[cfg(target_arch = "aarch64")] { + if sha2_hwcap::get() { + // SAFETY: we checked that `sha2` target feature is available + return unsafe { aarch64_sha2::compress(state, blocks) }; + } + } + } + + soft::compress(state, blocks); + } + } } /// Raw SHA-256 compression function. diff --git a/sha2/src/sha256/aarch64_sha2.rs b/sha2/src/sha256/aarch64_sha2.rs index 89d6f8cc4..51185dfbf 100644 --- a/sha2/src/sha256/aarch64_sha2.rs +++ b/sha2/src/sha256/aarch64_sha2.rs @@ -1,26 +1,13 @@ //! SHA-256 `aarch64` backend. +//! +//! Implementation adapted from mbedtls. #![allow(unsafe_op_in_unsafe_fn)] -// Implementation adapted from mbedtls. - -use core::arch::aarch64::*; - use crate::consts::K32; - -cpufeatures::new!(sha2_hwcap, "sha2"); - -pub(super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - // TODO: Replace with https://github.com/rust-lang/rfcs/pull/2725 - // after stabilization - if sha2_hwcap::get() { - unsafe { sha256_compress(state, blocks) } - } else { - super::soft::compress(state, blocks); - } -} +use core::arch::aarch64::*; #[target_feature(enable = "sha2")] -unsafe fn sha256_compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { +pub(super) unsafe fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { // SAFETY: Requires the sha2 feature. // Load state into vectors. diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index acdadeed1..d52f85fb8 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -1,108 +1,16 @@ -use crate::consts::K32; +mod utils; #[cfg(target_arch = "riscv32")] -use core::arch::riscv32::*; +use core::arch::riscv32::{sha256sig0, sha256sig1, sha256sum0, sha256sum1}; #[cfg(target_arch = "riscv64")] -use core::arch::riscv64::*; - -#[cfg(not(all( - target_feature = "zknh", - any(target_feature = "zbb", target_feature = "zbkb") -)))] -compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); - -#[inline(always)] -fn ch(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (!x & z) -} - -#[inline(always)] -fn maj(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (x & z) ^ (y & z) -} - -fn round(state: &mut [u32; 8], block: &[u32; 16], k: &[u32]) { - let n = K32.len() - R; - #[allow(clippy::identity_op)] - let a = (n + 0) % 8; - let b = (n + 1) % 8; - let c = (n + 2) % 8; - let d = (n + 3) % 8; - let e = (n + 4) % 8; - let f = (n + 5) % 8; - let g = (n + 6) % 8; - let h = (n + 7) % 8; - - state[h] = state[h] - .wrapping_add(unsafe { sha256sum1(state[e]) }) - .wrapping_add(ch(state[e], state[f], state[g])) - .wrapping_add(super::riscv_zknh_utils::opaque_load::(k)) - .wrapping_add(block[R]); - state[d] = state[d].wrapping_add(state[h]); - state[h] = state[h] - .wrapping_add(unsafe { sha256sum0(state[a]) }) - .wrapping_add(maj(state[a], state[b], state[c])) -} - -fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], k: &[u32]) { - round::(state, block, k); - - block[R] = block[R] - .wrapping_add(unsafe { sha256sig1(block[(R + 14) % 16]) }) - .wrapping_add(block[(R + 9) % 16]) - .wrapping_add(unsafe { sha256sig0(block[(R + 1) % 16]) }); -} - -#[inline(always)] -fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { - let s = &mut state.clone(); - let b = &mut block; - - for i in 0..3 { - let k = &K32[16 * i..]; - round_schedule::<0>(s, b, k); - round_schedule::<1>(s, b, k); - round_schedule::<2>(s, b, k); - round_schedule::<3>(s, b, k); - round_schedule::<4>(s, b, k); - round_schedule::<5>(s, b, k); - round_schedule::<6>(s, b, k); - round_schedule::<7>(s, b, k); - round_schedule::<8>(s, b, k); - round_schedule::<9>(s, b, k); - round_schedule::<10>(s, b, k); - round_schedule::<11>(s, b, k); - round_schedule::<12>(s, b, k); - round_schedule::<13>(s, b, k); - round_schedule::<14>(s, b, k); - round_schedule::<15>(s, b, k); - } - - let k = &K32[48..]; - round::<0>(s, b, k); - round::<1>(s, b, k); - round::<2>(s, b, k); - round::<3>(s, b, k); - round::<4>(s, b, k); - round::<5>(s, b, k); - round::<6>(s, b, k); - round::<7>(s, b, k); - round::<8>(s, b, k); - round::<9>(s, b, k); - round::<10>(s, b, k); - round::<11>(s, b, k); - round::<12>(s, b, k); - round::<13>(s, b, k); - round::<14>(s, b, k); - round::<15>(s, b, k); - - for i in 0..8 { - state[i] = state[i].wrapping_add(s[i]); - } -} - -pub(super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { - compress_block(state, block); +use core::arch::riscv64::{sha256sig0, sha256sig1, sha256sum0, sha256sum1}; + +cfg_if::cfg_if! { + if #[cfg(sha2_backend_riscv_zknh = "compact")] { + mod compact; + pub(super) use compact::compress; + } else { + mod unroll; + pub(super) use unroll::compress; } } diff --git a/sha2/src/sha256/riscv_zknh_compact.rs b/sha2/src/sha256/riscv_zknh/compact.rs similarity index 62% rename from sha2/src/sha256/riscv_zknh_compact.rs rename to sha2/src/sha256/riscv_zknh/compact.rs index bc49584b9..83840560b 100644 --- a/sha2/src/sha256/riscv_zknh_compact.rs +++ b/sha2/src/sha256/riscv_zknh/compact.rs @@ -1,27 +1,38 @@ +use super::{sha256sig0, sha256sig1, sha256sum0, sha256sum1}; use crate::consts::K32; -#[cfg(target_arch = "riscv32")] -use core::arch::riscv32::*; -#[cfg(target_arch = "riscv64")] -use core::arch::riscv64::*; +#[target_feature(enable = "zknh")] +pub(in super::super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + for block in blocks.iter().map(super::utils::load_block) { + compress_block(state, block); + } +} -#[cfg(not(all( - target_feature = "zknh", - any(target_feature = "zbb", target_feature = "zbkb") -)))] -compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features"); +#[target_feature(enable = "zknh")] +fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { + let mut s = *state; -#[inline(always)] -fn ch(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (!x & z) + for r in 0..64 { + round(&mut s, &block, r); + if r < 48 { + schedule(&mut block, r) + } + } + + for i in 0..8 { + state[i] = state[i].wrapping_add(s[i]); + } } -#[inline(always)] -fn maj(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (x & z) ^ (y & z) +#[target_feature(enable = "zknh")] +fn schedule(block: &mut [u32; 16], r: usize) { + block[r % 16] = block[r % 16] + .wrapping_add(sha256sig1(block[(r + 14) % 16])) + .wrapping_add(block[(r + 9) % 16]) + .wrapping_add(sha256sig0(block[(r + 1) % 16])); } -#[inline(always)] +#[target_feature(enable = "zknh")] fn round(state: &mut [u32; 8], block: &[u32; 16], r: usize) { let n = K32.len() - r; #[allow(clippy::identity_op)] @@ -35,42 +46,22 @@ fn round(state: &mut [u32; 8], block: &[u32; 16], r: usize) { let h = (n + 7) % 8; state[h] = state[h] - .wrapping_add(unsafe { sha256sum1(state[e]) }) + .wrapping_add(sha256sum1(state[e])) .wrapping_add(ch(state[e], state[f], state[g])) .wrapping_add(K32[r]) .wrapping_add(block[r % 16]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] - .wrapping_add(unsafe { sha256sum0(state[a]) }) + .wrapping_add(sha256sum0(state[a])) .wrapping_add(maj(state[a], state[b], state[c])) } #[inline(always)] -fn schedule(block: &mut [u32; 16], r: usize) { - block[r % 16] = block[r % 16] - .wrapping_add(unsafe { sha256sig1(block[(r + 14) % 16]) }) - .wrapping_add(block[(r + 9) % 16]) - .wrapping_add(unsafe { sha256sig0(block[(r + 1) % 16]) }); +fn ch(x: u32, y: u32, z: u32) -> u32 { + (x & y) ^ (!x & z) } #[inline(always)] -fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { - let mut s = *state; - - for r in 0..64 { - round(&mut s, &block, r); - if r < 48 { - schedule(&mut block, r) - } - } - - for i in 0..8 { - state[i] = state[i].wrapping_add(s[i]); - } -} - -pub(super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { - compress_block(state, block); - } +fn maj(x: u32, y: u32, z: u32) -> u32 { + (x & y) ^ (x & z) ^ (y & z) } diff --git a/sha2/src/sha256/riscv_zknh/unroll.rs b/sha2/src/sha256/riscv_zknh/unroll.rs new file mode 100644 index 000000000..e8b702e58 --- /dev/null +++ b/sha2/src/sha256/riscv_zknh/unroll.rs @@ -0,0 +1,101 @@ +use super::{sha256sig0, sha256sig1, sha256sum0, sha256sum1}; +use crate::consts::K32; + +#[target_feature(enable = "zknh")] +pub(in super::super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + for block in blocks.iter().map(super::utils::load_block) { + compress_block(state, block); + } +} + +#[target_feature(enable = "zknh")] +fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { + let s = &mut state.clone(); + let b = &mut block; + + for i in 0..3 { + let k = &K32[16 * i..]; + round_schedule::<0>(s, b, k); + round_schedule::<1>(s, b, k); + round_schedule::<2>(s, b, k); + round_schedule::<3>(s, b, k); + round_schedule::<4>(s, b, k); + round_schedule::<5>(s, b, k); + round_schedule::<6>(s, b, k); + round_schedule::<7>(s, b, k); + round_schedule::<8>(s, b, k); + round_schedule::<9>(s, b, k); + round_schedule::<10>(s, b, k); + round_schedule::<11>(s, b, k); + round_schedule::<12>(s, b, k); + round_schedule::<13>(s, b, k); + round_schedule::<14>(s, b, k); + round_schedule::<15>(s, b, k); + } + + let k = &K32[48..]; + round::<0>(s, b, k); + round::<1>(s, b, k); + round::<2>(s, b, k); + round::<3>(s, b, k); + round::<4>(s, b, k); + round::<5>(s, b, k); + round::<6>(s, b, k); + round::<7>(s, b, k); + round::<8>(s, b, k); + round::<9>(s, b, k); + round::<10>(s, b, k); + round::<11>(s, b, k); + round::<12>(s, b, k); + round::<13>(s, b, k); + round::<14>(s, b, k); + round::<15>(s, b, k); + + for i in 0..8 { + state[i] = state[i].wrapping_add(s[i]); + } +} + +#[target_feature(enable = "zknh")] +fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], k: &[u32]) { + round::(state, block, k); + + block[R] = block[R] + .wrapping_add(sha256sig1(block[(R + 14) % 16])) + .wrapping_add(block[(R + 9) % 16]) + .wrapping_add(sha256sig0(block[(R + 1) % 16])); +} + +#[target_feature(enable = "zknh")] +fn round(state: &mut [u32; 8], block: &[u32; 16], k: &[u32]) { + let n = K32.len() - R; + #[allow(clippy::identity_op)] + let a = (n + 0) % 8; + let b = (n + 1) % 8; + let c = (n + 2) % 8; + let d = (n + 3) % 8; + let e = (n + 4) % 8; + let f = (n + 5) % 8; + let g = (n + 6) % 8; + let h = (n + 7) % 8; + + state[h] = state[h] + .wrapping_add(sha256sum1(state[e])) + .wrapping_add(ch(state[e], state[f], state[g])) + .wrapping_add(super::utils::opaque_load::(k)) + .wrapping_add(block[R]); + state[d] = state[d].wrapping_add(state[h]); + state[h] = state[h] + .wrapping_add(sha256sum0(state[a])) + .wrapping_add(maj(state[a], state[b], state[c])) +} + +#[inline(always)] +fn ch(x: u32, y: u32, z: u32) -> u32 { + (x & y) ^ (!x & z) +} + +#[inline(always)] +fn maj(x: u32, y: u32, z: u32) -> u32 { + (x & y) ^ (x & z) ^ (y & z) +} diff --git a/sha2/src/sha256/riscv_zknh_utils.rs b/sha2/src/sha256/riscv_zknh/utils.rs similarity index 98% rename from sha2/src/sha256/riscv_zknh_utils.rs rename to sha2/src/sha256/riscv_zknh/utils.rs index d5c072679..2ec54977e 100644 --- a/sha2/src/sha256/riscv_zknh_utils.rs +++ b/sha2/src/sha256/riscv_zknh/utils.rs @@ -80,7 +80,7 @@ fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] { } /// This function returns `k[R]`, but prevents compiler from inlining the indexed value -#[cfg(sha2_backend = "riscv-zknh")] +#[cfg(not(sha2_backend_riscv_zknh = "compact"))] pub(super) fn opaque_load(k: &[u32]) -> u32 { assert!(R < k.len()); let dst; diff --git a/sha2/src/sha256/soft.rs b/sha2/src/sha256/soft.rs index 965aa216c..438ecff67 100644 --- a/sha2/src/sha256/soft.rs +++ b/sha2/src/sha256/soft.rs @@ -1,100 +1,16 @@ -use crate::consts::K32; - -#[rustfmt::skip] -macro_rules! repeat64 { - ($i:ident, $b:block) => { - let $i = 0; $b; let $i = 1; $b; let $i = 2; $b; let $i = 3; $b; - let $i = 4; $b; let $i = 5; $b; let $i = 6; $b; let $i = 7; $b; - let $i = 8; $b; let $i = 9; $b; let $i = 10; $b; let $i = 11; $b; - let $i = 12; $b; let $i = 13; $b; let $i = 14; $b; let $i = 15; $b; - let $i = 16; $b; let $i = 17; $b; let $i = 18; $b; let $i = 19; $b; - let $i = 20; $b; let $i = 21; $b; let $i = 22; $b; let $i = 23; $b; - let $i = 24; $b; let $i = 25; $b; let $i = 26; $b; let $i = 27; $b; - let $i = 28; $b; let $i = 29; $b; let $i = 30; $b; let $i = 31; $b; - let $i = 32; $b; let $i = 33; $b; let $i = 34; $b; let $i = 35; $b; - let $i = 36; $b; let $i = 37; $b; let $i = 38; $b; let $i = 39; $b; - let $i = 40; $b; let $i = 41; $b; let $i = 42; $b; let $i = 43; $b; - let $i = 44; $b; let $i = 45; $b; let $i = 46; $b; let $i = 47; $b; - let $i = 48; $b; let $i = 49; $b; let $i = 50; $b; let $i = 51; $b; - let $i = 52; $b; let $i = 53; $b; let $i = 54; $b; let $i = 55; $b; - let $i = 56; $b; let $i = 57; $b; let $i = 58; $b; let $i = 59; $b; - let $i = 60; $b; let $i = 61; $b; let $i = 62; $b; let $i = 63; $b; - }; -} - -/// Read round constant -fn rk(i: usize) -> u32 { - // `read_volatile` forces compiler to read round constants from the static - // instead of inlining them, which improves codegen and performance on some platforms. - // On x86 targets 32-bit constants can be encoded using immediate argument on the `add` - // instruction, so it's more efficient to inline them. - cfg_if::cfg_if! { - if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - use core::ptr::read as r; - } else { - use core::ptr::read_volatile as r; - } +cfg_if::cfg_if! { + if #[cfg(sha2_backend_soft = "compact")] { + mod compact; + pub(super) use compact::compress; + } else { + mod unroll; + pub(super) use unroll::compress; } - - unsafe { - let p = K32.as_ptr().add(i); - r(p) - } -} - -/// Process a block with the SHA-256 algorithm. -fn compress_block(state: &mut [u32; 8], block: &[u8; 64]) { - let mut block = super::to_u32s(block); - let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = *state; - - repeat64!(i, { - let w = if i < 16 { - block[i] - } else { - let w15 = block[(i - 15) % 16]; - let s0 = (w15.rotate_right(7)) ^ (w15.rotate_right(18)) ^ (w15 >> 3); - let w2 = block[(i - 2) % 16]; - let s1 = (w2.rotate_right(17)) ^ (w2.rotate_right(19)) ^ (w2 >> 10); - block[i % 16] = block[i % 16] - .wrapping_add(s0) - .wrapping_add(block[(i - 7) % 16]) - .wrapping_add(s1); - block[i % 16] - }; - - let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); - let ch = (e & f) ^ ((!e) & g); - let t1 = s1 - .wrapping_add(ch) - .wrapping_add(rk(i)) - .wrapping_add(w) - .wrapping_add(h); - let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22); - let maj = (a & b) ^ (a & c) ^ (b & c); - let t2 = s0.wrapping_add(maj); - - h = g; - g = f; - f = e; - e = d.wrapping_add(t1); - d = c; - c = b; - b = a; - a = t1.wrapping_add(t2); - }); - - state[0] = state[0].wrapping_add(a); - state[1] = state[1].wrapping_add(b); - state[2] = state[2].wrapping_add(c); - state[3] = state[3].wrapping_add(d); - state[4] = state[4].wrapping_add(e); - state[5] = state[5].wrapping_add(f); - state[6] = state[6].wrapping_add(g); - state[7] = state[7].wrapping_add(h); } -pub(super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks { - compress_block(state, block); - } +fn to_u32s(block: &[u8; 64]) -> [u32; 16] { + core::array::from_fn(|i| { + let chunk = block[4 * i..][..4].try_into().unwrap(); + u32::from_be_bytes(chunk) + }) } diff --git a/sha2/src/sha256/soft_compact.rs b/sha2/src/sha256/soft/compact.rs similarity index 95% rename from sha2/src/sha256/soft_compact.rs rename to sha2/src/sha256/soft/compact.rs index ef1793f7a..d82959a7a 100644 --- a/sha2/src/sha256/soft_compact.rs +++ b/sha2/src/sha256/soft/compact.rs @@ -50,7 +50,7 @@ fn compress_u32(state: &mut [u32; 8], mut block: [u32; 16]) { state[7] = state[7].wrapping_add(h); } -pub(super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { +pub(in super::super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { for block in blocks.iter() { compress_u32(state, super::to_u32s(block)); } diff --git a/sha2/src/sha256/soft/unroll.rs b/sha2/src/sha256/soft/unroll.rs new file mode 100644 index 000000000..83eed2f3a --- /dev/null +++ b/sha2/src/sha256/soft/unroll.rs @@ -0,0 +1,100 @@ +use crate::consts::K32; + +#[rustfmt::skip] +macro_rules! repeat64 { + ($i:ident, $b:block) => { + let $i = 0; $b; let $i = 1; $b; let $i = 2; $b; let $i = 3; $b; + let $i = 4; $b; let $i = 5; $b; let $i = 6; $b; let $i = 7; $b; + let $i = 8; $b; let $i = 9; $b; let $i = 10; $b; let $i = 11; $b; + let $i = 12; $b; let $i = 13; $b; let $i = 14; $b; let $i = 15; $b; + let $i = 16; $b; let $i = 17; $b; let $i = 18; $b; let $i = 19; $b; + let $i = 20; $b; let $i = 21; $b; let $i = 22; $b; let $i = 23; $b; + let $i = 24; $b; let $i = 25; $b; let $i = 26; $b; let $i = 27; $b; + let $i = 28; $b; let $i = 29; $b; let $i = 30; $b; let $i = 31; $b; + let $i = 32; $b; let $i = 33; $b; let $i = 34; $b; let $i = 35; $b; + let $i = 36; $b; let $i = 37; $b; let $i = 38; $b; let $i = 39; $b; + let $i = 40; $b; let $i = 41; $b; let $i = 42; $b; let $i = 43; $b; + let $i = 44; $b; let $i = 45; $b; let $i = 46; $b; let $i = 47; $b; + let $i = 48; $b; let $i = 49; $b; let $i = 50; $b; let $i = 51; $b; + let $i = 52; $b; let $i = 53; $b; let $i = 54; $b; let $i = 55; $b; + let $i = 56; $b; let $i = 57; $b; let $i = 58; $b; let $i = 59; $b; + let $i = 60; $b; let $i = 61; $b; let $i = 62; $b; let $i = 63; $b; + }; +} + +/// Read round constant +fn rk(i: usize) -> u32 { + // `read_volatile` forces compiler to read round constants from the static + // instead of inlining them, which improves codegen and performance on some platforms. + // On x86 targets 32-bit constants can be encoded using immediate argument on the `add` + // instruction, so it's more efficient to inline them. + cfg_if::cfg_if! { + if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + use core::ptr::read as r; + } else { + use core::ptr::read_volatile as r; + } + } + + unsafe { + let p = K32.as_ptr().add(i); + r(p) + } +} + +/// Process a block with the SHA-256 algorithm. +fn compress_block(state: &mut [u32; 8], block: &[u8; 64]) { + let mut block = super::to_u32s(block); + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = *state; + + repeat64!(i, { + let w = if i < 16 { + block[i] + } else { + let w15 = block[(i - 15) % 16]; + let s0 = (w15.rotate_right(7)) ^ (w15.rotate_right(18)) ^ (w15 >> 3); + let w2 = block[(i - 2) % 16]; + let s1 = (w2.rotate_right(17)) ^ (w2.rotate_right(19)) ^ (w2 >> 10); + block[i % 16] = block[i % 16] + .wrapping_add(s0) + .wrapping_add(block[(i - 7) % 16]) + .wrapping_add(s1); + block[i % 16] + }; + + let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); + let ch = (e & f) ^ ((!e) & g); + let t1 = s1 + .wrapping_add(ch) + .wrapping_add(rk(i)) + .wrapping_add(w) + .wrapping_add(h); + let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22); + let maj = (a & b) ^ (a & c) ^ (b & c); + let t2 = s0.wrapping_add(maj); + + h = g; + g = f; + f = e; + e = d.wrapping_add(t1); + d = c; + c = b; + b = a; + a = t1.wrapping_add(t2); + }); + + state[0] = state[0].wrapping_add(a); + state[1] = state[1].wrapping_add(b); + state[2] = state[2].wrapping_add(c); + state[3] = state[3].wrapping_add(d); + state[4] = state[4].wrapping_add(e); + state[5] = state[5].wrapping_add(f); + state[6] = state[6].wrapping_add(g); + state[7] = state[7].wrapping_add(h); +} + +pub(in super::super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + for block in blocks { + compress_block(state, block); + } +} diff --git a/sha2/src/sha256/x86_shani.rs b/sha2/src/sha256/x86_shani.rs index d40661d5c..55155dc6b 100644 --- a/sha2/src/sha256/x86_shani.rs +++ b/sha2/src/sha256/x86_shani.rs @@ -7,6 +7,7 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; +#[target_feature(enable = "sha,sse2,ssse3,sse4.1")] unsafe fn schedule(v0: __m128i, v1: __m128i, v2: __m128i, v3: __m128i) -> __m128i { let t1 = _mm_sha256msg1_epu32(v0, v1); let t2 = _mm_alignr_epi8(v3, v2, 4); @@ -39,7 +40,7 @@ macro_rules! schedule_rounds4 { // we use unaligned loads with `__m128i` pointers #[allow(clippy::cast_ptr_alignment)] #[target_feature(enable = "sha,sse2,ssse3,sse4.1")] -unsafe fn digest_blocks(state: &mut [u32; 8], blocks: &[[u8; 64]]) { +pub(super) unsafe fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { #[allow(non_snake_case)] let MASK: __m128i = _mm_set_epi64x( 0x0C0D_0E0F_0809_0A0Bu64 as i64, @@ -96,17 +97,3 @@ unsafe fn digest_blocks(state: &mut [u32; 8], blocks: &[[u8; 64]]) { _mm_storeu_si128(state_ptr_mut.add(0), dcba); _mm_storeu_si128(state_ptr_mut.add(1), hgef); } - -cpufeatures::new!(shani_cpuid, "sha", "sse2", "ssse3", "sse4.1"); - -pub(super) fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - // TODO: Replace with https://github.com/rust-lang/rfcs/pull/2725 - // after stabilization - if shani_cpuid::get() { - unsafe { - digest_blocks(state, blocks); - } - } else { - super::soft::compress(state, blocks); - } -} diff --git a/sha2/src/sha512.rs b/sha2/src/sha512.rs index a29950fda..92b02a76b 100644 --- a/sha2/src/sha512.rs +++ b/sha2/src/sha512.rs @@ -2,31 +2,48 @@ cfg_if::cfg_if! { if #[cfg(sha2_backend = "soft")] { mod soft; use soft::compress; - } else if #[cfg(sha2_backend = "soft-compact")] { - mod soft_compact; - use soft_compact::compress; - } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - mod soft; + } else if #[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + sha2_backend = "x86-avx2", + ))] { + #[cfg(not(target_feature = "avx2"))] + compile_error!("x86-avx2 backend requires avx2 target feature"); + mod x86_avx2; - use x86_avx2::compress; + + fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + // SAFETY: we checked above that the required target features are enabled + unsafe { x86_avx2::compress(state, blocks) } + } } else if #[cfg(all( any(target_arch = "riscv32", target_arch = "riscv64"), sha2_backend = "riscv-zknh" ))] { + #[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") + )))] + compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); + mod riscv_zknh; - mod riscv_zknh_utils; - use riscv_zknh::compress; + + fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + // SAFETY: we checked above that the required target features are enabled + unsafe { riscv_zknh::compress(state, blocks) } + } } else if #[cfg(all( - any(target_arch = "riscv32", target_arch = "riscv64"), - sha2_backend = "riscv-zknh-compact" - ))] { - mod riscv_zknh_compact; - mod riscv_zknh_utils; - use riscv_zknh_compact::compress; - } else if #[cfg(target_arch = "aarch64")] { - mod soft; - mod aarch64_sha2; - use aarch64_sha2::compress; + target_arch = "aarch64", + sha2_backend = "aarch64-sha3", + ))] { + #[cfg(not(target_feature = "sha3"))] + compile_error!("aarch64-sha3 backend requires sha3 target feature"); + + mod aarch64_sha3; + + fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + // SAFETY: we checked above that the required target features are enabled + unsafe { aarch64_sha3::compress(state, blocks) } + } } else if #[cfg(target_arch = "loongarch64")] { mod loongarch64_asm; use loongarch64_asm::compress; @@ -35,17 +52,35 @@ cfg_if::cfg_if! { use wasm32_simd128::compress; } else { mod soft; - use soft::compress; - } -} -#[inline(always)] -#[allow(dead_code)] -fn to_u64s(block: &[u8; 128]) -> [u64; 16] { - core::array::from_fn(|i| { - let chunk = block[8 * i..][..8].try_into().unwrap(); - u64::from_be_bytes(chunk) - }) + cfg_if::cfg_if! { + if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + mod x86_avx2; + cpufeatures::new!(avx2_cpuid, "avx2"); + } else if #[cfg(target_arch = "aarch64")] { + mod aarch64_sha3; + cpufeatures::new!(sha3_hwcap, "sha3"); + } + } + + fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + cfg_if::cfg_if! { + if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + if avx2_cpuid::get() { + // SAFETY: we checked that required target features are available + return unsafe { x86_avx2::compress(state, blocks) }; + } + } else if #[cfg(target_arch = "aarch64")] { + if sha3_hwcap::get() { + // SAFETY: we checked that `sha3` target feature is available + return unsafe { aarch64_sha3::compress(state, blocks) }; + } + } + } + + soft::compress(state, blocks); + } + } } /// Raw SHA-512 compression function. diff --git a/sha2/src/sha512/aarch64_sha2.rs b/sha2/src/sha512/aarch64_sha3.rs similarity index 95% rename from sha2/src/sha512/aarch64_sha2.rs rename to sha2/src/sha512/aarch64_sha3.rs index 5fb2ad0e4..968a2cd02 100644 --- a/sha2/src/sha512/aarch64_sha2.rs +++ b/sha2/src/sha512/aarch64_sha3.rs @@ -1,24 +1,11 @@ // Implementation adapted from mbedtls. #![allow(unsafe_op_in_unsafe_fn)] -use core::arch::aarch64::*; - use crate::consts::K64; - -cpufeatures::new!(sha3_hwcap, "sha3"); - -pub(super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - // TODO: Replace with https://github.com/rust-lang/rfcs/pull/2725 - // after stabilization - if sha3_hwcap::get() { - unsafe { sha512_compress(state, blocks) } - } else { - super::soft::compress(state, blocks); - } -} +use core::arch::aarch64::*; #[target_feature(enable = "sha3")] -unsafe fn sha512_compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { +pub(super) unsafe fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { // SAFETY: Requires the sha3 feature. // Load state into vectors. diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index d30c92d9c..c7689aabf 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -1,139 +1,49 @@ -use crate::consts::K64; +mod utils; + +cfg_if::cfg_if! { + if #[cfg(sha2_backend_riscv_zknh = "compact")] { + mod compact; + pub(super) use compact::compress; + } else { + mod unroll; + pub(super) use unroll::compress; + } +} -#[cfg(target_arch = "riscv32")] -use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] -use core::arch::riscv64::*; +use core::arch::riscv64::{sha512sig0, sha512sig1, sha512sum0, sha512sum1}; -#[cfg(not(all( - target_feature = "zknh", - any(target_feature = "zbb", target_feature = "zbkb") -)))] -compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); +#[cfg(target_arch = "riscv32")] +use core::arch::riscv32::*; #[cfg(target_arch = "riscv32")] -#[allow(unsafe_op_in_unsafe_fn)] -unsafe fn sha512sum0(x: u64) -> u64 { +#[target_feature(enable = "zknh")] +fn sha512sum0(x: u64) -> u64 { let a = sha512sum0r((x >> 32) as u32, x as u32); let b = sha512sum0r(x as u32, (x >> 32) as u32); ((a as u64) << 32) | (b as u64) } #[cfg(target_arch = "riscv32")] -#[allow(unsafe_op_in_unsafe_fn)] -unsafe fn sha512sum1(x: u64) -> u64 { +#[target_feature(enable = "zknh")] +fn sha512sum1(x: u64) -> u64 { let a = sha512sum1r((x >> 32) as u32, x as u32); let b = sha512sum1r(x as u32, (x >> 32) as u32); ((a as u64) << 32) | (b as u64) } #[cfg(target_arch = "riscv32")] -#[allow(unsafe_op_in_unsafe_fn)] -unsafe fn sha512sig0(x: u64) -> u64 { +#[target_feature(enable = "zknh")] +fn sha512sig0(x: u64) -> u64 { let a = sha512sig0h((x >> 32) as u32, x as u32); let b = sha512sig0l(x as u32, (x >> 32) as u32); ((a as u64) << 32) | (b as u64) } #[cfg(target_arch = "riscv32")] -#[allow(unsafe_op_in_unsafe_fn)] -unsafe fn sha512sig1(x: u64) -> u64 { +#[target_feature(enable = "zknh")] +fn sha512sig1(x: u64) -> u64 { let a = sha512sig1h((x >> 32) as u32, x as u32); let b = sha512sig1l(x as u32, (x >> 32) as u32); ((a as u64) << 32) | (b as u64) } - -#[inline(always)] -fn ch(x: u64, y: u64, z: u64) -> u64 { - (x & y) ^ (!x & z) -} - -#[inline(always)] -fn maj(x: u64, y: u64, z: u64) -> u64 { - (x & y) ^ (x & z) ^ (y & z) -} - -fn round(state: &mut [u64; 8], block: &[u64; 16], k: &[u64]) { - let n = K64.len() - R; - #[allow(clippy::identity_op)] - let a = (n + 0) % 8; - let b = (n + 1) % 8; - let c = (n + 2) % 8; - let d = (n + 3) % 8; - let e = (n + 4) % 8; - let f = (n + 5) % 8; - let g = (n + 6) % 8; - let h = (n + 7) % 8; - - state[h] = state[h] - .wrapping_add(unsafe { sha512sum1(state[e]) }) - .wrapping_add(ch(state[e], state[f], state[g])) - .wrapping_add(super::riscv_zknh_utils::opaque_load::(k)) - .wrapping_add(block[R]); - state[d] = state[d].wrapping_add(state[h]); - state[h] = state[h] - .wrapping_add(unsafe { sha512sum0(state[a]) }) - .wrapping_add(maj(state[a], state[b], state[c])) -} - -fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], k: &[u64]) { - round::(state, block, k); - - block[R] = block[R] - .wrapping_add(unsafe { sha512sig1(block[(R + 14) % 16]) }) - .wrapping_add(block[(R + 9) % 16]) - .wrapping_add(unsafe { sha512sig0(block[(R + 1) % 16]) }); -} - -fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { - let s = &mut state.clone(); - let b = &mut block; - - for i in 0..4 { - let k = &K64[16 * i..]; - round_schedule::<0>(s, b, k); - round_schedule::<1>(s, b, k); - round_schedule::<2>(s, b, k); - round_schedule::<3>(s, b, k); - round_schedule::<4>(s, b, k); - round_schedule::<5>(s, b, k); - round_schedule::<6>(s, b, k); - round_schedule::<7>(s, b, k); - round_schedule::<8>(s, b, k); - round_schedule::<9>(s, b, k); - round_schedule::<10>(s, b, k); - round_schedule::<11>(s, b, k); - round_schedule::<12>(s, b, k); - round_schedule::<13>(s, b, k); - round_schedule::<14>(s, b, k); - round_schedule::<15>(s, b, k); - } - - let k = &K64[64..]; - round::<0>(s, b, k); - round::<1>(s, b, k); - round::<2>(s, b, k); - round::<3>(s, b, k); - round::<4>(s, b, k); - round::<5>(s, b, k); - round::<6>(s, b, k); - round::<7>(s, b, k); - round::<8>(s, b, k); - round::<9>(s, b, k); - round::<10>(s, b, k); - round::<11>(s, b, k); - round::<12>(s, b, k); - round::<13>(s, b, k); - round::<14>(s, b, k); - round::<15>(s, b, k); - - for i in 0..8 { - state[i] = state[i].wrapping_add(s[i]); - } -} - -pub(super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { - compress_block(state, block); - } -} diff --git a/sha2/src/sha512/riscv_zknh/compact.rs b/sha2/src/sha512/riscv_zknh/compact.rs new file mode 100644 index 000000000..865288ee4 --- /dev/null +++ b/sha2/src/sha512/riscv_zknh/compact.rs @@ -0,0 +1,67 @@ +use super::{sha512sig0, sha512sig1, sha512sum0, sha512sum1}; +use crate::consts::K64; + +#[target_feature(enable = "zknh")] +pub(in super::super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + for block in blocks.iter().map(super::utils::load_block) { + compress_block(state, block); + } +} + +#[target_feature(enable = "zknh")] +fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { + let mut s = *state; + + for r in 0..80 { + round(&mut s, &block, r); + if r < 64 { + schedule(&mut block, r) + } + } + + for i in 0..8 { + state[i] = state[i].wrapping_add(s[i]); + } +} + +#[target_feature(enable = "zknh")] +fn schedule(block: &mut [u64; 16], r: usize) { + block[r % 16] = block[r % 16] + .wrapping_add(sha512sig1(block[(r + 14) % 16])) + .wrapping_add(block[(r + 9) % 16]) + .wrapping_add(sha512sig0(block[(r + 1) % 16])); +} + +#[target_feature(enable = "zknh")] +fn round(state: &mut [u64; 8], block: &[u64; 16], r: usize) { + let n = K64.len() - r; + #[allow(clippy::identity_op)] + let a = (n + 0) % 8; + let b = (n + 1) % 8; + let c = (n + 2) % 8; + let d = (n + 3) % 8; + let e = (n + 4) % 8; + let f = (n + 5) % 8; + let g = (n + 6) % 8; + let h = (n + 7) % 8; + + state[h] = state[h] + .wrapping_add(sha512sum1(state[e])) + .wrapping_add(ch(state[e], state[f], state[g])) + .wrapping_add(K64[r]) + .wrapping_add(block[r % 16]); + state[d] = state[d].wrapping_add(state[h]); + state[h] = state[h] + .wrapping_add(sha512sum0(state[a])) + .wrapping_add(maj(state[a], state[b], state[c])) +} + +#[inline(always)] +fn ch(x: u64, y: u64, z: u64) -> u64 { + (x & y) ^ (!x & z) +} + +#[inline(always)] +fn maj(x: u64, y: u64, z: u64) -> u64 { + (x & y) ^ (x & z) ^ (y & z) +} diff --git a/sha2/src/sha512/riscv_zknh/unroll.rs b/sha2/src/sha512/riscv_zknh/unroll.rs new file mode 100644 index 000000000..9f21fa221 --- /dev/null +++ b/sha2/src/sha512/riscv_zknh/unroll.rs @@ -0,0 +1,101 @@ +use super::{sha512sig0, sha512sig1, sha512sum0, sha512sum1}; +use crate::consts::K64; + +#[target_feature(enable = "zknh")] +pub(in super::super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + for block in blocks.iter().map(super::utils::load_block) { + compress_block(state, block); + } +} + +#[target_feature(enable = "zknh")] +fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { + let s = &mut state.clone(); + let b = &mut block; + + for i in 0..4 { + let k = &K64[16 * i..]; + round_schedule::<0>(s, b, k); + round_schedule::<1>(s, b, k); + round_schedule::<2>(s, b, k); + round_schedule::<3>(s, b, k); + round_schedule::<4>(s, b, k); + round_schedule::<5>(s, b, k); + round_schedule::<6>(s, b, k); + round_schedule::<7>(s, b, k); + round_schedule::<8>(s, b, k); + round_schedule::<9>(s, b, k); + round_schedule::<10>(s, b, k); + round_schedule::<11>(s, b, k); + round_schedule::<12>(s, b, k); + round_schedule::<13>(s, b, k); + round_schedule::<14>(s, b, k); + round_schedule::<15>(s, b, k); + } + + let k = &K64[64..]; + round::<0>(s, b, k); + round::<1>(s, b, k); + round::<2>(s, b, k); + round::<3>(s, b, k); + round::<4>(s, b, k); + round::<5>(s, b, k); + round::<6>(s, b, k); + round::<7>(s, b, k); + round::<8>(s, b, k); + round::<9>(s, b, k); + round::<10>(s, b, k); + round::<11>(s, b, k); + round::<12>(s, b, k); + round::<13>(s, b, k); + round::<14>(s, b, k); + round::<15>(s, b, k); + + for i in 0..8 { + state[i] = state[i].wrapping_add(s[i]); + } +} + +#[target_feature(enable = "zknh")] +fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], k: &[u64]) { + round::(state, block, k); + + block[R] = block[R] + .wrapping_add(sha512sig1(block[(R + 14) % 16])) + .wrapping_add(block[(R + 9) % 16]) + .wrapping_add(sha512sig0(block[(R + 1) % 16])); +} + +#[target_feature(enable = "zknh")] +fn round(state: &mut [u64; 8], block: &[u64; 16], k: &[u64]) { + let n = K64.len() - R; + #[allow(clippy::identity_op)] + let a = (n + 0) % 8; + let b = (n + 1) % 8; + let c = (n + 2) % 8; + let d = (n + 3) % 8; + let e = (n + 4) % 8; + let f = (n + 5) % 8; + let g = (n + 6) % 8; + let h = (n + 7) % 8; + + state[h] = state[h] + .wrapping_add(sha512sum1(state[e])) + .wrapping_add(ch(state[e], state[f], state[g])) + .wrapping_add(super::utils::opaque_load::(k)) + .wrapping_add(block[R]); + state[d] = state[d].wrapping_add(state[h]); + state[h] = state[h] + .wrapping_add(sha512sum0(state[a])) + .wrapping_add(maj(state[a], state[b], state[c])) +} + +#[inline(always)] +fn ch(x: u64, y: u64, z: u64) -> u64 { + (x & y) ^ (!x & z) +} + +#[inline(always)] +fn maj(x: u64, y: u64, z: u64) -> u64 { + (x & y) ^ (x & z) ^ (y & z) +} diff --git a/sha2/src/sha512/riscv_zknh_utils.rs b/sha2/src/sha512/riscv_zknh/utils.rs similarity index 98% rename from sha2/src/sha512/riscv_zknh_utils.rs rename to sha2/src/sha512/riscv_zknh/utils.rs index 41197d119..440682ea7 100644 --- a/sha2/src/sha512/riscv_zknh_utils.rs +++ b/sha2/src/sha512/riscv_zknh/utils.rs @@ -129,8 +129,10 @@ fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { } /// This function returns `k[R]`, but prevents compiler from inlining the indexed value -#[cfg(sha2_backend = "riscv-zknh")] +#[cfg(not(sha2_backend_riscv_zknh = "compact"))] pub(super) fn opaque_load(k: &[u64]) -> u64 { + use core::arch::asm; + assert!(R < k.len()); #[cfg(target_arch = "riscv64")] unsafe { diff --git a/sha2/src/sha512/riscv_zknh_compact.rs b/sha2/src/sha512/riscv_zknh_compact.rs deleted file mode 100644 index 840b80ff0..000000000 --- a/sha2/src/sha512/riscv_zknh_compact.rs +++ /dev/null @@ -1,108 +0,0 @@ -use crate::consts::K64; - -#[cfg(target_arch = "riscv32")] -use core::arch::riscv32::*; -#[cfg(target_arch = "riscv64")] -use core::arch::riscv64::*; - -#[cfg(not(all( - target_feature = "zknh", - any(target_feature = "zbb", target_feature = "zbkb") -)))] -compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features"); - -#[cfg(target_arch = "riscv32")] -#[allow(unsafe_op_in_unsafe_fn)] -unsafe fn sha512sum0(x: u64) -> u64 { - let a = sha512sum0r((x >> 32) as u32, x as u32); - let b = sha512sum0r(x as u32, (x >> 32) as u32); - ((a as u64) << 32) | (b as u64) -} - -#[cfg(target_arch = "riscv32")] -#[allow(unsafe_op_in_unsafe_fn)] -unsafe fn sha512sum1(x: u64) -> u64 { - let a = sha512sum1r((x >> 32) as u32, x as u32); - let b = sha512sum1r(x as u32, (x >> 32) as u32); - ((a as u64) << 32) | (b as u64) -} - -#[cfg(target_arch = "riscv32")] -#[allow(unsafe_op_in_unsafe_fn)] -unsafe fn sha512sig0(x: u64) -> u64 { - let a = sha512sig0h((x >> 32) as u32, x as u32); - let b = sha512sig0l(x as u32, (x >> 32) as u32); - ((a as u64) << 32) | (b as u64) -} - -#[cfg(target_arch = "riscv32")] -#[allow(unsafe_op_in_unsafe_fn)] -unsafe fn sha512sig1(x: u64) -> u64 { - let a = sha512sig1h((x >> 32) as u32, x as u32); - let b = sha512sig1l(x as u32, (x >> 32) as u32); - ((a as u64) << 32) | (b as u64) -} - -#[inline(always)] -fn ch(x: u64, y: u64, z: u64) -> u64 { - (x & y) ^ (!x & z) -} - -#[inline(always)] -fn maj(x: u64, y: u64, z: u64) -> u64 { - (x & y) ^ (x & z) ^ (y & z) -} - -#[inline(always)] -fn round(state: &mut [u64; 8], block: &[u64; 16], r: usize) { - let n = K64.len() - r; - #[allow(clippy::identity_op)] - let a = (n + 0) % 8; - let b = (n + 1) % 8; - let c = (n + 2) % 8; - let d = (n + 3) % 8; - let e = (n + 4) % 8; - let f = (n + 5) % 8; - let g = (n + 6) % 8; - let h = (n + 7) % 8; - - state[h] = state[h] - .wrapping_add(unsafe { sha512sum1(state[e]) }) - .wrapping_add(ch(state[e], state[f], state[g])) - .wrapping_add(K64[r]) - .wrapping_add(block[r % 16]); - state[d] = state[d].wrapping_add(state[h]); - state[h] = state[h] - .wrapping_add(unsafe { sha512sum0(state[a]) }) - .wrapping_add(maj(state[a], state[b], state[c])) -} - -#[inline(always)] -fn schedule(block: &mut [u64; 16], r: usize) { - block[r % 16] = block[r % 16] - .wrapping_add(unsafe { sha512sig1(block[(r + 14) % 16]) }) - .wrapping_add(block[(r + 9) % 16]) - .wrapping_add(unsafe { sha512sig0(block[(r + 1) % 16]) }); -} - -#[inline(always)] -fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { - let mut s = *state; - - for r in 0..80 { - round(&mut s, &block, r); - if r < 64 { - schedule(&mut block, r) - } - } - - for i in 0..8 { - state[i] = state[i].wrapping_add(s[i]); - } -} - -pub(super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { - compress_block(state, block); - } -} diff --git a/sha2/src/sha512/soft.rs b/sha2/src/sha512/soft.rs index 8a75a8b44..713ac84ab 100644 --- a/sha2/src/sha512/soft.rs +++ b/sha2/src/sha512/soft.rs @@ -1,94 +1,17 @@ -use crate::consts::K64; - -#[rustfmt::skip] -macro_rules! repeat80 { - ($i:ident, $b:block) => { - let $i = 0; $b; let $i = 1; $b; let $i = 2; $b; let $i = 3; $b; - let $i = 4; $b; let $i = 5; $b; let $i = 6; $b; let $i = 7; $b; - let $i = 8; $b; let $i = 9; $b; let $i = 10; $b; let $i = 11; $b; - let $i = 12; $b; let $i = 13; $b; let $i = 14; $b; let $i = 15; $b; - let $i = 16; $b; let $i = 17; $b; let $i = 18; $b; let $i = 19; $b; - let $i = 20; $b; let $i = 21; $b; let $i = 22; $b; let $i = 23; $b; - let $i = 24; $b; let $i = 25; $b; let $i = 26; $b; let $i = 27; $b; - let $i = 28; $b; let $i = 29; $b; let $i = 30; $b; let $i = 31; $b; - let $i = 32; $b; let $i = 33; $b; let $i = 34; $b; let $i = 35; $b; - let $i = 36; $b; let $i = 37; $b; let $i = 38; $b; let $i = 39; $b; - let $i = 40; $b; let $i = 41; $b; let $i = 42; $b; let $i = 43; $b; - let $i = 44; $b; let $i = 45; $b; let $i = 46; $b; let $i = 47; $b; - let $i = 48; $b; let $i = 49; $b; let $i = 50; $b; let $i = 51; $b; - let $i = 52; $b; let $i = 53; $b; let $i = 54; $b; let $i = 55; $b; - let $i = 56; $b; let $i = 57; $b; let $i = 58; $b; let $i = 59; $b; - let $i = 60; $b; let $i = 61; $b; let $i = 62; $b; let $i = 63; $b; - let $i = 64; $b; let $i = 65; $b; let $i = 66; $b; let $i = 67; $b; - let $i = 68; $b; let $i = 69; $b; let $i = 70; $b; let $i = 71; $b; - let $i = 72; $b; let $i = 73; $b; let $i = 74; $b; let $i = 75; $b; - let $i = 76; $b; let $i = 77; $b; let $i = 78; $b; let $i = 79; $b; - }; -} - -/// Read round constant -fn rk(i: usize) -> u64 { - // `read_volatile` forces the compiler to read round constants from the static - // instead of inlining them, which improves codegen and performance - unsafe { - let p = K64.as_ptr().add(i); - core::ptr::read_volatile(p) +cfg_if::cfg_if! { + if #[cfg(sha2_backend_soft = "compact")] { + mod compact; + pub(super) use compact::compress; + } else { + mod unroll; + pub(super) use unroll::compress; } } -/// Process a block with the SHA-512 algorithm. -fn compress_block(state: &mut [u64; 8], block: &[u8; 128]) { - let mut block = super::to_u64s(block); - let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = *state; - - repeat80!(i, { - let w = if i < 16 { - block[i] - } else { - let w15 = block[(i - 15) % 16]; - let s0 = (w15.rotate_right(1)) ^ (w15.rotate_right(8)) ^ (w15 >> 7); - let w2 = block[(i - 2) % 16]; - let s1 = (w2.rotate_right(19)) ^ (w2.rotate_right(61)) ^ (w2 >> 6); - block[i % 16] = block[i % 16] - .wrapping_add(s0) - .wrapping_add(block[(i - 7) % 16]) - .wrapping_add(s1); - block[i % 16] - }; - - let s1 = e.rotate_right(14) ^ e.rotate_right(18) ^ e.rotate_right(41); - let ch = (e & f) ^ ((!e) & g); - let t1 = s1 - .wrapping_add(ch) - .wrapping_add(rk(i)) - .wrapping_add(w) - .wrapping_add(h); - let s0 = a.rotate_right(28) ^ a.rotate_right(34) ^ a.rotate_right(39); - let maj = (a & b) ^ (a & c) ^ (b & c); - let t2 = s0.wrapping_add(maj); - - h = g; - g = f; - f = e; - e = d.wrapping_add(t1); - d = c; - c = b; - b = a; - a = t1.wrapping_add(t2); - }); - - state[0] = state[0].wrapping_add(a); - state[1] = state[1].wrapping_add(b); - state[2] = state[2].wrapping_add(c); - state[3] = state[3].wrapping_add(d); - state[4] = state[4].wrapping_add(e); - state[5] = state[5].wrapping_add(f); - state[6] = state[6].wrapping_add(g); - state[7] = state[7].wrapping_add(h); -} - -pub(super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks { - compress_block(state, block); - } +#[inline(always)] +fn to_u64s(block: &[u8; 128]) -> [u64; 16] { + core::array::from_fn(|i| { + let chunk = block[8 * i..][..8].try_into().unwrap(); + u64::from_be_bytes(chunk) + }) } diff --git a/sha2/src/sha512/soft_compact.rs b/sha2/src/sha512/soft/compact.rs similarity index 95% rename from sha2/src/sha512/soft_compact.rs rename to sha2/src/sha512/soft/compact.rs index 7ba83fa3b..e7cc82372 100644 --- a/sha2/src/sha512/soft_compact.rs +++ b/sha2/src/sha512/soft/compact.rs @@ -49,7 +49,7 @@ fn compress_u64(state: &mut [u64; 8], block: [u64; 16]) { state[7] = state[7].wrapping_add(h); } -pub(super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { +pub(in super::super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { for block in blocks.iter() { compress_u64(state, super::to_u64s(block)); } diff --git a/sha2/src/sha512/soft/unroll.rs b/sha2/src/sha512/soft/unroll.rs new file mode 100644 index 000000000..77a9db71c --- /dev/null +++ b/sha2/src/sha512/soft/unroll.rs @@ -0,0 +1,94 @@ +use crate::consts::K64; + +#[rustfmt::skip] +macro_rules! repeat80 { + ($i:ident, $b:block) => { + let $i = 0; $b; let $i = 1; $b; let $i = 2; $b; let $i = 3; $b; + let $i = 4; $b; let $i = 5; $b; let $i = 6; $b; let $i = 7; $b; + let $i = 8; $b; let $i = 9; $b; let $i = 10; $b; let $i = 11; $b; + let $i = 12; $b; let $i = 13; $b; let $i = 14; $b; let $i = 15; $b; + let $i = 16; $b; let $i = 17; $b; let $i = 18; $b; let $i = 19; $b; + let $i = 20; $b; let $i = 21; $b; let $i = 22; $b; let $i = 23; $b; + let $i = 24; $b; let $i = 25; $b; let $i = 26; $b; let $i = 27; $b; + let $i = 28; $b; let $i = 29; $b; let $i = 30; $b; let $i = 31; $b; + let $i = 32; $b; let $i = 33; $b; let $i = 34; $b; let $i = 35; $b; + let $i = 36; $b; let $i = 37; $b; let $i = 38; $b; let $i = 39; $b; + let $i = 40; $b; let $i = 41; $b; let $i = 42; $b; let $i = 43; $b; + let $i = 44; $b; let $i = 45; $b; let $i = 46; $b; let $i = 47; $b; + let $i = 48; $b; let $i = 49; $b; let $i = 50; $b; let $i = 51; $b; + let $i = 52; $b; let $i = 53; $b; let $i = 54; $b; let $i = 55; $b; + let $i = 56; $b; let $i = 57; $b; let $i = 58; $b; let $i = 59; $b; + let $i = 60; $b; let $i = 61; $b; let $i = 62; $b; let $i = 63; $b; + let $i = 64; $b; let $i = 65; $b; let $i = 66; $b; let $i = 67; $b; + let $i = 68; $b; let $i = 69; $b; let $i = 70; $b; let $i = 71; $b; + let $i = 72; $b; let $i = 73; $b; let $i = 74; $b; let $i = 75; $b; + let $i = 76; $b; let $i = 77; $b; let $i = 78; $b; let $i = 79; $b; + }; +} + +/// Read round constant +fn rk(i: usize) -> u64 { + // `read_volatile` forces the compiler to read round constants from the static + // instead of inlining them, which improves codegen and performance + unsafe { + let p = K64.as_ptr().add(i); + core::ptr::read_volatile(p) + } +} + +/// Process a block with the SHA-512 algorithm. +fn compress_block(state: &mut [u64; 8], block: &[u8; 128]) { + let mut block = super::to_u64s(block); + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = *state; + + repeat80!(i, { + let w = if i < 16 { + block[i] + } else { + let w15 = block[(i - 15) % 16]; + let s0 = (w15.rotate_right(1)) ^ (w15.rotate_right(8)) ^ (w15 >> 7); + let w2 = block[(i - 2) % 16]; + let s1 = (w2.rotate_right(19)) ^ (w2.rotate_right(61)) ^ (w2 >> 6); + block[i % 16] = block[i % 16] + .wrapping_add(s0) + .wrapping_add(block[(i - 7) % 16]) + .wrapping_add(s1); + block[i % 16] + }; + + let s1 = e.rotate_right(14) ^ e.rotate_right(18) ^ e.rotate_right(41); + let ch = (e & f) ^ ((!e) & g); + let t1 = s1 + .wrapping_add(ch) + .wrapping_add(rk(i)) + .wrapping_add(w) + .wrapping_add(h); + let s0 = a.rotate_right(28) ^ a.rotate_right(34) ^ a.rotate_right(39); + let maj = (a & b) ^ (a & c) ^ (b & c); + let t2 = s0.wrapping_add(maj); + + h = g; + g = f; + f = e; + e = d.wrapping_add(t1); + d = c; + c = b; + b = a; + a = t1.wrapping_add(t2); + }); + + state[0] = state[0].wrapping_add(a); + state[1] = state[1].wrapping_add(b); + state[2] = state[2].wrapping_add(c); + state[3] = state[3].wrapping_add(d); + state[4] = state[4].wrapping_add(e); + state[5] = state[5].wrapping_add(f); + state[6] = state[6].wrapping_add(g); + state[7] = state[7].wrapping_add(h); +} + +pub(in super::super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + for block in blocks { + compress_block(state, block); + } +} diff --git a/sha2/src/sha512/x86_avx2.rs b/sha2/src/sha512/x86_avx2.rs index 8c36952d4..f53d5a7ef 100644 --- a/sha2/src/sha512/x86_avx2.rs +++ b/sha2/src/sha512/x86_avx2.rs @@ -11,22 +11,8 @@ use core::arch::x86_64::*; use crate::consts::K64; -cpufeatures::new!(avx2_cpuid, "avx2"); - -pub(super) fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - // TODO: Replace with https://github.com/rust-lang/rfcs/pull/2725 - // after stabilization - if avx2_cpuid::get() { - unsafe { - sha512_compress_x86_64_avx2(state, blocks); - } - } else { - super::soft::compress(state, blocks); - } -} - #[target_feature(enable = "avx2")] -unsafe fn sha512_compress_x86_64_avx2(state: &mut [u64; 8], blocks: &[[u8; 128]]) { +pub(super) unsafe fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { let mut start_block = 0; if blocks.len() & 0b1 != 0 {