Skip to content

Commit fc5a50d

Browse files
committed
feat: 添加Scale算子
1 parent 120d30f commit fc5a50d

File tree

7 files changed

+382
-0
lines changed

7 files changed

+382
-0
lines changed

operators/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub mod random_sample;
1818
pub mod rearrange;
1919
pub mod rms_norm;
2020
pub mod rope;
21+
pub mod scale;
2122
pub mod swiglu;
2223

2324
pub use common::*;

operators/src/scale/args.rs

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
use crate::{
2+
get_static, rank_mismatch, shape_mismatch, shape_not_support, utils::type_distinct, ConstPtr,
3+
Hardware, MutPtr, SchemeError, TensorLayout,
4+
};
5+
use digit_layout::DigitLayout;
6+
use itertools::izip;
7+
use std::{
8+
cmp::Ordering,
9+
ptr::{null, null_mut},
10+
};
11+
12+
#[derive(Clone)]
13+
pub struct Args<H: Hardware> {
14+
pub c_layout: TensorLayout,
15+
pub c_base: MutPtr<H>,
16+
pub a_layout: TensorLayout,
17+
pub a_base: ConstPtr<H>,
18+
pub scale: f32,
19+
}
20+
21+
impl<H: Hardware> Args<H> {
22+
pub fn new_null(
23+
c_layout: TensorLayout,
24+
a_layout: TensorLayout,
25+
b_layout: TensorLayout,
26+
) -> Self {
27+
Self {
28+
c_layout,
29+
c_base: null_mut(),
30+
a_layout,
31+
a_base: null(),
32+
scale: 1.0,
33+
}
34+
}
35+
}
36+
37+
#[derive(Clone, Debug)]
38+
pub(super) struct Scheme(DigitLayout, Box<[isize]>);
39+
40+
impl Scheme {
41+
pub fn new<H: Hardware>(args: &Args<H>) -> Result<Self, SchemeError> {
42+
let Args {
43+
c_layout: c,
44+
a_layout: a,
45+
..
46+
} = args;
47+
// # 检查基本属性
48+
let dt = type_distinct(&[c.dt(), a.dt()])?;
49+
let ndim = c.ndim();
50+
if a.ndim() != ndim {
51+
return Err(rank_mismatch(format!(
52+
"c.ndim = {}, a.ndim = {}",
53+
c.ndim(),
54+
a.ndim(),
55+
)));
56+
}
57+
// # 输入形状
58+
#[derive(Clone, PartialEq, Eq, Debug)]
59+
struct Dim {
60+
d: usize,
61+
c: isize,
62+
a: isize,
63+
}
64+
let mut dims = Vec::with_capacity(ndim);
65+
for (&d, &da, &sc, &sa) in izip!(c.shape(), a.shape(), c.strides(), a.strides(),) {
66+
get_static! {
67+
d da
68+
sc sa
69+
}
70+
if da != d {
71+
return Err(shape_mismatch(format!(
72+
"c: {:?}, a: {:?}",
73+
c.shape(),
74+
a.shape(),
75+
)));
76+
}
77+
// 剔除初始的 1 长维度
78+
if d != 1 {
79+
if sc == 0 {
80+
return Err(shape_not_support("Reducing is not allowed for scale"));
81+
}
82+
dims.push(Dim { d, c: sc, a: sa })
83+
}
84+
}
85+
// # 排序
86+
dims.sort_unstable_by(|dim0, dim1| {
87+
let &Dim {
88+
d: d0,
89+
c: c0,
90+
a: a0,
91+
} = dim0;
92+
let &Dim {
93+
d: d1,
94+
c: c1,
95+
a: a1,
96+
} = dim1;
97+
use Ordering::Equal as Eq;
98+
match c0.abs().cmp(&c1.abs()) {
99+
Eq => match a0.abs().cmp(&a1.abs()) {
100+
ord => ord.reverse(),
101+
},
102+
ord => ord.reverse(),
103+
}
104+
});
105+
// # 合并连续维度
106+
let mut ndim = dims.len();
107+
for i in (1..dims.len()).rev() {
108+
let (head, tail) = dims.split_at_mut(i);
109+
let f = &mut head[i - 1]; // f for front
110+
let b = &mut tail[0]; // b for back
111+
let d = b.d as isize;
112+
if b.c * d == f.c && b.a * d == f.a {
113+
*f = Dim { d: b.d * f.d, ..*b };
114+
*b = Dim { d: 1, c: 0, a: 0 };
115+
ndim -= 1
116+
}
117+
}
118+
// # 合并空间
119+
let mut layout = vec![0isize; 1 + ndim * 4].into_boxed_slice();
120+
{
121+
let (idx, tail) = layout.split_at_mut(1 + ndim);
122+
let (c_, tail) = tail.split_at_mut(ndim);
123+
let (a_, b_) = tail.split_at_mut(ndim);
124+
for (Dim { d, c, a }, idx, c_, a_) in
125+
izip!(dims.into_iter().filter(|d| d.d != 1), &mut *idx, c_, a_)
126+
{
127+
*idx = d as _;
128+
*c_ = c;
129+
*a_ = a;
130+
}
131+
idx[ndim] = 1;
132+
for i in (1..=ndim).rev() {
133+
idx[i - 1] *= idx[i];
134+
}
135+
}
136+
Ok(Self(dt, layout))
137+
}
138+
139+
#[inline]
140+
pub const fn dt(&self) -> DigitLayout {
141+
self.0
142+
}
143+
144+
/// 执行方案维数。
145+
#[inline]
146+
pub fn ndim(&self) -> usize {
147+
(self.1.len() - 1) / 4
148+
}
149+
150+
/// 读写单元数量。
151+
#[inline]
152+
pub fn count(&self) -> usize {
153+
self.1[0] as _
154+
}
155+
156+
/// 索引步长。
157+
#[inline]
158+
pub fn idx_strides(&self) -> &[isize] {
159+
let ndim = self.ndim();
160+
&self.1[1..][..ndim]
161+
}
162+
163+
#[inline]
164+
pub fn c_strides(&self) -> &[isize] {
165+
let ndim = self.ndim();
166+
&self.1[1 + ndim..][..ndim]
167+
}
168+
169+
#[inline]
170+
pub fn a_strides(&self) -> &[isize] {
171+
let ndim = self.ndim();
172+
&self.1[1 + ndim * 2..][..ndim]
173+
}
174+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use super::{args::Scheme, Args, Scale};
2+
use crate::{common_cpu::Cpu, ByteOf, LaunchError, QueueAlloc, SchemeError};
3+
use digit_layout::types as ty;
4+
use half::f16;
5+
use rayon::iter::{IntoParallelIterator, ParallelIterator};
6+
7+
pub struct Operator;
8+
9+
impl Scale<Cpu> for Operator {}
10+
11+
impl crate::Operator for Operator {
12+
type Hardware = Cpu;
13+
type TopoNode = Cpu;
14+
type Args = Args<Cpu>;
15+
16+
#[inline]
17+
fn new(_node: &Self::TopoNode) -> Self {
18+
Self
19+
}
20+
#[inline]
21+
fn scheme(
22+
&mut self,
23+
_args: &Self::Args,
24+
_max_workspace_size: usize,
25+
) -> Result<usize, SchemeError> {
26+
Ok(0)
27+
}
28+
29+
fn launch<QA>(
30+
&self,
31+
args: &Self::Args,
32+
_workspace: &mut [ByteOf<Self::Hardware>],
33+
_queue_alloc: &QA,
34+
) -> Result<(), LaunchError>
35+
where
36+
QA: QueueAlloc<Hardware = Self::Hardware>,
37+
{
38+
let scheme = Scheme::new(args)?;
39+
let c = args.c_base as isize;
40+
let a = args.a_base as isize;
41+
let s = args.scale;
42+
let idx_strides = scheme.idx_strides();
43+
let c_strides = scheme.c_strides();
44+
let a_strides = scheme.a_strides();
45+
(0..scheme.count() as isize)
46+
.into_par_iter()
47+
.for_each(|mut rem| {
48+
let mut c = c;
49+
let mut a = a;
50+
for (i, &s) in idx_strides.iter().enumerate() {
51+
let k = rem / s;
52+
c += k * c_strides[i];
53+
a += k * a_strides[i];
54+
rem %= s;
55+
}
56+
match scheme.dt() {
57+
ty::F16 => mul::<f16>(c, a, f16::from_f32(s)),
58+
ty::F32 => mul::<f32>(c, a, s),
59+
ty::F64 => mul::<f64>(c, a, s as f64),
60+
_ => todo!(),
61+
}
62+
});
63+
Ok(())
64+
}
65+
}
66+
67+
fn mul<T: std::ops::Mul<Output = T>>(c: isize, a: isize, s: T) {
68+
let c = c as *mut T;
69+
let a = a as *const T;
70+
unsafe { *c = a.read() * s }
71+
}

operators/src/scale/cuda/mod.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use super::{args::Scheme, Args, Scale};
2+
use crate::{
3+
cuda::{dt_name, Gpu, Handle, ModuleBox},
4+
shape_not_support, strides_not_support,
5+
utils::{gcd, type_distinct},
6+
ByteOf, LaunchError, QueueAlloc, SchemeDiversity, SchemeError,
7+
};
8+
use digit_layout::DigitLayout;
9+
use lru::LruCache;
10+
use std::{
11+
ffi::{c_uint, CString},
12+
sync::{Arc, Mutex},
13+
};
14+
15+
pub struct Operator {}
16+
impl Scale<Gpu> for Operator {}
17+
18+
impl crate::Operator for Operator {
19+
type Hardware = Gpu;
20+
type TopoNode = Gpu;
21+
type Args = Args<Gpu>;
22+
23+
fn new(node: &Self::TopoNode) -> Self {
24+
Self {}
25+
}
26+
27+
#[inline]
28+
fn scheme(
29+
&mut self,
30+
args: &Self::Args,
31+
_max_workspace_size: usize,
32+
) -> Result<usize, SchemeError> {
33+
todo!();
34+
Ok(0)
35+
}
36+
37+
fn launch<QA>(
38+
&self,
39+
_args: &Self::Args,
40+
_workspace: &mut [ByteOf<Self::Hardware>],
41+
_queue_alloc: &QA,
42+
) -> Result<(), LaunchError>
43+
where
44+
QA: QueueAlloc<Hardware = Self::Hardware>,
45+
{
46+
todo!();
47+
Ok(())
48+
}
49+
}

operators/src/scale/infini/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use super::{Args, Scale};
2+
use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError};
3+
4+
pub struct Operator;
5+
6+
impl Add<Device> for Operator {}
7+
8+
impl crate::Operator for Operator {
9+
type Hardware = Device;
10+
type TopoNode = Device;
11+
type Args = Args<Device>;
12+
13+
fn new(_node: &Self::TopoNode) -> Self {
14+
todo!()
15+
}
16+
17+
fn scheme(
18+
&mut self,
19+
_args: &Self::Args,
20+
_max_workspace_size: usize,
21+
) -> Result<usize, SchemeError> {
22+
todo!()
23+
}
24+
25+
fn launch<QA>(
26+
&self,
27+
_args: &Self::Args,
28+
_workspace: &mut [ByteOf<Self::Hardware>],
29+
_queue_alloc: &QA,
30+
) -> Result<(), LaunchError>
31+
where
32+
QA: QueueAlloc<Hardware = Self::Hardware>,
33+
{
34+
todo!()
35+
}
36+
}

