1616
1717//! A simple application to run a compute shader.
1818
19- mod encode;
20-
2119use std:: time:: Instant ;
2220
2321use wgpu:: util:: DeviceExt ;
2422
25- use encode :: Codable ;
23+ use bytemuck ;
2624
2725async fn run ( ) {
28- let instance = wgpu:: Instance :: new ( wgpu:: BackendBit :: PRIMARY ) ;
26+ let instance = wgpu:: Instance :: new ( wgpu:: Backends :: PRIMARY ) ;
2927 let adapter = instance. request_adapter ( & Default :: default ( ) ) . await . unwrap ( ) ;
3028 let features = adapter. features ( ) ;
3129 let ( device, queue) = adapter
@@ -39,17 +37,11 @@ async fn run() {
3937 )
4038 . await
4139 . unwrap ( ) ;
42- let mut shader_flags = wgpu:: ShaderFlags :: VALIDATION ;
43- if matches ! (
44- adapter. get_info( ) . backend,
45- wgpu:: Backend :: Vulkan | wgpu:: Backend :: Metal | wgpu:: Backend :: Gl
46- ) {
47- shader_flags |= wgpu:: ShaderFlags :: EXPERIMENTAL_TRANSLATION ;
48- }
4940 let query_set = if features. contains ( wgpu:: Features :: TIMESTAMP_QUERY ) {
5041 Some ( device. create_query_set ( & wgpu:: QuerySetDescriptor {
5142 count : 2 ,
5243 ty : wgpu:: QueryType :: Timestamp ,
44+ label : None ,
5345 } ) )
5446 } else {
5547 None
@@ -60,28 +52,28 @@ async fn run() {
6052 label : None ,
6153 //source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("alu.spv")).into()),
6254 source : wgpu:: ShaderSource :: Wgsl ( include_str ! ( "shader.wgsl" ) . into ( ) ) ,
63- flags : shader_flags,
6455 } ) ;
6556 println ! ( "shader compilation {:?}" , start_instant. elapsed( ) ) ;
66- let input: Vec < u8 > = Codable :: encode_vec ( & [ 1.0f32 , 2.0f32 ] ) ;
57+ let input_f = & [ 1.0f32 , 2.0f32 ] ;
58+ let input : & [ u8 ] = bytemuck:: bytes_of ( input_f) ;
6759 let input_buf = device. create_buffer_init ( & wgpu:: util:: BufferInitDescriptor {
6860 label : None ,
69- contents : & input,
70- usage : wgpu:: BufferUsage :: STORAGE
71- | wgpu:: BufferUsage :: COPY_DST
72- | wgpu:: BufferUsage :: COPY_SRC ,
61+ contents : input,
62+ usage : wgpu:: BufferUsages :: STORAGE
63+ | wgpu:: BufferUsages :: COPY_DST
64+ | wgpu:: BufferUsages :: COPY_SRC ,
7365 } ) ;
7466 let output_buf = device. create_buffer ( & wgpu:: BufferDescriptor {
7567 label : None ,
7668 size : input. len ( ) as u64 ,
77- usage : wgpu:: BufferUsage :: MAP_READ | wgpu:: BufferUsage :: COPY_DST ,
69+ usage : wgpu:: BufferUsages :: MAP_READ | wgpu:: BufferUsages :: COPY_DST ,
7870 mapped_at_creation : false ,
7971 } ) ;
8072 // This works if the buffer is initialized, otherwise reads all 0, for some reason.
8173 let query_buf = device. create_buffer_init ( & wgpu:: util:: BufferInitDescriptor {
8274 label : None ,
8375 contents : & [ 0 ; 16 ] ,
84- usage : wgpu:: BufferUsage :: MAP_READ | wgpu:: BufferUsage :: COPY_DST ,
76+ usage : wgpu:: BufferUsages :: MAP_READ | wgpu:: BufferUsages :: COPY_DST ,
8577 } ) ;
8678
8779 let pipeline = device. create_compute_pipeline ( & wgpu:: ComputePipelineDescriptor {
@@ -109,7 +101,7 @@ async fn run() {
109101 let mut cpass = encoder. begin_compute_pass ( & Default :: default ( ) ) ;
110102 cpass. set_pipeline ( & pipeline) ;
111103 cpass. set_bind_group ( 0 , & bind_group, & [ ] ) ;
112- cpass. dispatch ( 1 , 1 , 1 ) ;
104+ cpass. dispatch ( input_f . len ( ) as u32 , 1 , 1 ) ;
113105 }
114106 if let Some ( query_set) = & query_set {
115107 encoder. write_timestamp ( query_set, 1 ) ;
@@ -128,32 +120,18 @@ async fn run() {
128120 device. poll ( wgpu:: Maintain :: Wait ) ;
129121 println ! ( "post-poll {:?}" , std:: time:: Instant :: now( ) ) ;
130122 if buf_future. await . is_ok ( ) {
131- let data = buf_slice. get_mapped_range ( ) ;
123+ let data_raw = & * buf_slice. get_mapped_range ( ) ;
124+ let data : & [ f32 ] = bytemuck:: cast_slice ( data_raw) ;
132125 println ! ( "data: {:?}" , & * data) ;
133126 }
134127 if features. contains ( wgpu:: Features :: TIMESTAMP_QUERY ) {
135128 let ts_period = queue. get_timestamp_period ( ) ;
136- let ts_data: Vec < u64 > = Codable :: decode_vec ( & * query_slice. get_mapped_range ( ) ) ;
137- let ts_data = ts_data
138- . iter ( )
139- . map ( |ts| * ts as f64 * ts_period as f64 * 1e-6 )
140- . collect :: < Vec < _ > > ( ) ;
141- println ! ( "compute shader elapsed: {:?}ms" , ts_data[ 1 ] - ts_data[ 0 ] ) ;
129+ let ts_data_raw = & * query_slice. get_mapped_range ( ) ;
130+ let ts_data : & [ u64 ] = bytemuck:: cast_slice ( ts_data_raw) ;
131+ println ! ( "compute shader elapsed: {:?}ms" , ( ts_data[ 1 ] - ts_data[ 0 ] ) as f64 * ts_period as f64 * 1e-6 ) ;
142132 }
143133}
144134
145- #[ allow( unused) ]
146- fn bytes_to_u32 ( bytes : & [ u8 ] ) -> Vec < u32 > {
147- bytes
148- . chunks_exact ( 4 )
149- . map ( |b| {
150- let mut bytes = [ 0 ; 4 ] ;
151- bytes. copy_from_slice ( b) ;
152- u32:: from_le_bytes ( bytes)
153- } )
154- . collect ( )
155- }
156-
157135fn main ( ) {
158136 pollster:: block_on ( run ( ) ) ;
159137}
0 commit comments