@@ -232,33 +232,113 @@ pub fn apply_diff_copy(original: &[u8], diffset: &DiffSet<'_>) -> Result<Vec<u8>
232232 } )
233233}
234234
235- /// This function constructs destination by merging original with diff such that destination
236- /// becomes the changed version of the original.
235+ /// Constructs destination by applying the diff to original, such that destination becomes the
236+ /// post-diff state of the original.
237237///
238238/// Precondition:
239- /// - destination.len() == original.len()
239+ /// - destination.len() == diffset.changed_len()
240+ /// - original.len() may differ from destination.len() to allow Solana
241+ /// account resizing (shrink or expand).
242+ /// Assumption:
243+ /// - destination is assumed to be zero-initialized. That automatically holds true for freshly
244+ /// allocated Solana account data. The function does NOT validate this assumption for performance reason.
245+ /// Returns:
246+ /// - Ok(n) where n is number of bytes written to destination.
247+ /// - if n < destination.len(), then the last (destination.len() - n) bytes are not written by this function
248+ /// and are assumed to be already zero-initialized. Callers may write to those bytes starting at index `n`.
249+ /// - else n == destination.len().
250+ /// Notes:
251+ /// - Merge consists of:
252+ /// - bytes covered by diff segments are written from diffset.
253+ /// - unmodified regions are copied directly from original.
254+ /// - In shrink case, extra trailing bytes from original are ignored.
255+ /// - In expansion case, any remaining bytes beyond both the diff coverage
256+ /// and original.len() stay unwritten and are assumed to be zero-initialized.
257+ ///
240258pub fn merge_diff_copy (
241259 destination : & mut [ u8 ] ,
242260 original : & [ u8 ] ,
243261 diffset : & DiffSet < ' _ > ,
244- ) -> Result < ( ) , ProgramError > {
245- if destination. len ( ) != original . len ( ) {
262+ ) -> Result < usize , ProgramError > {
263+ if destination. len ( ) != diffset . changed_len ( ) {
246264 return Err ( DlpError :: MergeDiffError . into ( ) ) ;
247265 }
266+
248267 let mut write_index = 0 ;
249268 for item in diffset. iter ( ) {
250269 let ( diff_segment, OffsetInData { start, end } ) = item?;
270+
251271 if write_index < start {
272+ if start > original. len ( ) {
273+ return Err ( DlpError :: InvalidDiff . into ( ) ) ;
274+ }
252275 // copy the unchanged bytes
253276 destination[ write_index..start] . copy_from_slice ( & original[ write_index..start] ) ;
254277 }
278+
255279 destination[ start..end] . copy_from_slice ( diff_segment) ;
256280 write_index = end;
257281 }
258- if write_index < original. len ( ) {
259- destination[ write_index..] . copy_from_slice ( & original[ write_index..] ) ;
260- }
261- Ok ( ( ) )
282+
283+ // Ensure we have overwritten all bytes in destination, otherwise "construction" of destination
284+ // will be considered incomplete.
285+ let num_bytes_written = match write_index. cmp ( & destination. len ( ) ) {
286+ Ordering :: Equal => {
287+ // It means the destination is fully constructed.
288+ // Nothing to do here.
289+
290+ // It is possible that destination.len() <= original.len() i.e destination might have shrunk
291+ // in which case we do not care about those bytes of original which are not part of
292+ // destination anymore.
293+ write_index
294+ }
295+ Ordering :: Less => {
296+ // destination is NOT fully constructed yet. Few bytes in the destination are still unwritten.
297+ // Let's say the number of these unwritten bytes is: N.
298+ //
299+ // Now how do we construct these N unwritten bytes? We have already processed the
300+ // diffset, so now where could the values for these N bytes come from?
301+ //
302+ // There are 3 scenarios:
303+ // - All N bytes must be copied from remaining region of the original:
304+ // - that means, destination.len() <= original.len()
305+ // - and the destination might have shrunk, in which case we do not care about
306+ // the extra bytes in the original: they're discarded.
307+ // - Only (N-M) bytes come from original and the rest M bytes stay unwritten and are
308+ // "assumed" to be already zero-initialized.
309+ // - that means, destination.len() > original.len()
310+ // - write_index + (N-M) == original.len()
311+ // - and the destination has expanded.
312+ // - None of these N bytes come from original. It's basically a special case of
313+ // the second scenario: when M = N i.e all N bytes stay unwritten.
314+ // - that means, destination.len() > original.len()
315+ // - and also, write_index == original.len().
316+ // - the destination has expanded just like the above case.
317+ // - all N bytes are "assumed" to be already zero-initialized (by the caller)
318+
319+ if destination. len ( ) <= original. len ( ) {
320+ // case: all n bytes come from original
321+ let dest_len = destination. len ( ) ;
322+ destination[ write_index..] . copy_from_slice ( & original[ write_index..dest_len] ) ;
323+ dest_len
324+ } else if write_index < original. len ( ) {
325+ // case: some bytes come from original and the rest are "assumed" to be
326+ // zero-initialized (by the caller).
327+ destination[ write_index..original. len ( ) ] . copy_from_slice ( & original[ write_index..] ) ;
328+ original. len ( )
329+ } else {
330+ // case: all N bytes are "assumed" to be zero-initialized (by the caller).
331+ write_index
332+ }
333+ }
334+ Ordering :: Greater => {
335+ // It is an impossible scenario. Even if the diff is corrupt, or the lengths of destinatiare are same
336+ // or different, we'll not encounter this case. It only implies logic error.
337+ return Err ( DlpError :: InfallibleError . into ( ) ) ;
338+ }
339+ } ;
340+
341+ Ok ( num_bytes_written)
262342}
263343
264344// private function that does the actual work.
@@ -297,6 +377,58 @@ mod tests {
297377 ) ;
298378 }
299379
380+ fn get_example_expected_diff (
381+ changed_len : usize ,
382+ // additional_changes must apply after index 78 (index-in-data) !!
383+ additional_changes : Vec < ( u32 , & [ u8 ] ) > ,
384+ ) -> Vec < u8 > {
385+ // expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
386+
387+ let mut expected_diff = vec ! [ ] ;
388+
389+ // changed_len (u32)
390+ expected_diff. extend_from_slice ( & ( changed_len as u32 ) . to_le_bytes ( ) ) ;
391+
392+ if additional_changes. is_empty ( ) {
393+ // 2 (u32)
394+ expected_diff. extend_from_slice ( & 2u32 . to_le_bytes ( ) ) ;
395+ } else {
396+ expected_diff
397+ . extend_from_slice ( & ( 2u32 + additional_changes. len ( ) as u32 ) . to_le_bytes ( ) ) ;
398+ }
399+
400+ // -- offsets
401+
402+ // 0 11 (each u32)
403+ expected_diff. extend_from_slice ( & 0u32 . to_le_bytes ( ) ) ;
404+ expected_diff. extend_from_slice ( & 11u32 . to_le_bytes ( ) ) ;
405+
406+ // 4 71 (each u32)
407+ expected_diff. extend_from_slice ( & 4u32 . to_le_bytes ( ) ) ;
408+ expected_diff. extend_from_slice ( & 71u32 . to_le_bytes ( ) ) ;
409+
410+ let mut offset_in_diff = 12u32 ;
411+ for ( offset_in_data, diff) in additional_changes. iter ( ) {
412+ expected_diff. extend_from_slice ( & offset_in_diff. to_le_bytes ( ) ) ;
413+ expected_diff. extend_from_slice ( & offset_in_data. to_le_bytes ( ) ) ;
414+ offset_in_diff += diff. len ( ) as u32 ;
415+ }
416+
417+ // -- segments --
418+
419+ // 11 12 13 14 (each u8)
420+ expected_diff. extend_from_slice ( & 0x01020304u32 . to_le_bytes ( ) ) ;
421+ // 71 72 ... 78 (each u8)
422+ expected_diff. extend_from_slice ( & 0x0102030405060708u64 . to_le_bytes ( ) ) ;
423+
424+ // append diff from additional_changes
425+ for ( _, diff) in additional_changes. iter ( ) {
426+ expected_diff. extend_from_slice ( diff) ;
427+ }
428+
429+ expected_diff
430+ }
431+
300432 #[ test]
301433 fn test_using_example_data ( ) {
302434 let original = [ 0 ; 100 ] ;
@@ -311,42 +443,99 @@ mod tests {
311443
312444 let actual_diff = compute_diff ( & original, & changed) ;
313445 let actual_diffset = DiffSet :: try_new ( & actual_diff) . unwrap ( ) ;
314- let expected_diff = {
315- // expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
446+ let expected_diff = get_example_expected_diff ( changed. len ( ) , vec ! [ ] ) ;
316447
317- let mut serialized = vec ! [ ] ;
448+ assert_eq ! ( actual_diff. len( ) , 4 + 4 + 8 + 8 + ( 4 + 8 ) ) ;
449+ assert_eq ! ( actual_diff. as_slice( ) , expected_diff. as_slice( ) ) ;
318450
319- // 100 (u32)
320- serialized. extend_from_slice ( & ( changed. len ( ) as u32 ) . to_le_bytes ( ) ) ;
451+ let expected_changed = apply_diff_copy ( & original, & actual_diffset) . unwrap ( ) ;
321452
322- // 2 (u32)
323- serialized. extend_from_slice ( & 2u32 . to_le_bytes ( ) ) ;
453+ assert_eq ! ( changed. as_slice( ) , expected_changed. as_slice( ) ) ;
454+
455+ let expected_changed = {
456+ let mut destination = vec ! [ 255 ; original. len( ) ] ;
457+ merge_diff_copy ( & mut destination, & original, & actual_diffset) . unwrap ( ) ;
458+ destination
459+ } ;
460+
461+ assert_eq ! ( changed. as_slice( ) , expected_changed. as_slice( ) ) ;
462+ }
324463
325- // 0 11 (each u32)
326- serialized. extend_from_slice ( & 0u32 . to_le_bytes ( ) ) ;
327- serialized. extend_from_slice ( & 11u32 . to_le_bytes ( ) ) ;
464+ #[ test]
465+ fn test_shrunk_account_data ( ) {
466+ // Note that changed_len cannot be lower than 79 because the last "changed" index is
467+ // 78 in the diff.
468+ const CHANGED_LEN : usize = 80 ;
328469
329- // 4 71 (each u32)
330- serialized. extend_from_slice ( & 4u32 . to_le_bytes ( ) ) ;
331- serialized. extend_from_slice ( & 71u32 . to_le_bytes ( ) ) ;
470+ let original = vec ! [ 0 ; 100 ] ;
471+ let changed = {
472+ let mut copy = original. clone ( ) ;
473+ copy. truncate ( CHANGED_LEN ) ;
332474
333- // 11 12 13 14 (each u8)
334- serialized . extend_from_slice ( & 0x01020304u32 . to_le_bytes ( ) ) ;
335- // 71 72 ... 78 (each u8)
336- serialized . extend_from_slice ( & 0x0102030405060708u64 . to_le_bytes ( ) ) ;
337- serialized
475+ // | 11 | 12 | 13 | 14 |
476+ copy [ 11 ..= 14 ] . copy_from_slice ( & 0x01020304u32 . to_le_bytes ( ) ) ;
477+ // | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
478+ copy [ 71 ..= 78 ] . copy_from_slice ( & 0x0102030405060708u64 . to_le_bytes ( ) ) ;
479+ copy
338480 } ;
339481
482+ let actual_diff = compute_diff ( & original, & changed) ;
483+
484+ let actual_diffset = DiffSet :: try_new ( & actual_diff) . unwrap ( ) ;
485+
486+ let expected_diff = get_example_expected_diff ( CHANGED_LEN , vec ! [ ] ) ;
487+
340488 assert_eq ! ( actual_diff. len( ) , 4 + 4 + 8 + 8 + ( 4 + 8 ) ) ;
341489 assert_eq ! ( actual_diff. as_slice( ) , expected_diff. as_slice( ) ) ;
342490
343- let expected_changed = apply_diff_copy ( & original, & actual_diffset) . unwrap ( ) ;
491+ let expected_changed = {
492+ let mut destination = vec ! [ 255 ; CHANGED_LEN ] ;
493+ merge_diff_copy ( & mut destination, & original, & actual_diffset) . unwrap ( ) ;
494+ destination
495+ } ;
344496
345497 assert_eq ! ( changed. as_slice( ) , expected_changed. as_slice( ) ) ;
498+ }
499+
500+ #[ test]
501+ fn test_expanded_account_data ( ) {
502+ const CHANGED_LEN : usize = 120 ;
503+
504+ let original = vec ! [ 0 ; 100 ] ;
505+ let changed = {
506+ let mut copy = original. clone ( ) ;
507+ copy. resize ( CHANGED_LEN , 0 ) ; // new bytes are zero-initialized
508+
509+ // | 11 | 12 | 13 | 14 |
510+ copy[ 11 ..=14 ] . copy_from_slice ( & 0x01020304u32 . to_le_bytes ( ) ) ;
511+ // | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
512+ copy[ 71 ..=78 ] . copy_from_slice ( & 0x0102030405060708u64 . to_le_bytes ( ) ) ;
513+ copy
514+ } ;
515+
516+ let actual_diff = compute_diff ( & original, & changed) ;
517+
518+ let actual_diffset = DiffSet :: try_new ( & actual_diff) . unwrap ( ) ;
519+
520+ // When an account expands, the extra bytes at the end become part of the diff, even if
521+ // all of them are zeroes, that is why (100, &[0; 32]) is passed as additional_changes to
522+ // the following function.
523+ //
524+ // TODO (snawaz): we could optimize compute_diff to not include the zero bytes which are
525+ // part of the expansion.
526+ let expected_diff = get_example_expected_diff ( CHANGED_LEN , vec ! [ ( 100 , & [ 0 ; 20 ] ) ] ) ;
527+
528+ assert_eq ! ( actual_diff. len( ) , 4 + 4 + ( 8 + 8 ) + ( 4 + 8 ) + ( 4 + 4 + 20 ) ) ;
529+ assert_eq ! ( actual_diff. as_slice( ) , expected_diff. as_slice( ) ) ;
346530
347531 let expected_changed = {
348- let mut destination = vec ! [ 255 ; original. len( ) ] ;
349- merge_diff_copy ( & mut destination, & original, & actual_diffset) . unwrap ( ) ;
532+ let mut destination = vec ! [ 255 ; CHANGED_LEN ] ;
533+ let written = merge_diff_copy ( & mut destination, & original, & actual_diffset) . unwrap ( ) ;
534+
535+ // TODO (snawaz): written == 120, is because currently the expanded bytes are part of the diff.
536+ // Once compute_diff is optimized further, written must be 100.
537+ assert_eq ! ( written, 120 ) ;
538+
350539 destination
351540 } ;
352541
0 commit comments