Skip to content

Commit c7869d5

Browse files
committed
Start with bitonic sort implementation [wip]
1 parent 7d6d5da commit c7869d5

File tree

13 files changed

+287
-23
lines changed

13 files changed

+287
-23
lines changed

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

necsim/core/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ contracts = "0.6.3"
2020
serde = { version = "1.0", default-features = false, features = ["derive"] }
2121

2222
[target.'cfg(target_os = "cuda")'.dependencies]
23-
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "d5dfd114", features = ["derive"], optional = true }
23+
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "c56d7f8f", features = ["derive"], optional = true }
2424

2525
[target.'cfg(not(target_os = "cuda"))'.dependencies]
26-
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "d5dfd114", features = ["derive", "host"], optional = true }
26+
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "c56d7f8f", features = ["derive", "host"], optional = true }

necsim/core/src/event.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ impl PartialEq for PackedEvent {
255255
}
256256

257257
impl Ord for PackedEvent {
258+
#[inline]
258259
fn cmp(&self, other: &Self) -> Ordering {
259260
// Order `Event`s in lexicographical order:
260261
// (1) event_time /=\

necsim/impls/cuda/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ contracts = "0.6.3"
1515
serde = { version = "1.0", default-features = false, features = ["derive"] }
1616

1717
[target.'cfg(target_os = "cuda")'.dependencies]
18-
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "d5dfd114", features = ["derive"] }
18+
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "c56d7f8f", features = ["derive"] }
1919

2020
[target.'cfg(not(target_os = "cuda"))'.dependencies]
21-
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "d5dfd114", features = ["derive", "host"] }
21+
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "c56d7f8f", features = ["derive", "host"] }

necsim/impls/cuda/src/event_buffer.rs

Lines changed: 239 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,35 +59,48 @@ pub trait AlignedToU64: sealed::AlignedToU64 {}
5959
impl<T: sealed::AlignedToU64> AlignedToU64 for T {}
6060

6161
pub trait EventType {
62-
type Event: ~const rust_cuda::const_type_layout::TypeGraphLayout
62+
type Event: 'static
63+
+ ~const rust_cuda::const_type_layout::TypeGraphLayout
6364
+ rust_cuda::safety::StackOnly
6465
+ Into<TypedEvent>
6566
+ Into<PackedEvent>
6667
+ Ord
6768
+ Clone
6869
+ AlignedToU64;
70+
71+
const SHARED_LIMIT: usize;
6972
}
7073

7174
impl<ReportSpeciation: Boolean, ReportDispersal: Boolean> EventType
7275
for EventBuffer<ReportSpeciation, ReportDispersal>
7376
{
7477
default type Event = PackedEvent;
78+
79+
default const SHARED_LIMIT: usize = 0;
7580
}
7681

7782
impl EventType for EventBuffer<False, False> {
7883
type Event = PackedEvent;
84+
85+
const SHARED_LIMIT: usize = ((48*1024 / core::mem::size_of::<Self::Event>()) / 32) * 32;
7986
}
8087

8188
impl EventType for EventBuffer<False, True> {
8289
type Event = PackedEvent;
90+
91+
const SHARED_LIMIT: usize = ((48*1024 / core::mem::size_of::<Self::Event>()) / 32) * 32;
8392
}
8493

8594
impl EventType for EventBuffer<True, False> {
8695
type Event = SpeciationEvent;
96+
97+
const SHARED_LIMIT: usize = ((48*1024 / core::mem::size_of::<Self::Event>()) / 32) * 32;
8798
}
8899

89100
impl EventType for EventBuffer<True, True> {
90101
type Event = PackedEvent;
102+
103+
const SHARED_LIMIT: usize = ((48*1024 / core::mem::size_of::<Self::Event>()) / 32) * 32;
91104
}
92105

