Skip to content

Commit 0e56883

Browse files
committed
More perf improvements for incremental sorting
1 parent 8fb4711 commit 0e56883

File tree

2 files changed

+85
-58
lines changed
  • necsim/impls/no-std/src/parallelisation/independent/monolithic/reporter
  • rustcoalescence/algorithms/cuda/src/parallelisation

2 files changed

+85
-58
lines changed

necsim/impls/no-std/src/parallelisation/independent/monolithic/reporter/live.rs

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use alloc::{vec::Vec, collections::VecDeque};
1+
use alloc::{collections::VecDeque, vec::Vec};
22
use core::{fmt, marker::PhantomData, ops::ControlFlow};
33
use necsim_core_bond::NonNegativeF64;
44

@@ -115,7 +115,11 @@ impl<'l, 'p, R: Reporter, P: LocalPartition<'p, R>> LiveWaterLevelReporterProxy<
115115
|| (n >= 3 && self.runs[n - 3].len <= self.runs[n - 2].len + self.runs[n - 1].len)
116116
|| (n >= 4 && self.runs[n - 4].len <= self.runs[n - 3].len + self.runs[n - 2].len))
117117
{
118-
if n >= 3 && self.runs[n - 3].len < self.runs[n - 1].len { Some(n - 3) } else { Some(n - 2) }
118+
if n >= 3 && self.runs[n - 3].len < self.runs[n - 1].len {
119+
Some(n - 3)
120+
} else {
121+
Some(n - 2)
122+
}
119123
} else {
120124
None
121125
}
@@ -157,30 +161,36 @@ impl<'l, 'p, R: Reporter, P: LocalPartition<'p, R>> LiveWaterLevelReporterProxy<
157161
let v = v.as_mut_ptr();
158162
let (v_mid, v_end) = unsafe { (v.add(mid), v.add(len)) };
159163

160-
// The merge process first copies the shorter run into `buf`. Then it traces the newly copied
161-
// run and the longer run forwards (or backwards), comparing their next unconsumed elements and
162-
// copying the lesser (or greater) one into `v`.
164+
// The merge process first copies the shorter run into `buf`. Then it traces the
165+
// newly copied run and the longer run forwards (or backwards),
166+
// comparing their next unconsumed elements and copying the lesser (or
167+
// greater) one into `v`.
163168
//
164-
// As soon as the shorter run is fully consumed, the process is done. If the longer run gets
165-
// consumed first, then we must copy whatever is left of the shorter run into the remaining
166-
// hole in `v`.
169+
// As soon as the shorter run is fully consumed, the process is done. If the
170+
// longer run gets consumed first, then we must copy whatever is left of
171+
// the shorter run into the remaining hole in `v`.
167172
//
168-
// Intermediate state of the process is always tracked by `hole`, which serves two purposes:
169-
// 1. Protects integrity of `v` from panics in `is_less`.
173+
// Intermediate state of the process is always tracked by `hole`, which serves
174+
// two purposes: 1. Protects integrity of `v` from panics in `is_less`.
170175
// 2. Fills the remaining hole in `v` if the longer run gets consumed first.
171176
//
172177
// Panic safety:
173178
//
174-
// If `is_less` panics at any point during the process, `hole` will get dropped and fill the
175-
// hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
176-
// object it initially held exactly once.
179+
// If `is_less` panics at any point during the process, `hole` will get dropped
180+
// and fill the hole in `v` with the unconsumed range in `buf`, thus
181+
// ensuring that `v` still holds every object it initially held exactly
182+
// once.
177183
let mut hole;
178184

179185
if mid <= len - mid {
180186
// The left run is shorter.
181187
unsafe {
182188
core::ptr::copy_nonoverlapping(v, buf, mid);
183-
hole = MergeHole { start: buf, end: buf.add(mid), dest: v };
189+
hole = MergeHole {
190+
start: buf,
191+
end: buf.add(mid),
192+
dest: v,
193+
};
184194
}
185195

186196
// Initially, these pointers point to the beginnings of their arrays.
@@ -204,7 +214,11 @@ impl<'l, 'p, R: Reporter, P: LocalPartition<'p, R>> LiveWaterLevelReporterProxy<
204214
// The right run is shorter.
205215
unsafe {
206216
core::ptr::copy_nonoverlapping(v_mid, buf, len - mid);
207-
hole = MergeHole { start: buf, end: buf.add(len - mid), dest: v_mid };
217+
hole = MergeHole {
218+
start: buf,
219+
end: buf.add(len - mid),
220+
dest: v_mid,
221+
};
208222
}
209223

210224
// Initially, these pointers point past the ends of their arrays.
@@ -225,39 +239,24 @@ impl<'l, 'p, R: Reporter, P: LocalPartition<'p, R>> LiveWaterLevelReporterProxy<
225239
}
226240
}
227241
}
228-
// Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
229-
// it will now be copied into the hole in `v`.
242+
// Finally, `hole` gets dropped. If the shorter run was not fully
243+
// consumed, whatever remains of it will now be copied into the
244+
// hole in `v`.
230245
}
231246

