Skip to content

Commit a645f82

Browse files
authored
Merge pull request #17 from qinyiqun/fix_infini_rearrange
fix: 修复inifni rearrange unit传递不正确的问题
2 parents 8712870 + 08b8bf0 commit a645f82

File tree

1 file changed

+18
-4
lines changed
  • operators/src/rearrange/infini

1 file changed

+18
-4
lines changed

operators/src/rearrange/infini/mod.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
use super::{args::Scheme, Args, Rearrange};
1+
use super::{args::Scheme, Args, Rearrange};
22
use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError};
33
use digit_layout::types;
44
use infini_op::{infiniop, AsRaw, Descriptor, Handle};
5-
use std::sync::Arc;
5+
use std::{
6+
slice::{from_raw_parts, from_raw_parts_mut},
7+
sync::Arc,
8+
};
69

710
pub struct Operator(Arc<Handle>);
811

@@ -39,14 +42,25 @@ impl crate::Operator for Operator {
3942
use std::iter::once;
4043

4144
let scheme = Scheme::new(args)?;
45+
if scheme.ndim() == 0 {
46+
let unit = scheme.unit();
47+
let dst = unsafe { from_raw_parts_mut(args.dst_base, unit) };
48+
let src = unsafe { from_raw_parts(args.src_base, unit) };
49+
queue_alloc.queue().memcpy_d2d(dst, src);
50+
return Ok(());
51+
}
52+
53+
let scheme = scheme.distribute_unit((0..=5).rev().map(|n| 32 * (1 << n)));
54+
let unit = scheme.unit();
55+
4256
let dst = infini_op::Tensor::new(
4357
types::U8,
44-
scheme.shape().chain(once(scheme.unit())),
58+
scheme.shape().chain(once(unit)),
4559
scheme.dst_strides().iter().cloned().chain(once(1)),
4660
);
4761
let src = infini_op::Tensor::new(
4862
types::U8,
49-
scheme.shape().chain(once(scheme.unit())),
63+
scheme.shape().chain(once(unit)),
5064
scheme.src_strides().iter().cloned().chain(once(1)),
5165
);
5266

0 commit comments

Comments
 (0)