From ec02c4f1c89f3689b0218c341a5e0b418cedb0e5 Mon Sep 17 00:00:00 2001 From: Sarfaraz Nawaz Date: Fri, 31 Oct 2025 23:24:46 +0530 Subject: [PATCH 1/2] feat: Handle account shrinking/expansion in merge_diff_copy --- src/diff/algorithm.rs | 249 +++++++++++++++++++++++++++++++++++++----- src/error.rs | 2 + 2 files changed, 221 insertions(+), 30 deletions(-) diff --git a/src/diff/algorithm.rs b/src/diff/algorithm.rs index 1d7af04..0f06180 100644 --- a/src/diff/algorithm.rs +++ b/src/diff/algorithm.rs @@ -232,33 +232,113 @@ pub fn apply_diff_copy(original: &[u8], diffset: &DiffSet<'_>) -> Result }) } -/// This function constructs destination by merging original with diff such that destination -/// becomes the changed version of the original. +/// Constructs destination by applying the diff to original, such that destination becomes the +/// post-diff state of the original. /// /// Precondition: -/// - destination.len() == original.len() +/// - destination.len() == diffset.changed_len() +/// - original.len() may differ from destination.len() to allow Solana +/// account resizing (shrink or expand). +/// Assumption: +/// - destination is assumed to be zero-initialized. That automatically holds true for freshly +/// allocated Solana account data. The function does NOT validate this assumption for performance reason. +/// Returns: +/// - Ok(n) where n is number of bytes written to destination. +/// - if n < destination.len(), then the last (destination.len() - n) bytes are not written by this function +/// and are assumed to be already zero-initialized. Callers may write to those bytes starting at index `n`. +/// - else n == destination.len(). +/// Notes: +/// - Merge consists of: +/// - bytes covered by diff segments are written from diffset. +/// - unmodified regions are copied directly from original. +/// - In shrink case, extra trailing bytes from original are ignored. +/// - In expansion case, any remaining bytes beyond both the diff coverage +/// and original.len() stay unwritten and are assumed to be zero-initialized. +/// pub fn merge_diff_copy( destination: &mut [u8], original: &[u8], diffset: &DiffSet<'_>, -) -> Result<(), ProgramError> { - if destination.len() != original.len() { +) -> Result { + if destination.len() != diffset.changed_len() { return Err(DlpError::MergeDiffError.into()); } + let mut write_index = 0; for item in diffset.iter() { let (diff_segment, OffsetInData { start, end }) = item?; + if write_index < start { + if start > original.len() { + return Err(DlpError::InvalidDiff.into()); + } // copy the unchanged bytes destination[write_index..start].copy_from_slice(&original[write_index..start]); } + destination[start..end].copy_from_slice(diff_segment); write_index = end; } - if write_index < original.len() { - destination[write_index..].copy_from_slice(&original[write_index..]); - } - Ok(()) + + // Ensure we have overwritten all bytes in destination, otherwise "construction" of destination + // will be considered incomplete. + let num_bytes_written = match write_index.cmp(&destination.len()) { + Ordering::Equal => { + // It means the destination is fully constructed. + // Nothing to do here. + + // It is possible that destination.len() <= original.len() i.e destination might have shrunk + // in which case we do not care about those bytes of original which are not part of + // destination anymore. + write_index + } + Ordering::Less => { + // destination is NOT fully constructed yet. Few bytes in the destination are still unwritten. + // Let's say the number of these unwritten bytes is: N. + // + // Now how do we construct these N unwritten bytes? We have already processed the + // diffset, so now where could the values for these N bytes come from? + // + // There are 3 scenarios: + // - All N bytes must be copied from remaining region of the original: + // - that means, destination.len() <= original.len() + // - and the destination might have shrunk, in which case we do not care about + // the extra bytes in the original: they're discarded. + // - Only (N-M) bytes come from original and the rest M bytes stay unwritten and are + // "assumed" to be already zero-initialized. + // - that means, destination.len() > original.len() + // - write_index + (N-M) == original.len() + // - and the destination has expanded. + // - None of these N bytes come from original. It's basically a special case of + // the second scenario: when M = N i.e all N bytes stay unwritten. + // - that means, destination.len() > original.len() + // - and also, write_index == original.len(). + // - the destination has expanded just like the above case. + // - all N bytes are "assumed" to be already zero-initialized (by the caller) + + if destination.len() <= original.len() { + // case: all n bytes come from original + let dest_len = destination.len(); + destination[write_index..].copy_from_slice(&original[write_index..dest_len]); + dest_len + } else if write_index < original.len() { + // case: some bytes come from original and the rest are "assumed" to be + // zero-initialized (by the caller). + destination[write_index..original.len()].copy_from_slice(&original[write_index..]); + original.len() + } else { + // case: all N bytes are "assumed" to be zero-initialized (by the caller). + write_index + } + } + Ordering::Greater => { + // It is an impossible scenario. Even if the diff is corrupt, or the lengths of destinatiare are same + // or different, we'll not encounter this case. It only implies logic error. + return Err(DlpError::InfallibleError.into()); + } + }; + + Ok(num_bytes_written) } // private function that does the actual work. @@ -297,6 +377,58 @@ mod tests { ); } + fn get_example_expected_diff( + changed_len: usize, + // additional_changes must apply after index 78 (index-in-data) !! + additional_changes: Vec<(u32, &[u8])>, + ) -> Vec { + // expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 | + + let mut expected_diff = vec![]; + + // changed_len (u32) + expected_diff.extend_from_slice(&(changed_len as u32).to_le_bytes()); + + if additional_changes.is_empty() { + // 2 (u32) + expected_diff.extend_from_slice(&2u32.to_le_bytes()); + } else { + expected_diff + .extend_from_slice(&(2u32 + additional_changes.len() as u32).to_le_bytes()); + } + + // -- offsets + + // 0 11 (each u32) + expected_diff.extend_from_slice(&0u32.to_le_bytes()); + expected_diff.extend_from_slice(&11u32.to_le_bytes()); + + // 4 71 (each u32) + expected_diff.extend_from_slice(&4u32.to_le_bytes()); + expected_diff.extend_from_slice(&71u32.to_le_bytes()); + + let mut offset_in_diff = 12u32; + for (offset_in_data, diff) in additional_changes.iter() { + expected_diff.extend_from_slice(&offset_in_diff.to_le_bytes()); + expected_diff.extend_from_slice(&offset_in_data.to_le_bytes()); + offset_in_diff += diff.len() as u32; + } + + // -- segments -- + + // 11 12 13 14 (each u8) + expected_diff.extend_from_slice(&0x01020304u32.to_le_bytes()); + // 71 72 ... 78 (each u8) + expected_diff.extend_from_slice(&0x0102030405060708u64.to_le_bytes()); + + // append diff from additional_changes + for (_, diff) in additional_changes.iter() { + expected_diff.extend_from_slice(diff); + } + + expected_diff + } + #[test] fn test_using_example_data() { let original = [0; 100]; @@ -311,42 +443,99 @@ mod tests { let actual_diff = compute_diff(&original, &changed); let actual_diffset = DiffSet::try_new(&actual_diff).unwrap(); - let expected_diff = { - // expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 | + let expected_diff = get_example_expected_diff(changed.len(), vec![]); - let mut serialized = vec![]; + assert_eq!(actual_diff.len(), 4 + 4 + 8 + 8 + (4 + 8)); + assert_eq!(actual_diff.as_slice(), expected_diff.as_slice()); - // 100 (u32) - serialized.extend_from_slice(&(changed.len() as u32).to_le_bytes()); + let expected_changed = apply_diff_copy(&original, &actual_diffset).unwrap(); - // 2 (u32) - serialized.extend_from_slice(&2u32.to_le_bytes()); + assert_eq!(changed.as_slice(), expected_changed.as_slice()); + + let expected_changed = { + let mut destination = vec![255; original.len()]; + merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap(); + destination + }; + + assert_eq!(changed.as_slice(), expected_changed.as_slice()); + } - // 0 11 (each u32) - serialized.extend_from_slice(&0u32.to_le_bytes()); - serialized.extend_from_slice(&11u32.to_le_bytes()); + #[test] + fn test_shrunk_account_data() { + // Note that changed_len cannot be lower than 79 because the last "changed" index is + // 78 in the diff. + const CHANGED_LEN: usize = 80; - // 4 71 (each u32) - serialized.extend_from_slice(&4u32.to_le_bytes()); - serialized.extend_from_slice(&71u32.to_le_bytes()); + let original = vec![0; 100]; + let changed = { + let mut copy = original.clone(); + copy.truncate(CHANGED_LEN); - // 11 12 13 14 (each u8) - serialized.extend_from_slice(&0x01020304u32.to_le_bytes()); - // 71 72 ... 78 (each u8) - serialized.extend_from_slice(&0x0102030405060708u64.to_le_bytes()); - serialized + // | 11 | 12 | 13 | 14 | + copy[11..=14].copy_from_slice(&0x01020304u32.to_le_bytes()); + // | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | + copy[71..=78].copy_from_slice(&0x0102030405060708u64.to_le_bytes()); + copy }; + let actual_diff = compute_diff(&original, &changed); + + let actual_diffset = DiffSet::try_new(&actual_diff).unwrap(); + + let expected_diff = get_example_expected_diff(CHANGED_LEN, vec![]); + assert_eq!(actual_diff.len(), 4 + 4 + 8 + 8 + (4 + 8)); assert_eq!(actual_diff.as_slice(), expected_diff.as_slice()); - let expected_changed = apply_diff_copy(&original, &actual_diffset).unwrap(); + let expected_changed = { + let mut destination = vec![255; CHANGED_LEN]; + merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap(); + destination + }; assert_eq!(changed.as_slice(), expected_changed.as_slice()); + } + + #[test] + fn test_expanded_account_data() { + const CHANGED_LEN: usize = 120; + + let original = vec![0; 100]; + let changed = { + let mut copy = original.clone(); + copy.resize(CHANGED_LEN, 0); // new bytes are zero-initialized + + // | 11 | 12 | 13 | 14 | + copy[11..=14].copy_from_slice(&0x01020304u32.to_le_bytes()); + // | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | + copy[71..=78].copy_from_slice(&0x0102030405060708u64.to_le_bytes()); + copy + }; + + let actual_diff = compute_diff(&original, &changed); + + let actual_diffset = DiffSet::try_new(&actual_diff).unwrap(); + + // When an account expands, the extra bytes at the end become part of the diff, even if + // all of them are zeroes, that is why (100, &[0; 32]) is passed as additional_changes to + // the following function. + // + // TODO (snawaz): we could optimize compute_diff to not include the zero bytes which are + // part of the expansion. + let expected_diff = get_example_expected_diff(CHANGED_LEN, vec![(100, &[0; 20])]); + + assert_eq!(actual_diff.len(), 4 + 4 + (8 + 8) + (4 + 8) + (4 + 4 + 20)); + assert_eq!(actual_diff.as_slice(), expected_diff.as_slice()); let expected_changed = { - let mut destination = vec![255; original.len()]; - merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap(); + let mut destination = vec![255; CHANGED_LEN]; + let written = merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap(); + + // TODO (snawaz): written == 120, is because currently the expanded bytes are part of the diff. + // Once compute_diff is optimized further, written must be 100. + assert_eq!(written, 120); + destination }; diff --git a/src/error.rs b/src/error.rs index 99565d4..0931b4d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -81,6 +81,8 @@ pub enum DlpError { UndelegateBufferAlreadyInitialized = 36, #[error("Undelegate buffer PDA immutable")] UndelegateBufferImmutable = 37, + #[error("An infallible error is encountered possibly due to logic error")] + InfallibleError = 100, } impl From for ProgramError { From f860453c435459b584a5df1947a30e75c0193d38 Mon Sep 17 00:00:00 2001 From: Sarfaraz Nawaz Date: Sat, 1 Nov 2025 22:59:39 +0530 Subject: [PATCH 2/2] Return unwritten bytes instead of the length of written bytes --- src/diff/algorithm.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diff/algorithm.rs b/src/diff/algorithm.rs index 0f06180..80aa700 100644 --- a/src/diff/algorithm.rs +++ b/src/diff/algorithm.rs @@ -243,10 +243,8 @@ pub fn apply_diff_copy(original: &[u8], diffset: &DiffSet<'_>) -> Result /// - destination is assumed to be zero-initialized. That automatically holds true for freshly /// allocated Solana account data. The function does NOT validate this assumption for performance reason. /// Returns: -/// - Ok(n) where n is number of bytes written to destination. -/// - if n < destination.len(), then the last (destination.len() - n) bytes are not written by this function -/// and are assumed to be already zero-initialized. Callers may write to those bytes starting at index `n`. -/// - else n == destination.len(). +/// - Ok(&mut [u8]) where the slice contains the trailing unwritten bytes in destination and are +/// assumed to be already zero-initialized. Callers may write to those bytes or validate it. /// Notes: /// - Merge consists of: /// - bytes covered by diff segments are written from diffset. @@ -255,11 +253,11 @@ pub fn apply_diff_copy(original: &[u8], diffset: &DiffSet<'_>) -> Result /// - In expansion case, any remaining bytes beyond both the diff coverage /// and original.len() stay unwritten and are assumed to be zero-initialized. /// -pub fn merge_diff_copy( - destination: &mut [u8], +pub fn merge_diff_copy<'a>( + destination: &'a mut [u8], original: &[u8], diffset: &DiffSet<'_>, -) -> Result { +) -> Result<&'a mut [u8], ProgramError> { if destination.len() != diffset.changed_len() { return Err(DlpError::MergeDiffError.into()); } @@ -338,7 +336,9 @@ pub fn merge_diff_copy( } }; - Ok(num_bytes_written) + let (_, unwritten_bytes) = destination.split_at_mut(num_bytes_written); + + Ok(unwritten_bytes) } // private function that does the actual work. @@ -530,11 +530,11 @@ mod tests { let expected_changed = { let mut destination = vec![255; CHANGED_LEN]; - let written = merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap(); + let unwritten = merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap(); - // TODO (snawaz): written == 120, is because currently the expanded bytes are part of the diff. - // Once compute_diff is optimized further, written must be 100. - assert_eq!(written, 120); + // TODO (snawaz): unwritten == &mut [], is because currently the expanded bytes are part of the diff. + // Once compute_diff is optimized further, written must be &mut [0; 20]. + assert_eq!(unwritten, &mut []); destination };