@@ -22,15 +22,25 @@ use wgpu::util::DeviceExt;
2222
2323use bytemuck;
2424
25+ // A strategy of 0 is just atomic loads.
26+ // A strategy of 1 replaces the flag load with an atomicOr.
27+ const STRATEGY : u32 = 0 ;
28+
29+ const USE_SPIRV : bool = false ;
30+
2531async fn run ( ) {
2632 let instance = wgpu:: Instance :: new ( wgpu:: Backends :: PRIMARY ) ;
2733 let adapter = instance. request_adapter ( & Default :: default ( ) ) . await . unwrap ( ) ;
2834 let features = adapter. features ( ) ;
35+ let mut feature_mask = wgpu:: Features :: TIMESTAMP_QUERY | wgpu:: Features :: CLEAR_COMMANDS ;
36+ if USE_SPIRV {
37+ feature_mask |= wgpu:: Features :: SPIRV_SHADER_PASSTHROUGH ;
38+ }
2939 let ( device, queue) = adapter
3040 . request_device (
3141 & wgpu:: DeviceDescriptor {
3242 label : None ,
33- features : features & wgpu :: Features :: TIMESTAMP_QUERY ,
43+ features : features & feature_mask ,
3444 limits : Default :: default ( ) ,
3545 } ,
3646 None ,
@@ -48,26 +58,33 @@ async fn run() {
4858 } ;
4959
5060 let start_instant = Instant :: now ( ) ;
51- let cs_module = device. create_shader_module ( & wgpu:: ShaderModuleDescriptor {
52- label : None ,
53- //source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("alu.spv")).into()),
54- source : wgpu:: ShaderSource :: Wgsl ( include_str ! ( "shader.wgsl" ) . into ( ) ) ,
55- } ) ;
61+ let cs_module = if USE_SPIRV {
62+ let shader_src: & [ u32 ] = bytemuck:: cast_slice ( include_bytes ! ( "shader.spv" ) ) ;
63+ unsafe {
64+ device. create_shader_module_spirv ( & wgpu:: ShaderModuleDescriptorSpirV {
65+ label : None ,
66+ source : std:: borrow:: Cow :: Owned ( shader_src. into ( ) ) ,
67+ } )
68+ }
69+ } else {
70+ device. create_shader_module ( & wgpu:: ShaderModuleDescriptor {
71+ label : None ,
72+ source : wgpu:: ShaderSource :: Wgsl ( include_str ! ( "shader.wgsl" ) . into ( ) ) ,
73+ } )
74+ } ;
75+
76+
5677 println ! ( "shader compilation {:?}" , start_instant. elapsed( ) ) ;
57- let input_f = & [ 1.0f32 , 2.0f32 ] ;
58- let input : & [ u8 ] = bytemuck:: bytes_of ( input_f) ;
59- let input_buf = device. create_buffer_init ( & wgpu:: util:: BufferInitDescriptor {
78+ let data_buf = device. create_buffer ( & wgpu:: BufferDescriptor {
6079 label : None ,
61- contents : input,
62- usage : wgpu:: BufferUsages :: STORAGE
63- | wgpu:: BufferUsages :: COPY_DST
64- | wgpu:: BufferUsages :: COPY_SRC ,
80+ size : 0x80000 ,
81+ usage : wgpu:: BufferUsages :: STORAGE | wgpu:: BufferUsages :: COPY_DST ,
82+ mapped_at_creation : false ,
6583 } ) ;
66- let output_buf = device. create_buffer ( & wgpu:: BufferDescriptor {
84+ let config_buf = device. create_buffer_init ( & wgpu:: util :: BufferInitDescriptor {
6785 label : None ,
68- size : input. len ( ) as u64 ,
69- usage : wgpu:: BufferUsages :: MAP_READ | wgpu:: BufferUsages :: COPY_DST ,
70- mapped_at_creation : false ,
86+ contents : bytemuck:: bytes_of ( & [ STRATEGY , 0 ] ) ,
87+ usage : wgpu:: BufferUsages :: STORAGE | wgpu:: BufferUsages :: MAP_READ ,
7188 } ) ;
7289 // This works if the buffer is initialized, otherwise reads all 0, for some reason.
7390 let query_buf = device. create_buffer_init ( & wgpu:: util:: BufferInitDescriptor {
@@ -76,62 +93,98 @@ async fn run() {
7693 usage : wgpu:: BufferUsages :: MAP_READ | wgpu:: BufferUsages :: COPY_DST ,
7794 } ) ;
7895
96+ let bind_group_layout = device. create_bind_group_layout ( & wgpu:: BindGroupLayoutDescriptor {
97+ label : None ,
98+ entries : & [ wgpu:: BindGroupLayoutEntry {
99+ binding : 0 ,
100+ visibility : wgpu:: ShaderStages :: COMPUTE ,
101+ ty : wgpu:: BindingType :: Buffer {
102+ ty : wgpu:: BufferBindingType :: Storage { read_only : false } ,
103+ has_dynamic_offset : false ,
104+ min_binding_size : None ,
105+ } ,
106+ count : None ,
107+ } ,
108+ wgpu:: BindGroupLayoutEntry {
109+ binding : 1 ,
110+ visibility : wgpu:: ShaderStages :: COMPUTE ,
111+ ty : wgpu:: BindingType :: Buffer {
112+ ty : wgpu:: BufferBindingType :: Storage { read_only : false } ,
113+ has_dynamic_offset : false ,
114+ min_binding_size : None ,
115+ } ,
116+ count : None ,
117+ } ] ,
118+ } ) ;
119+ let compute_pipeline_layout = device. create_pipeline_layout ( & wgpu:: PipelineLayoutDescriptor {
120+ label : None ,
121+ bind_group_layouts : & [ & bind_group_layout] ,
122+ push_constant_ranges : & [ ] ,
123+ } ) ;
79124 let pipeline = device. create_compute_pipeline ( & wgpu:: ComputePipelineDescriptor {
80125 label : None ,
81- layout : None ,
126+ layout : Some ( & compute_pipeline_layout ) ,
82127 module : & cs_module,
83128 entry_point : "main" ,
84129 } ) ;
85130
86- let bind_group_layout = pipeline. get_bind_group_layout ( 0 ) ;
87131 let bind_group = device. create_bind_group ( & wgpu:: BindGroupDescriptor {
88132 label : None ,
89133 layout : & bind_group_layout,
90- entries : & [ wgpu:: BindGroupEntry {
91- binding : 0 ,
92- resource : input_buf. as_entire_binding ( ) ,
93- } ] ,
134+ entries : & [
135+ wgpu:: BindGroupEntry {
136+ binding : 0 ,
137+ resource : data_buf. as_entire_binding ( ) ,
138+ } ,
139+ wgpu:: BindGroupEntry {
140+ binding : 1 ,
141+ resource : config_buf. as_entire_binding ( ) ,
142+ } ,
143+ ] ,
94144 } ) ;
95145
96146 let mut encoder = device. create_command_encoder ( & Default :: default ( ) ) ;
97147 if let Some ( query_set) = & query_set {
98148 encoder. write_timestamp ( query_set, 0 ) ;
99149 }
150+ encoder. clear_buffer ( & data_buf, 0 , None ) ;
100151 {
101152 let mut cpass = encoder. begin_compute_pass ( & Default :: default ( ) ) ;
102153 cpass. set_pipeline ( & pipeline) ;
103154 cpass. set_bind_group ( 0 , & bind_group, & [ ] ) ;
104- cpass. dispatch ( input_f . len ( ) as u32 , 1 , 1 ) ;
155+ cpass. dispatch ( 256 , 1 , 1 ) ;
105156 }
106157 if let Some ( query_set) = & query_set {
107158 encoder. write_timestamp ( query_set, 1 ) ;
108159 }
109- encoder. copy_buffer_to_buffer ( & input_buf, 0 , & output_buf, 0 , input. len ( ) as u64 ) ;
160+ // encoder.copy_buffer_to_buffer(&input_buf, 0, &output_buf, 0, input.len() as u64);
110161 if let Some ( query_set) = & query_set {
111162 encoder. resolve_query_set ( query_set, 0 ..2 , & query_buf, 0 ) ;
112163 }
113164 queue. submit ( Some ( encoder. finish ( ) ) ) ;
114165
115- let buf_slice = output_buf . slice ( ..) ;
166+ let buf_slice = config_buf . slice ( ..) ;
116167 let buf_future = buf_slice. map_async ( wgpu:: MapMode :: Read ) ;
117168 let query_slice = query_buf. slice ( ..) ;
118169 let _query_future = query_slice. map_async ( wgpu:: MapMode :: Read ) ;
119- println ! ( "pre-poll {:?}" , std:: time:: Instant :: now( ) ) ;
120170 device. poll ( wgpu:: Maintain :: Wait ) ;
121- println ! ( "post-poll {:?}" , std:: time:: Instant :: now( ) ) ;
122171 if buf_future. await . is_ok ( ) {
123172 let data_raw = & * buf_slice. get_mapped_range ( ) ;
124- let data : & [ f32 ] = bytemuck:: cast_slice ( data_raw) ;
125- println ! ( "data: {:? }" , & * data) ;
173+ let data: & [ u32 ] = bytemuck:: cast_slice ( data_raw) ;
174+ println ! ( "failures with strategy {}: { }" , data[ 0 ] , data [ 1 ] ) ;
126175 }
127176 if features. contains ( wgpu:: Features :: TIMESTAMP_QUERY ) {
128177 let ts_period = queue. get_timestamp_period ( ) ;
129178 let ts_data_raw = & * query_slice. get_mapped_range ( ) ;
130- let ts_data : & [ u64 ] = bytemuck:: cast_slice ( ts_data_raw) ;
131- println ! ( "compute shader elapsed: {:?}ms" , ( ts_data[ 1 ] - ts_data[ 0 ] ) as f64 * ts_period as f64 * 1e-6 ) ;
179+ let ts_data: & [ u64 ] = bytemuck:: cast_slice ( ts_data_raw) ;
180+ println ! (
181+ "compute shader elapsed: {:?}ms" ,
182+ ( ts_data[ 1 ] - ts_data[ 0 ] ) as f64 * ts_period as f64 * 1e-6
183+ ) ;
132184 }
133185}
134186
135187fn main ( ) {
188+ env_logger:: init ( ) ;
136189 pollster:: block_on ( run ( ) ) ;
137190}
0 commit comments