93106
impl<ReportSpeciation: Boolean, ReportDispersal: Boolean> fmt::Debug
@@ -212,19 +225,238 @@ impl<ReportSpeciation: Boolean, ReportDispersal: Boolean>
212225
impl<ReportSpeciation: Boolean, ReportDispersal: Boolean>
213226
EventBuffer<ReportSpeciation, ReportDispersal>
214227
{
228+
/// Bitonic sort combined merge step for shared memory, based on
229+
/// <https://github.com/NVIDIA/cuda-samples/blob/81992093d2b8c33cab22dbf6852c070c330f1715/Samples/2_Concepts_and_Techniques/sortingNetworks/bitonicSort.cu#L179-L220>
230+
///
231+
/// # Safety
232+
///
233+
/// All CUDA threads must call this method with the same size argument.
234+
/// Only one call per kernel launch is safe without further synchronisation.
235+
///
236+
/// # Panics
237+
///
238+
/// Panics if the thread block size does not equal `<Self as EventType>::SHARED_LIMIT`.
239+
pub unsafe fn bitonic_sort_events_shared_step(&mut self, size: usize) where [(); <Self as EventType>::SHARED_LIMIT]: {
240+
use core::cmp::Ordering;
241+
242+
let block_dim = rust_cuda::device::utils::block_dim();
243+
244+
rust_cuda::assert_eq!(block_dim.size() * 2, <Self as EventType>::SHARED_LIMIT);
245+
246+
let block_idx = rust_cuda::device::utils::block_idx().as_id(&rust_cuda::device::utils::grid_dim());
247+
let thread_idx = rust_cuda::device::utils::thread_idx().as_id(&block_dim);
248+
249+
let idx = block_idx * <Self as EventType>::SHARED_LIMIT + thread_idx;
250+
251+
let shared_mask: rust_cuda::device::ThreadBlockShared<
252+
[bool; <Self as EventType>::SHARED_LIMIT]
253+
> = rust_cuda::device::ThreadBlockShared::new_uninit();
254+
let shared_mask_array: *mut bool = shared_mask.get().cast();
255+
let shared_buffer: rust_cuda::device::ThreadBlockShared<
256+
[MaybeSome<<Self as EventType>::Event>; <Self as EventType>::SHARED_LIMIT]
257+
> = rust_cuda::device::ThreadBlockShared::new_uninit();
258+
let shared_buffer_array: *mut MaybeSome<<Self as EventType>::Event> = shared_buffer.get().cast();
259+
260+
*shared_mask_array.add(thread_idx) = match self.event_mask.alias_unchecked().get(idx) {
261+
None => false,
262+
Some(mask) => *mask.read(),
263+
};
264+
*shared_buffer_array.add(thread_idx) = match self.event_buffer.alias_unchecked().get(idx) {
265+
None => MaybeSome::None,
266+
Some(event) => event.as_uninit().assume_init_read(),
267+
};
268+
*shared_mask_array.add(thread_idx + (<Self as EventType>::SHARED_LIMIT / 2)) = match self.event_mask.alias_unchecked().get(idx + (<Self as EventType>::SHARED_LIMIT / 2)) {
269+
None => false,
270+
Some(mask) => *mask.read(),
271+
};
272+
*shared_buffer_array.add(thread_idx + (<Self as EventType>::SHARED_LIMIT / 2)) = match self.event_buffer.alias_unchecked().get(idx + (<Self as EventType>::SHARED_LIMIT / 2)) {
273+
None => MaybeSome::None,
274+
Some(event) => event.as_uninit().assume_init_read(),
275+
};
276+
277+
let pos = (block_idx * block_dim.size() + thread_idx) & ((self.event_mask.alias_unchecked().len().next_power_of_two() / 2) - 1);
278+
let dir = if (pos & (size / 2)) == 0 {
279+
Ordering::Greater
280+
} else {
281+
Ordering::Less
282+
};
283+
284+
let mut stride = <Self as EventType>::SHARED_LIMIT >> 1;
285+
286+
while stride > 0 {
287+
::core::arch::nvptx::_syncthreads();
288+
289+
let pos_a = 2 * thread_idx - (thread_idx & (stride - 1));
290+
let pos_b = pos_a + stride;
291+
292+
let mask_a: bool = *shared_mask_array.add(pos_a);
293+
let mask_b: bool = *shared_mask_array.add(pos_b);
294+
295+
let cmp = match (mask_a, mask_b) {
296+
(false, false) => Ordering::Equal,
297+
(false, true) => Ordering::Greater,
298+
(true, false) => Ordering::Less,
299+
(true, true) => {
300+
// Safety: both masks indicate that the two events exist
301+
let event_a: &<Self as EventType>::Event = unsafe {
302+
(*shared_buffer_array.add(pos_a)).assume_some_ref()
303+
};
304+
let event_b: &<Self as EventType>::Event = unsafe {
305+
(*shared_buffer_array.add(pos_b)).assume_some_ref()
306+
};
307+
308+
event_a.cmp(event_b)
309+
},
310+
};
311+
312+
if cmp == dir {
313+
*shared_mask_array.add(pos_a) = mask_b;
314+
*shared_mask_array.add(pos_b) = mask_a;
315+
316+
let ptr_a: *mut u64 = shared_buffer_array.add(pos_a).cast();
317+
let ptr_b: *mut u64 = shared_buffer_array.add(pos_b).cast();
318+
319+
// Manual swap implementation that can be unrolled without local memory
320+
// Safety: AlignedToU64 guarantees that both events are aligned to u64
321+
// and can be copied as multiples of u64
322+
for i in 0..(core::mem::size_of::<<Self as EventType>::Event>() / 8) {
323+
let swap = *ptr_a.add(i);
324+
*ptr_a.add(i) = *ptr_b.add(i);
325+
*ptr_b.add(i) = swap;
326+
}
327+
}
328+
329+
stride >>= 1;
330+
}
331+
332+
::core::arch::nvptx::_syncthreads();
333+
334+
if let Some(mask) = self.event_mask.alias_mut_unchecked().get_mut(idx) {
335+
mask.write(*shared_mask_array.add(thread_idx));
336+
}
337+
if let Some(event) = self.event_buffer.alias_mut_unchecked().get_mut(idx) {
338+
event.write(core::ptr::read(shared_buffer_array.add(thread_idx)));
339+
}
340+
if let Some(mask) = self.event_mask.alias_mut_unchecked().get_mut(idx + (<Self as EventType>::SHARED_LIMIT / 2)) {
341+
mask.write(*shared_mask_array.add(thread_idx + (<Self as EventType>::SHARED_LIMIT / 2)));
342+
}
343+
if let Some(event) = self.event_buffer.alias_mut_unchecked().get_mut(idx + (<Self as EventType>::SHARED_LIMIT / 2)) {
344+
event.write(core::ptr::read(shared_buffer_array.add(thread_idx + (<Self as EventType>::SHARED_LIMIT / 2))));
345+
}
346+
}
347+
348+
/// Bitonic sort single merge step for global memory, based on
349+
/// <https://github.com/NVIDIA/cuda-samples/blob/81992093d2b8c33cab22dbf6852c070c330f1715/Samples/2_Concepts_and_Techniques/sortingNetworks/bitonicSort.cu#L154-L177>
350+
///
351+
/// # Safety
352+
///
353+
/// All CUDA threads must call this method with the same size and stride arguments.
354+
/// Only one call per kernel launch is safe without further synchronisation.
355+
pub unsafe fn bitonic_sort_events_step(&mut self, size: usize, stride: usize) {
356+
use core::cmp::Ordering;
357+
358+
let idx = rust_cuda::device::utils::index();
359+
360+
let pos = idx & ((self.event_mask.alias_unchecked().len().next_power_of_two() / 2) - 1);
361+
362+
let dir = if (pos & (size / 2)) == 0 {
363+
Ordering::Greater
364+
} else {
365+
Ordering::Less
366+
};
367+
368+
let pos_a = 2 * idx - (idx & (stride - 1));
369+
let pos_b = pos_a + stride;
370+
371+
if (pos_a < self.event_mask.alias_unchecked().len())
372+
&& (pos_b < self.event_mask.alias_unchecked().len())
373+
{
374+
let mask_a: bool = *self
375+
.event_mask
376+
.alias_unchecked()
377+
.get_unchecked(pos_a)
378+
.read();
379+
let mask_b: bool = *self
380+
.event_mask
381+
.alias_unchecked()
382+
.get_unchecked(pos_b)
383+
.read();
384+
385+
let cmp = match (mask_a, mask_b) {
386+
(false, false) => Ordering::Equal,
387+
(false, true) => Ordering::Greater,
388+
(true, false) => Ordering::Less,
389+
(true, true) => {
390+
// Safety: both masks indicate that the two events exist
391+
let event_a: &<Self as EventType>::Event = unsafe {
392+
self.event_buffer
393+
.alias_unchecked()
394+
.get_unchecked(pos_a)
395+
.as_uninit()
396+
.assume_init_ref()
397+
.assume_some_ref()
398+
};
399+
let event_b: &<Self as EventType>::Event = unsafe {
400+
self.event_buffer
401+
.alias_unchecked()
402+
.get_unchecked(pos_b)
403+
.as_uninit()
404+
.assume_init_ref()
405+
.assume_some_ref()
406+
};
407+
408+
event_a.cmp(event_b)
409+
},
410+
};
411+
412+
if cmp == dir {
413+
self.event_mask
414+
.alias_mut_unchecked()
415+
.get_unchecked_mut(pos_a)
416+
.write(mask_b);
417+
self.event_mask
418+
.alias_mut_unchecked()
419+
.get_unchecked_mut(pos_b)
420+
.write(mask_a);
421+
422+
let ptr_a: *mut u64 = self
423+
.event_buffer
424+
.alias_mut_unchecked()
425+
.as_mut_ptr()
426+
.add(pos_a)
427+
.cast();
428+
let ptr_b: *mut u64 = self
429+
.event_buffer
430+
.alias_mut_unchecked()
431+
.as_mut_ptr()
432+
.add(pos_b)
433+
.cast();
434+
435+
// Manual swap implementation that can be unrolled without local memory
436+
// Safety: AlignedToU64 guarantees that both events are aligned to u64
437+
// and can be copied as multiples of u64
438+
for i in 0..(core::mem::size_of::<<Self as EventType>::Event>() / 8) {
439+
let swap = *ptr_a.add(i);
440+
*ptr_a.add(i) = *ptr_b.add(i);
441+
*ptr_b.add(i) = swap;
442+
}
443+
}
444+
}
445+
}
446+
215447
#[allow(clippy::too_many_lines)]
448+
/// Odd-Even sort single merge step for global memory, based on
449+
/// <https://github.com/NVIDIA/cuda-samples/blob/81992093d2b8c33cab22dbf6852c070c330f1715/Samples/2_Concepts_and_Techniques/sortingNetworks/oddEvenMergeSort.cu#L95-L137>
450+
///
216451
/// # Safety
217452
///
218-
/// All CUDA threads must call this method with the same size, stride, and
219-
/// direction arguments. Only one call per kernel launch is safe without
220-
/// further synchronisation.
221-
pub unsafe fn sort_events_step(&mut self, size: usize, stride: usize) {
453+
/// All CUDA threads must call this method with the same size and stride arguments.
454+
/// Only one call per kernel launch is safe without further synchronisation.
455+
pub unsafe fn odd_even_sort_events_step(&mut self, size: usize, stride: usize) {
222456
use core::cmp::Ordering;
223457

224458
let idx = rust_cuda::device::utils::index();
225459

226-
// Odd-Even merge based on
227-
// https://github.com/NVIDIA/cuda-samples/blob/81992093d2b8c33cab22dbf6852c070c330f1715/Samples/2_Concepts_and_Techniques/sortingNetworks/oddEvenMergeSort.cu#L95-L137
228460
let pos = 2 * idx - (idx & (stride - 1));
229461

230462
let (pos_a, pos_b) = if stride < (size / 2) {

necsim/impls/cuda/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#![feature(const_refs_to_cell)]
88
#![feature(generic_const_exprs)]
99
#![cfg_attr(target_os = "cuda", feature(asm_experimental_arch))]
10+
#![cfg_attr(target_os = "cuda", feature(stdsimd))]
1011
#![allow(incomplete_features)]
1112
#![feature(specialization)]
1213

necsim/impls/cuda/src/utils.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use rust_cuda::safety::StackOnly;
88
pub struct MaybeSome<T: StackOnly>(MaybeUninit<T>);
99

1010
impl<T: StackOnly> MaybeSome<T> {
11-
#[cfg(not(target_os = "cuda"))]
1211
#[allow(non_upper_case_globals)]
1312
pub(crate) const None: Self = Self(MaybeUninit::uninit());
1413

necsim/impls/no-std/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ rand_core = "0.6"
3131
rand_distr = { version = "0.4", default-features = false, features = [] }
3232

3333
[target.'cfg(target_os = "cuda")'.dependencies]
34-
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "d5dfd114", features = ["derive"], optional = true }
34+
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "c56d7f8f", features = ["derive"], optional = true }
3535

3636
[target.'cfg(not(target_os = "cuda"))'.dependencies]
37-
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "d5dfd114", features = ["derive", "host"], optional = true }
37+
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "c56d7f8f", features = ["derive", "host"], optional = true }

rustcoalescence/algorithms/cuda/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ thiserror = "1.0"
2323
serde = { version = "1.0", features = ["derive"] }
2424
serde_state = "0.4"
2525
serde_derive_state = "0.4"
26-
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "d5dfd114", features = ["host"] }
26+
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "c56d7f8f", features = ["host"] }

rustcoalescence/algorithms/cuda/cpu-kernel/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ necsim-impls-no-std = { path = "../../../../necsim/impls/no-std", features = ["c
1414
necsim-impls-cuda = { path = "../../../../necsim/impls/cuda" }
1515
rustcoalescence-algorithms-cuda-gpu-kernel = { path = "../gpu-kernel" }
1616

17-
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "d5dfd114", features = ["host"] }
17+
rust-cuda = { git = "https://github.com/juntyr/rust-cuda", rev = "c56d7f8f", features = ["host"] }

0 commit comments

Comments
 (0)