Skip to content

Commit c6c1898

Browse files
committed
feat: Handle account shrinking/expansion in merge_diff_copy
1 parent e8d0393 commit c6c1898

File tree

2 files changed

+221
-30
lines changed

2 files changed

+221
-30
lines changed

src/diff/algorithm.rs

Lines changed: 219 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -232,33 +232,113 @@ pub fn apply_diff_copy(original: &[u8], diffset: &DiffSet<'_>) -> Result<Vec<u8>
232232
})
233233
}
234234

235-
/// This function constructs destination by merging original with diff such that destination
236-
/// becomes the changed version of the original.
235+
/// Constructs destination by applying the diff to original, such that destination becomes the
236+
/// post-diff state of the original.
237237
///
238238
/// Precondition:
239-
/// - destination.len() == original.len()
239+
/// - destination.len() == diffset.changed_len()
240+
/// - original.len() may differ from destination.len() to allow Solana
241+
/// account resizing (shrink or expand).
242+
/// Assumption:
243+
/// - destination is assumed to be zero-initialized. That automatically holds true for freshly
244+
/// allocated Solana account data. The function does NOT validate this assumption for performance reason.
245+
/// Returns:
246+
/// - Ok(n) where n is number of bytes written to destination.
247+
/// - if n < destination.len(), then the last (destination.len() - n) bytes are not written by this function
248+
/// and are assumed to be already zero-initialized. Callers may write to those bytes starting at index `n`.
249+
/// - else n == destination.len().
250+
/// Notes:
251+
/// - Merge consists of:
252+
/// - bytes covered by diff segments are written from diffset.
253+
/// - unmodified regions are copied directly from original.
254+
/// - In shrink case, extra trailing bytes from original are ignored.
255+
/// - In expansion case, any remaining bytes beyond both the diff coverage
256+
/// and original.len() stay unwritten and are assumed to be zero-initialized.
257+
///
240258
pub fn merge_diff_copy(
241259
destination: &mut [u8],
242260
original: &[u8],
243261
diffset: &DiffSet<'_>,
244-
) -> Result<(), ProgramError> {
245-
if destination.len() != original.len() {
262+
) -> Result<usize, ProgramError> {
263+
if destination.len() != diffset.changed_len() {
246264
return Err(DlpError::MergeDiffError.into());
247265
}
266+
248267
let mut write_index = 0;
249268
for item in diffset.iter() {
250269
let (diff_segment, OffsetInData { start, end }) = item?;
270+
251271
if write_index < start {
272+
if start > original.len() {
273+
return Err(DlpError::InvalidDiff.into());
274+
}
252275
// copy the unchanged bytes
253276
destination[write_index..start].copy_from_slice(&original[write_index..start]);
254277
}
278+
255279
destination[start..end].copy_from_slice(diff_segment);
256280
write_index = end;
257281
}
258-
if write_index < original.len() {
259-
destination[write_index..].copy_from_slice(&original[write_index..]);
260-
}
261-
Ok(())
282+
283+
// Ensure we have overwritten all bytes in destination, otherwise "construction" of destination
284+
// will be considered incomplete.
285+
let num_bytes_written = match write_index.cmp(&destination.len()) {
286+
Ordering::Equal => {
287+
// It means the destination is fully constructed.
288+
// Nothing to do here.
289+
290+
// It is possible that destination.len() <= original.len() i.e destination might have shrunk
291+
// in which case we do not care about those bytes of original which are not part of
292+
// destination anymore.
293+
write_index
294+
}
295+
Ordering::Less => {
296+
// destination is NOT fully constructed yet. Few bytes in the destination are still unwritten.
297+
// Let's say the number of these unwritten bytes is: N.
298+
//
299+
// Now how do we construct these N unwritten bytes? We have already processed the
300+
// diffset, so now where could the values for these N bytes come from?
301+
//
302+
// There are 3 scenarios:
303+
// - All N bytes must be copied from remaining region of the original:
304+
// - that means, destination.len() <= original.len()
305+
// - and the destination might have shrunk, in which case we do not care about
306+
// the extra bytes in the original: they're discarded.
307+
// - Only (N-M) bytes come from original and the rest M bytes stay unwritten and are
308+
// "assumed" to be already zero-initialized.
309+
// - that means, destination.len() > original.len()
310+
// - write_index + (N-M) == original.len()
311+
// - and the destination has expanded.
312+
// - None of these N bytes come from original. It's basically a special case of
313+
// the second scenario: when M = N i.e all N bytes stay unwritten.
314+
// - that means, destination.len() > original.len()
315+
// - and also, write_index == original.len().
316+
// - the destination has expanded just like the above case.
317+
// - all N bytes are "assumed" to be already zero-initialized (by the caller)
318+
319+
if destination.len() <= original.len() {
320+
// case: all n bytes come from original
321+
let dest_len = destination.len();
322+
destination[write_index..].copy_from_slice(&original[write_index..dest_len]);
323+
dest_len
324+
} else if write_index < original.len() {
325+
// case: some bytes come from original and the rest are "assumed" to be
326+
// zero-initialized (by the caller).
327+
destination[write_index..original.len()].copy_from_slice(&original[write_index..]);
328+
original.len()
329+
} else {
330+
// case: all N bytes are "assumed" to be zero-initialized (by the caller).
331+
write_index
332+
}
333+
}
334+
Ordering::Greater => {
335+
// It is an impossible scenario. Even if the diff is corrupt, or the lengths of destinatiare are same
336+
// or different, we'll not encounter this case. It only implies logic error.
337+
return Err(DlpError::InfallibleError.into());
338+
}
339+
};
340+
341+
Ok(num_bytes_written)
262342
}
263343

264344
// private function that does the actual work.
@@ -297,6 +377,58 @@ mod tests {
297377
);
298378
}
299379