232247
fn sort_slow_events_step(&mut self, force_merge: bool) -> ControlFlow<()> {
233-
let r = loop {
234-
if let Some(r) = self.collapse(force_merge && self.overflow.is_empty()) {
235-
break r;
236-
}
237-
248+
let Some(r) = self.collapse(force_merge && self.overflow.is_empty() && self.run.len == 0) else {
238249
let next_run = match self.overflow.pop_front() {
239250
Some(next_run) => next_run,
240251
None if self.run.len > 0 => core::mem::replace(&mut self.run, Run { start: self.slow_events.len(), len: 0 }),
241252
None => return ControlFlow::Break(()),
242253
};
243254

244-
// let Some(mut next_run) = self.overflow.pop_front() else {
245-
// return ControlFlow::Break(());
246-
// };
247-
248-
// if next_run.len < self.sort_batch_size {
249-
// while next_run.len < self.sort_batch_size {
250-
// let Some(extra_run) = self.overflow.pop_front() else {
251-
// break;
252-
// };
253-
// next_run.len += extra_run.len;
254-
// }
255-
// self.slow_events[next_run.start..next_run.start+next_run.len].sort_unstable();
256-
// }
257-
258255
self.slow_events[next_run.start..next_run.start+next_run.len].sort_unstable();
259256

260257
self.runs.push(next_run);
258+
259+
return ControlFlow::Continue(());
261260
};
262261

263262
let left = self.runs[r];
@@ -266,7 +265,8 @@ impl<'l, 'p, R: Reporter, P: LocalPartition<'p, R>> LiveWaterLevelReporterProxy<
266265
let min_len = left.len.min(right.len);
267266

268267
if min_len > self.tmp_events.capacity() {
269-
self.tmp_events.reserve(min_len - self.tmp_events.capacity());
268+
self.tmp_events
269+
.reserve(min_len - self.tmp_events.capacity());
270270
}
271271

272272
unsafe {
@@ -278,7 +278,10 @@ impl<'l, 'p, R: Reporter, P: LocalPartition<'p, R>> LiveWaterLevelReporterProxy<
278278
);
279279
}
280280

281-
self.runs[r] = Run { start: left.start, len: left.len + right.len };
281+
self.runs[r] = Run {
282+
start: left.start,
283+
len: left.len + right.len,
284+
};
282285
self.runs.remove(r + 1);
283286

284287
ControlFlow::Continue(())
@@ -319,6 +322,8 @@ impl<'l, 'p, R: Reporter, P: LocalPartition<'p, R>> WaterLevelReporterProxy<'l,
319322
let mut i = 0;
320323

321324
// Report all events below the water level in sorted order
325+
// TODO: Should we detect if no partial sort steps were taken
326+
// and revert to a full unstable sort in that case?
322327
while let ControlFlow::Continue(()) = self.sort_slow_events_step(true) {
323328
if (i % 100) == 0 {
324329
info!("{:?}", self);
@@ -350,14 +355,20 @@ impl<'l, 'p, R: Reporter, P: LocalPartition<'p, R>> WaterLevelReporterProxy<'l,
350355
self.water_level = water_level;
351356

352357
// Move fast events below the new water level into slow events
353-
for event in self.fast_events.drain_filter(|event| event.event_time() < water_level) {
358+
for event in self
359+
.fast_events
360+
.drain_filter(|event| event.event_time() < water_level)
361+
{
354362
let new_run = self.run.len > self.sort_batch_size; // self.slow_events.last().map_or(true, |prev| prev > &event);
355363

356364
if new_run {
357-
let old_run = core::mem::replace(&mut self.run, Run {
358-
start: self.slow_events.len(),
359-
len: 1,
360-
});
365+
let old_run = core::mem::replace(
366+
&mut self.run,
367+
Run {
368+
start: self.slow_events.len(),
369+
len: 1,
370+
},
371+
);
361372
self.overflow.push_back(old_run);
362373
} else {
363374
self.run.len += 1;

rustcoalescence/algorithms/cuda/src/parallelisation/monolithic.rs

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ use std::{
22
collections::VecDeque,
33
convert::{TryFrom, TryInto},
44
num::NonZeroU64,
5-
sync::atomic::AtomicU64, ops::ControlFlow,
5+
ops::ControlFlow,
6+
sync::atomic::{AtomicU64, Ordering},
67
};
78

89
use rust_cuda::{
@@ -178,13 +179,23 @@ pub fn simulate<
178179

179180
let event_slice = event_slice.capacity(slow_lineages.len());
180181

181-
let (grid_size, block_size, dedup_cache, step_slice, sort_block_size, sort_mode, sort_batch_size) = config;
182+
let (
183+
grid_size,
184+
block_size,
185+
dedup_cache,
186+
step_slice,
187+
sort_block_size,
188+
sort_mode,
189+
sort_batch_size,
190+
) = config;
182191

183192
let mut proxy = <WaterLevelReporterStrategy as WaterLevelReporterConstructor<
184193
L::IsLive,
185194
P,
186195
L,
187-
>>::WaterLevelReporter::new(event_slice.get(), local_partition, sort_batch_size);
196+
>>::WaterLevelReporter::new(
197+
event_slice.get(), local_partition, sort_batch_size
198+
);
188199

189200
#[allow(clippy::or_fun_call)]
190201
let intial_max_time = slow_lineages
@@ -236,9 +247,8 @@ pub fn simulate<
236247
let cpu_turnover_rate = simulation.turnover_rate().backup();
237248
let cpu_speciation_probability = simulation.speciation_probability().backup();
238249

239-
let kernel_event = rust_cuda::host::CudaDropWrapper::from(rust_cuda::rustacuda::event::Event::new(
240-
rust_cuda::rustacuda::event::EventFlags::DISABLE_TIMING
241-
)?);
250+
let cuda_kernel_cycle = AtomicU64::new(0);
251+
let mut host_kernel_cycle = 0_u64;
242252

243253
HostAndDeviceMutRef::with_new(&mut total_time_max, |total_time_max| -> Result<()> {
244254
HostAndDeviceMutRef::with_new(&mut total_steps_sum, |total_steps_sum| -> Result<()> {
@@ -295,7 +305,7 @@ pub fn simulate<
295305
proxy.advance_water_level(level_time);
296306

297307
// Simulate all slow lineages until they have finished or exceeded the
298-
// new water level
308+
// new water level
299309
while !slow_lineages.is_empty() {
300310
let mut num_tasks = 0_usize;
301311

@@ -318,7 +328,7 @@ pub fn simulate<
318328
}
319329

320330
// Move the task list, event buffer and min speciation sample buffer
321-
// to CUDA
331+
// to CUDA
322332
let mut task_list_cuda = task_list.move_to_device_async(stream)?;
323333

324334
// TODO: Investigate distributing over several streams
@@ -468,11 +478,18 @@ pub fn simulate<
468478

469479
let event_buffer_host = event_buffer_cuda.move_to_host_async(stream)?;
470480

471-
kernel_event.record(stream)?;
481+
host_kernel_cycle = host_kernel_cycle.wrapping_add(1);
472482

473-
while let rust_cuda::rustacuda::event::EventStatus::NotReady = kernel_event.query()? {
474-
if let ControlFlow::Break(()) = proxy.partial_sort_step() {
475-
kernel_event.synchronize()?;
483+
// Note: The stream ensures that the stores from subsequent
484+
// cycles are ordered, so no compare_exchange is needed.
485+
stream.add_callback(Box::new(|_| {
486+
cuda_kernel_cycle.store(host_kernel_cycle, Ordering::SeqCst);
487+
}))?;
488+
489+
// While CUDA runs the simulation kernel, do some incremental
490+
// sorting of the events that have already been produced
491+
while let ControlFlow::Continue(()) = proxy.partial_sort_step() {
492+
if cuda_kernel_cycle.load(Ordering::SeqCst) == host_kernel_cycle {
476493
break;
477494
}
478495
}
@@ -496,9 +513,8 @@ pub fn simulate<
496513
(task.take(), next_event_time.take())
497514
{
498515
if !duplicate_individual {
499-
// Reclassify lineages as either slow (still below
500-
// water) or
501-
// fast
516+
// Reclassify lineages as either slow
517+
// (still below water) or fast
502518
if next_event_time < level_time {
503519
slow_lineages.push_back((task, next_event_time.into()));
504520
} else {

0 commit comments

Comments
 (0)