operators/src/scale/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//! c =scale*a
2+
3+
#[cfg(any(use_cpu, test))]
4+
pub mod common_cpu;
5+
#[cfg(use_cuda)]
6+
pub mod cuda;
7+
#[cfg(use_infini)]
8+
pub mod infini;
9+
#[cfg(use_cl)]
10+
pub mod opencl;
11+
12+
mod args;
13+
pub use args::Args;
14+
15+
crate::op_trait!(Scale);

operators/src/scale/opencl/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use super::{Args, Scale};
2+
use crate::{opencl::ClDevice, ByteOf, LaunchError, QueueAlloc, SchemeError};
3+
4+
pub struct Operator;
5+
6+
impl Scale<ClDevice> for Operator {}
7+
8+
impl crate::Operator for Operator {
9+
type Hardware = ClDevice;
10+
type TopoNode = ClDevice;
11+
type Args = Args<ClDevice>;
12+
13+
fn new(_node: &Self::TopoNode) -> Self {
14+
todo!()
15+
}
16+
17+
fn scheme(
18+
&mut self,
19+
_args: &Self::Args,
20+
_max_workspace_size: usize,
21+
) -> Result<usize, SchemeError> {
22+
todo!()
23+
}
24+
25+
fn launch<QA>(
26+
&self,
27+
_args: &Self::Args,
28+
_workspace: &mut [ByteOf<Self::Hardware>],
29+
_queue_alloc: &QA,
30+
) -> Result<(), LaunchError>
31+
where
32+
QA: QueueAlloc<Hardware = Self::Hardware>,
33+
{
34+
todo!()
35+
}
36+
}

0 commit comments

Comments
 (0)