380+
fn get_example_expected_diff(
381+
changed_len: usize,
382+
// additional_changes must apply after index 78 (index-in-data) !!
383+
additional_changes: Vec<(u32, &[u8])>,
384+
) -> Vec<u8> {
385+
// expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
386+
387+
let mut expected_diff = vec![];
388+
389+
// changed_len (u32)
390+
expected_diff.extend_from_slice(&(changed_len as u32).to_le_bytes());
391+
392+
if additional_changes.is_empty() {
393+
// 2 (u32)
394+
expected_diff.extend_from_slice(&2u32.to_le_bytes());
395+
} else {
396+
expected_diff
397+
.extend_from_slice(&(2u32 + additional_changes.len() as u32).to_le_bytes());
398+
}
399+
400+
// -- offsets
401+
402+
// 0 11 (each u32)
403+
expected_diff.extend_from_slice(&0u32.to_le_bytes());
404+
expected_diff.extend_from_slice(&11u32.to_le_bytes());
405+
406+
// 4 71 (each u32)
407+
expected_diff.extend_from_slice(&4u32.to_le_bytes());
408+
expected_diff.extend_from_slice(&71u32.to_le_bytes());
409+
410+
let mut offset_in_diff = 12u32;
411+
for (offset_in_data, diff) in additional_changes.iter() {
412+
expected_diff.extend_from_slice(&offset_in_diff.to_le_bytes());
413+
expected_diff.extend_from_slice(&offset_in_data.to_le_bytes());
414+
offset_in_diff += diff.len() as u32;
415+
}
416+
417+
// -- segments --
418+
419+
// 11 12 13 14 (each u8)
420+
expected_diff.extend_from_slice(&0x01020304u32.to_le_bytes());
421+
// 71 72 ... 78 (each u8)
422+
expected_diff.extend_from_slice(&0x0102030405060708u64.to_le_bytes());
423+
424+
// append diff from additional_changes
425+
for (_, diff) in additional_changes.iter() {
426+
expected_diff.extend_from_slice(diff);
427+
}
428+
429+
expected_diff
430+
}
431+
300432
#[test]
301433
fn test_using_example_data() {
302434
let original = [0; 100];
@@ -311,42 +443,99 @@ mod tests {
311443

312444
let actual_diff = compute_diff(&original, &changed);
313445
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
314-
let expected_diff = {
315-
// expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
446+
let expected_diff = get_example_expected_diff(changed.len(), vec![]);
316447

317-
let mut serialized = vec![];
448+
assert_eq!(actual_diff.len(), 4 + 4 + 8 + 8 + (4 + 8));
449+
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
318450

319-
// 100 (u32)
320-
serialized.extend_from_slice(&(changed.len() as u32).to_le_bytes());
451+
let expected_changed = apply_diff_copy(&original, &actual_diffset).unwrap();
321452

322-
// 2 (u32)
323-
serialized.extend_from_slice(&2u32.to_le_bytes());
453+
assert_eq!(changed.as_slice(), expected_changed.as_slice());
454+
455+
let expected_changed = {
456+
let mut destination = vec![255; original.len()];
457+
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
458+
destination
459+
};
460+
461+
assert_eq!(changed.as_slice(), expected_changed.as_slice());
462+
}
324463

325-
// 0 11 (each u32)
326-
serialized.extend_from_slice(&0u32.to_le_bytes());
327-
serialized.extend_from_slice(&11u32.to_le_bytes());
464+
#[test]
465+
fn test_shrunk_account_data() {
466+
// Note that changed_len cannot be lower than 79 because the last "changed" index is
467+
// 78 in the diff.
468+
const CHANGED_LEN: usize = 80;
328469

329-
// 4 71 (each u32)
330-
serialized.extend_from_slice(&4u32.to_le_bytes());
331-
serialized.extend_from_slice(&71u32.to_le_bytes());
470+
let original = vec![0; 100];
471+
let changed = {
472+
let mut copy = original.clone();
473+
copy.truncate(CHANGED_LEN);
332474

333-
// 11 12 13 14 (each u8)
334-
serialized.extend_from_slice(&0x01020304u32.to_le_bytes());
335-
// 71 72 ... 78 (each u8)
336-
serialized.extend_from_slice(&0x0102030405060708u64.to_le_bytes());
337-
serialized
475+
// | 11 | 12 | 13 | 14 |
476+
copy[11..=14].copy_from_slice(&0x01020304u32.to_le_bytes());
477+
// | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
478+
copy[71..=78].copy_from_slice(&0x0102030405060708u64.to_le_bytes());
479+
copy
338480
};
339481

482+
let actual_diff = compute_diff(&original, &changed);
483+
484+
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
485+
486+
let expected_diff = get_example_expected_diff(CHANGED_LEN, vec![]);
487+
340488
assert_eq!(actual_diff.len(), 4 + 4 + 8 + 8 + (4 + 8));
341489
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
342490

343-
let expected_changed = apply_diff_copy(&original, &actual_diffset).unwrap();
491+
let expected_changed = {
492+
let mut destination = vec![255; CHANGED_LEN];
493+
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
494+
destination
495+
};
344496

345497
assert_eq!(changed.as_slice(), expected_changed.as_slice());
498+
}
499+
500+
#[test]
501+
fn test_expanded_account_data() {
502+
const CHANGED_LEN: usize = 120;
503+
504+
let original = vec![0; 100];
505+
let changed = {
506+
let mut copy = original.clone();
507+
copy.resize(CHANGED_LEN, 0); // new bytes are zero-initialized
508+
509+
// | 11 | 12 | 13 | 14 |
510+
copy[11..=14].copy_from_slice(&0x01020304u32.to_le_bytes());
511+
// | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
512+
copy[71..=78].copy_from_slice(&0x0102030405060708u64.to_le_bytes());
513+
copy
514+
};
515+
516+
let actual_diff = compute_diff(&original, &changed);
517+
518+
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
519+
520+
// When an account expands, the extra bytes at the end become part of the diff, even if
521+
// all of them are zeroes, that is why (100, &[0; 32]) is passed as additional_changes to
522+
// the following function.
523+
//
524+
// TODO (snawaz): we could optimize compute_diff to not include the zero bytes which are
525+
// part of the expansion.
526+
let expected_diff = get_example_expected_diff(CHANGED_LEN, vec![(100, &[0; 20])]);
527+
528+
assert_eq!(actual_diff.len(), 4 + 4 + (8 + 8) + (4 + 8) + (4 + 4 + 20));
529+
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
346530

347531
let expected_changed = {
348-
let mut destination = vec![255; original.len()];
349-
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
532+
let mut destination = vec![255; CHANGED_LEN];
533+
let written = merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
534+
535+
// TODO (snawaz): written == 120, is because currently the expanded bytes are part of the diff.
536+
// Once compute_diff is optimized further, written must be 100.
537+
assert_eq!(written, 120);
538+
350539
destination
351540
};
352541

src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pub enum DlpError {
4141
InvalidDiffAlignment = 16,
4242
#[error("MergeDiff precondition did not meet")]
4343
MergeDiffError = 17,
44+
#[error("An infallible error is encountered possibly due to logic error")]
45+
InfallibleError = 18,
4446
}
4547

4648
impl From<DlpError> for ProgramError {

0 commit comments

Comments
 (0)