11use core:: fmt;
22
3+ use const_type_layout:: TypeGraphLayout ;
34#[ cfg( not( target_os = "cuda" ) ) ]
45use rust_cuda:: deps:: rustacuda:: {
56 error:: CudaResult ,
67 function:: { BlockSize , GridSize } ,
78} ;
89
9- use rust_cuda:: utils:: {
10- aliasing:: SplitSliceOverCudaThreadsDynamicStride , exchange:: buffer:: CudaExchangeBuffer ,
10+ use rust_cuda:: {
11+ lend:: RustToCudaProxy ,
12+ safety:: { PortableBitSemantics , SafeMutableAliasing , StackOnly } ,
13+ utils:: {
14+ aliasing:: SplitSliceOverCudaThreadsDynamicStride ,
15+ exchange:: buffer:: { CudaExchangeBuffer , CudaExchangeItem } ,
16+ } ,
1117} ;
1218
1319use necsim_core:: {
@@ -27,8 +33,13 @@ use super::utils::MaybeSome;
2733#[ derive( rust_cuda:: lend:: LendRustToCuda ) ]
2834#[ cuda( free = "ReportSpeciation" , free = "ReportDispersal" ) ]
2935pub struct EventBuffer < ReportSpeciation : Boolean , ReportDispersal : Boolean > {
36+ #[ cfg( not( target_os = "cuda" ) ) ]
3037 #[ cuda( embed) ]
3138 event_mask : SplitSliceOverCudaThreadsDynamicStride < CudaExchangeBuffer < bool , true , true > > ,
39+ #[ cfg( target_os = "cuda" ) ]
40+ #[ cuda( embed = "SplitSliceOverCudaThreadsDynamicStride<CudaExchangeBuffer<bool, true, true>>" ) ]
41+ event_mask : CudaExchangeSlice < CudaExchangeItem < bool , true , true > > ,
42+ #[ cfg( not( target_os = "cuda" ) ) ]
3243 #[ cuda( embed) ]
3344 event_buffer : SplitSliceOverCudaThreadsDynamicStride <
3445 CudaExchangeBuffer <
@@ -37,8 +48,41 @@ pub struct EventBuffer<ReportSpeciation: Boolean, ReportDispersal: Boolean> {
3748 true ,
3849 > ,
3950 > ,
40- max_events : usize ,
41- event_counter : usize ,
51+ #[ cfg( target_os = "cuda" ) ]
52+ #[ cuda( embed = "SplitSliceOverCudaThreadsDynamicStride<
53+ CudaExchangeBuffer<
54+ MaybeSome<<EventBuffer<ReportSpeciation, ReportDispersal> as EventType>::Event>,
55+ false,
56+ true,
57+ >,
58+ >" ) ]
59+ event_buffer : CudaExchangeSlice <
60+ CudaExchangeItem <
61+ MaybeSome < <EventBuffer < ReportSpeciation , ReportDispersal > as EventType >:: Event > ,
62+ false ,
63+ true ,
64+ > ,
65+ > ,
66+ }
67+
68+ // Safety:
69+ // - no mutable aliasing occurs since all parts implement SafeMutableAliasing
70+ // - dropping does not trigger (de)alloc since EventBuffer doesn't impl Drop and
71+ // all parts implement SafeMutableAliasing
72+ // - EventBuffer has no shallow mutable state
73+ unsafe impl < ReportSpeciation : Boolean , ReportDispersal : Boolean > SafeMutableAliasing
74+ for EventBuffer < ReportSpeciation , ReportDispersal >
75+ where
76+ SplitSliceOverCudaThreadsDynamicStride < CudaExchangeBuffer < bool , true , true > > :
77+ SafeMutableAliasing ,
78+ SplitSliceOverCudaThreadsDynamicStride <
79+ CudaExchangeBuffer <
80+ MaybeSome < <EventBuffer < ReportSpeciation , ReportDispersal > as EventType >:: Event > ,
81+ false ,
82+ true ,
83+ > ,
84+ > : SafeMutableAliasing ,
85+ {
4286}
4387
4488pub trait EventType {
@@ -78,10 +122,7 @@ impl<ReportSpeciation: Boolean, ReportDispersal: Boolean> fmt::Debug
78122 for EventBuffer < ReportSpeciation , ReportDispersal >
79123{
80124 fn fmt ( & self , fmt : & mut fmt:: Formatter ) -> fmt:: Result {
81- fmt. debug_struct ( "EventBuffer" )
82- . field ( "max_events" , & self . max_events )
83- . field ( "event_counter" , & self . event_counter )
84- . finish_non_exhaustive ( )
125+ fmt. debug_struct ( "EventBuffer" ) . finish_non_exhaustive ( )
85126 }
86127}
87128
@@ -122,8 +163,6 @@ impl<ReportSpeciation: Boolean, ReportDispersal: Boolean>
122163 CudaExchangeBuffer :: from_vec ( event_buffer) ?,
123164 max_events,
124165 ) ,
125- max_events,
126- event_counter : 0_usize ,
127166 } )
128167 }
129168
@@ -148,9 +187,26 @@ impl<ReportSpeciation: Boolean, ReportDispersal: Boolean>
148187 mask. write ( false ) ;
149188 }
150189 }
190+ }
151191
152- pub fn max_events_per_individual ( & self ) -> usize {
153- self . max_events
192+ #[ cfg( target_os = "cuda" ) ]
193+ impl < ReportSpeciation : Boolean , ReportDispersal : Boolean >
194+ EventBuffer < ReportSpeciation , ReportDispersal >
195+ {
196+ fn report_event (
197+ & mut self ,
198+ event : impl Into < <EventBuffer < ReportSpeciation , ReportDispersal > as EventType >:: Event > ,
199+ ) {
200+ if let ( [ mask, mask_rest @ ..] , [ buffer, buffer_rest @ ..] ) = (
201+ core:: mem:: take ( & mut self . event_mask . 0 ) ,
202+ core:: mem:: take ( & mut self . event_buffer . 0 ) ,
203+ ) {
204+ mask. write ( true ) ;
205+ buffer. write ( MaybeSome :: Some ( event. into ( ) ) ) ;
206+
207+ self . event_mask . 0 = mask_rest;
208+ self . event_buffer . 0 = buffer_rest;
209+ }
154210 }
155211}
156212
@@ -169,19 +225,11 @@ impl<ReportSpeciation: Boolean, ReportDispersal: Boolean> Reporter
169225impl Reporter for EventBuffer < False , True > {
170226 impl_report ! (
171227 #[ debug_requires(
172- self . event_counter < self . max_events ,
228+ ! self . event_buffer . 0 . is_empty ( ) ,
173229 "does not report extraneous dispersal events"
174230 ) ]
175231 dispersal( & mut self , event: Used ) {
176- if let Some ( mask) = self . event_mask. get_mut( self . event_counter) {
177- mask. write( true ) ;
178-
179- unsafe {
180- self . event_buffer. get_unchecked_mut( self . event_counter)
181- } . write( MaybeSome :: Some ( event. clone( ) . into( ) ) ) ;
182- }
183-
184- self . event_counter += 1 ;
232+ self . report_event( event. clone( ) ) ;
185233 }
186234 ) ;
187235}
@@ -190,19 +238,14 @@ impl Reporter for EventBuffer<False, True> {
190238impl Reporter for EventBuffer < True , False > {
191239 impl_report ! (
192240 #[ debug_requires(
193- self . event_counter == 0 ,
241+ ! self . event_buffer . 0 . is_empty ( ) ,
194242 "does not report extraneous speciation events"
195243 ) ]
196244 speciation( & mut self , event: Used ) {
197- if let Some ( mask) = self . event_mask. get_mut( 0 ) {
198- mask. write( true ) ;
245+ self . report_event( event. clone( ) ) ;
199246
200- unsafe {
201- self . event_buffer. get_unchecked_mut( 0 )
202- } . write( MaybeSome :: Some ( event. clone( ) ) ) ;
203- }
204-
205- self . event_counter = self . max_events;
247+ self . event_mask. 0 = & mut [ ] ;
248+ self . event_buffer. 0 = & mut [ ] ;
206249 }
207250 ) ;
208251}
@@ -211,37 +254,57 @@ impl Reporter for EventBuffer<True, False> {
211254impl Reporter for EventBuffer < True , True > {
212255 impl_report ! (
213256 #[ debug_requires(
214- self . event_counter < self . max_events ,
257+ ! self . event_buffer . 0 . is_empty ( ) ,
215258 "does not report extraneous speciation events"
216259 ) ]
217260 speciation( & mut self , event: Used ) {
218- if let Some ( mask) = self . event_mask. get_mut( self . event_counter) {
219- mask. write( true ) ;
261+ self . report_event( event. clone( ) ) ;
220262
221- unsafe {
222- self . event_buffer. get_unchecked_mut( self . event_counter)
223- } . write( MaybeSome :: Some ( event. clone( ) . into( ) ) ) ;
224- }
225-
226- self . event_counter = self . max_events;
263+ self . event_mask. 0 = & mut [ ] ;
264+ self . event_buffer. 0 = & mut [ ] ;
227265 }
228266 ) ;
229267
230268 impl_report ! (
231269 #[ debug_requires(
232- self . event_counter < self . max_events ,
270+ ! self . event_buffer . 0 . is_empty ( ) ,
233271 "does not report extraneous dispersal events"
234272 ) ]
235273 dispersal( & mut self , event: Used ) {
236- if let Some ( mask) = self . event_mask. get_mut( self . event_counter) {
237- mask. write( true ) ;
238-
239- unsafe {
240- self . event_buffer. get_unchecked_mut( self . event_counter)
241- } . write( MaybeSome :: Some ( event. clone( ) . into( ) ) ) ;
242- }
243-
244- self . event_counter += 1 ;
274+ self . report_event( event. clone( ) ) ;
245275 }
246276 ) ;
247277}
278+
279+ // FIXME: find a less hacky hack
280+ struct CudaExchangeSlice < T : ' static + StackOnly + PortableBitSemantics + TypeGraphLayout > (
281+ & ' static mut [ T ] ,
282+ ) ;
283+
284+ impl <
285+ T : ' static + StackOnly + PortableBitSemantics + TypeGraphLayout ,
286+ const M2D : bool ,
287+ const M2H : bool ,
288+ > RustToCudaProxy < CudaExchangeSlice < CudaExchangeItem < T , M2D , M2H > > >
289+ for SplitSliceOverCudaThreadsDynamicStride < CudaExchangeBuffer < T , M2D , M2H > >
290+ {
291+ fn from_ref ( _val : & CudaExchangeSlice < CudaExchangeItem < T , M2D , M2H > > ) -> & Self {
292+ unsafe { unreachable_cuda_event_buffer_hack ( ) }
293+ }
294+
295+ fn from_mut ( _val : & mut CudaExchangeSlice < CudaExchangeItem < T , M2D , M2H > > ) -> & mut Self {
296+ unsafe { unreachable_cuda_event_buffer_hack ( ) }
297+ }
298+
299+ fn into ( mut self ) -> CudaExchangeSlice < CudaExchangeItem < T , M2D , M2H > > {
300+ let slice: & mut [ CudaExchangeItem < T , M2D , M2H > ] = & mut self ;
301+
302+ let slice = unsafe { core:: slice:: from_raw_parts_mut ( slice. as_mut_ptr ( ) , slice. len ( ) ) } ;
303+
304+ CudaExchangeSlice ( slice)
305+ }
306+ }
307+
308+ extern "C" {
309+ fn unreachable_cuda_event_buffer_hack ( ) -> !;
310+ }
0 commit comments