@@ -2,8 +2,8 @@ use super::{args::Scheme, Add, Args};
22use crate :: {
33 cuda:: { dt_name, Gpu , Handle , ModuleBox } ,
44 shape_not_support, strides_not_support,
5- utils:: { gcd, type_distinct } ,
6- ByteOf , LaunchError , QueueAlloc , SchemeDiversity , SchemeError ,
5+ utils:: gcd,
6+ ByteOf , LaunchError , QueueAlloc , SchemeDiversity ,
77} ;
88use digit_layout:: DigitLayout ;
99use lru:: LruCache ;
@@ -32,20 +32,6 @@ impl crate::Operator for Operator {
3232 }
3333 }
3434
35- #[ inline]
36- fn scheme (
37- & mut self ,
38- args : & Self :: Args ,
39- _max_workspace_size : usize ,
40- ) -> Result < usize , SchemeError > {
41- let dt = type_distinct ( & [ args. c_layout . dt ( ) , args. a_layout . dt ( ) , args. b_layout . dt ( ) ] ) ?;
42- self . schemes
43- . lock ( )
44- . unwrap ( )
45- . get_or_insert ( dt, || compile ( & self . handle , dt) ) ;
46- Ok ( 0 )
47- }
48-
4935 fn launch < QA > (
5036 & self ,
5137 args : & Self :: Args ,
@@ -60,20 +46,20 @@ impl crate::Operator for Operator {
6046 let count = scheme. count ( ) ;
6147
6248 let & [ 1 ] = scheme. idx_strides ( ) else {
63- return Err ( shape_not_support ( "" ) . into ( ) ) ;
49+ return Err ( shape_not_support ( "" ) ) ;
6450 } ;
6551 let & [ sc] = scheme. c_strides ( ) else {
66- return Err ( shape_not_support ( "" ) . into ( ) ) ;
52+ return Err ( shape_not_support ( "" ) ) ;
6753 } ;
6854 let & [ sa] = scheme. a_strides ( ) else {
69- return Err ( shape_not_support ( "" ) . into ( ) ) ;
55+ return Err ( shape_not_support ( "" ) ) ;
7056 } ;
7157 let & [ sb] = scheme. b_strides ( ) else {
72- return Err ( shape_not_support ( "" ) . into ( ) ) ;
58+ return Err ( shape_not_support ( "" ) ) ;
7359 } ;
7460 let unit = dt. nbytes ( ) as isize ;
7561 if sc != unit || sa != unit || sb != unit {
76- return Err ( strides_not_support ( "" ) . into ( ) ) ;
62+ return Err ( strides_not_support ( "" ) ) ;
7763 }
7864
7965 let block_dims = gcd ( count, self . max_threads_block ) ;
@@ -124,25 +110,12 @@ extern "C" __global__ void add(
124110#[ cfg( test) ]
125111mod test {
126112 use super :: { Args , Gpu , Operator } ;
127- use crate :: { dyn_ , Hardware , Operator as _, TensorLayout } ;
113+ use crate :: { Hardware , Operator as _, TensorLayout } ;
128114 use digit_layout:: {
129115 types:: { F16 , F64 } ,
130116 DigitLayout ,
131117 } ;
132- use std:: ptr:: null;
133118
134- fn dyn_args < H : Hardware > ( dt : DigitLayout ) -> Args < H > {
135- use std:: ptr:: null_mut;
136- let layout = TensorLayout :: new_dyn ( dt, & [ dyn_ ( ) ; 2 ] , & [ dyn_ ( ) ; 2 ] ) ;
137- Args {
138- c_layout : layout. clone ( ) ,
139- c_base : null_mut ( ) ,
140- a_layout : layout. clone ( ) ,
141- a_base : null ( ) ,
142- b_layout : layout. clone ( ) ,
143- b_base : null ( ) ,
144- }
145- }
146119 fn args < H : Hardware > (
147120 dt : DigitLayout ,
148121 n : usize ,
@@ -178,10 +151,8 @@ mod test {
178151 return ;
179152 } ;
180153
181- let mut cpu_op = RefOp :: new ( & Cpu ) ;
182- let mut gpu_op = Operator :: new ( & gpu) ;
183- cpu_op. scheme ( & dyn_args ( F64 ) , 0 ) . unwrap ( ) ;
184- gpu_op. scheme ( & dyn_args ( F16 ) , 0 ) . unwrap ( ) ;
154+ let cpu_op = RefOp :: new ( & Cpu ) ;
155+ let gpu_op = Operator :: new ( & gpu) ;
185156
186157 let n = 1 ;
187158 let d = 768 ;
0 commit comments