|
1 | | -use super::{args::Scheme, Args, Rearrange}; |
| 1 | +use super::{args::Scheme, Args, Rearrange}; |
2 | 2 | use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError}; |
3 | 3 | use digit_layout::types; |
4 | 4 | use infini_op::{infiniop, AsRaw, Descriptor, Handle}; |
5 | | -use std::sync::Arc; |
| 5 | +use std::{ |
| 6 | + slice::{from_raw_parts, from_raw_parts_mut}, |
| 7 | + sync::Arc, |
| 8 | +}; |
6 | 9 |
|
7 | 10 | pub struct Operator(Arc<Handle>); |
8 | 11 |
|
@@ -39,14 +42,25 @@ impl crate::Operator for Operator { |
39 | 42 | use std::iter::once; |
40 | 43 |
|
41 | 44 | let scheme = Scheme::new(args)?; |
| 45 | + if scheme.ndim() == 0 { |
| 46 | + let unit = scheme.unit(); |
| 47 | + let dst = unsafe { from_raw_parts_mut(args.dst_base, unit) }; |
| 48 | + let src = unsafe { from_raw_parts(args.src_base, unit) }; |
| 49 | + queue_alloc.queue().memcpy_d2d(dst, src); |
| 50 | + return Ok(()); |
| 51 | + } |
| 52 | + |
| 53 | + let scheme = scheme.distribute_unit((0..=5).rev().map(|n| 32 * (1 << n))); |
| 54 | + let unit = scheme.unit(); |
| 55 | + |
42 | 56 | let dst = infini_op::Tensor::new( |
43 | 57 | types::U8, |
44 | | - scheme.shape().chain(once(scheme.unit())), |
| 58 | + scheme.shape().chain(once(unit)), |
45 | 59 | scheme.dst_strides().iter().cloned().chain(once(1)), |
46 | 60 | ); |
47 | 61 | let src = infini_op::Tensor::new( |
48 | 62 | types::U8, |
49 | | - scheme.shape().chain(once(scheme.unit())), |
| 63 | + scheme.shape().chain(once(unit)), |
50 | 64 | scheme.src_strides().iter().cloned().chain(once(1)), |
51 | 65 | ); |
52 | 66 |
|
|
0 commit comments