diff --git a/operators/src/conv/args.rs b/operators/src/conv/args.rs index f2bb2cc..1c5cfec 100644 --- a/operators/src/conv/args.rs +++ b/operators/src/conv/args.rs @@ -11,8 +11,8 @@ pub struct Args { pub x_base: ConstPtr, pub w_layout: TensorLayout, pub w_base: ConstPtr, - pub b_layout: TensorLayout, - pub b_base: ConstPtr, + pub b_layout: Option, + pub b_base: Option>, pub strides: [usize; 2], pub dilations: [usize; 2], pub pads: [usize; 4], @@ -50,12 +50,17 @@ impl Args { let &[m, ck, hk, wk] = &*w_layout.shape() else { return Err(rank_error("w", 4, w_layout.ndim())); }; - let &[mb] = &*b_layout.shape() else { - return Err(rank_error("b", 1, b_layout.ndim())); + let (mb, b_layout_dt) = if let Some(b_layout) = b_layout { + let &[mb] = &*b_layout.shape() else { + return Err(rank_error("b", 1, b_layout.ndim())); + }; + (mb, b_layout.dt) + } else { + (m, y_layout.dt) }; Ok(Meta { - dt: type_distinct(&[y_layout.dt, x_layout.dt, w_layout.dt, b_layout.dt])?, + dt: type_distinct(&[y_layout.dt, x_layout.dt, w_layout.dt, b_layout_dt])?, n: dim_distinct(&[n, ny]).expect("n mismatch"), m: dim_distinct(&[m, my, mb]).expect("m mismatch"), c: dim_distinct(&[c, ck]).expect("c mismatch"), diff --git a/operators/src/conv/im2col.rs b/operators/src/conv/im2col.rs index 73f0260..c45037f 100644 --- a/operators/src/conv/im2col.rs +++ b/operators/src/conv/im2col.rs @@ -91,9 +91,6 @@ where let &[mks, cks, hks, wks] = w_layout.strides() else { unreachable!() }; - let &[mbs] = b_layout.strides() else { - unreachable!() - }; // 计算考虑空洞的 kernel size @@ -147,19 +144,24 @@ where let b_dst = TensorLayout { dt, layout: b_dst }; let b_src = TensorLayout { dt, layout: b_src }; - // b 布局广播 - let b = Arr4::new(&[n, m, hy * wy], &[0, mbs, 0], 0); - // 广播 b - self.rearrange.launch( - &rearrange::Args { - dst_layout: c_y.clone(), - dst_base: *y_base, - src_layout: TensorLayout::new(dt, b.shape(), b.strides()), - src_base: *b_base, - }, - workspace, - queue_alloc, - )?; + if let (Some(b_layout), Some(b_base)) = (b_layout, b_base) { + let &[mbs] = b_layout.strides() else { + unreachable!() + }; + // b 布局广播 + let b = Arr4::new(&[n, m, hy * wy], &[0, mbs, 0], 0); + // 广播 b + self.rearrange.launch( + &rearrange::Args { + dst_layout: c_y.clone(), + dst_base: *y_base, + src_layout: TensorLayout::new(dt, b.shape(), b.strides()), + src_base: *b_base, + }, + workspace, + queue_alloc, + )?; + } // 为 im2col 分配工作空间 let b_size = b_shape.iter().product::() * ele; @@ -181,7 +183,7 @@ where &mat_mul::Args { c_layout: c_y.clone(), c_base: *y_base, - beta: 1., + beta: if b_layout.is_some() { 1. } else { 0. }, a_layout: a_w, a_base: *w_base, b_layout: b_x,