@@ -196,10 +196,9 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
196
196
197
197
function (c:: Conv )(x:: AbstractArray )
198
198
_conv_size_check (c, x)
199
- σ = NNlib. fast_act (c. σ, x)
200
199
cdims = conv_dims (c, x)
201
200
xT = _match_eltype (c, x)
202
- σ .( conv (xT, c. weight, cdims) .+ conv_reshape_bias (c))
201
+ NNlib . bias_act! (c . σ, conv (xT, c. weight, cdims), conv_reshape_bias (c))
203
202
end
204
203
205
204
_channels_in (l:: Conv ) = size (l. weight, ndims (l. weight)- 1 ) * l. groups
@@ -350,10 +349,9 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
350
349
351
350
function (c:: ConvTranspose )(x:: AbstractArray )
352
351
_conv_size_check (c, x)
353
- σ = NNlib. fast_act (c. σ, x)
354
352
cdims = conv_transpose_dims (c, x)
355
353
xT = _match_eltype (c, x)
356
- σ .( ∇conv_data (xT, c. weight, cdims) .+ conv_reshape_bias (c))
354
+ NNlib . bias_act! (c . σ, ∇conv_data (xT, c. weight, cdims), conv_reshape_bias (c))
357
355
end
358
356
359
357
function Base. show (io:: IO , l:: ConvTranspose )
@@ -493,10 +491,9 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
493
491
494
492
function (c:: CrossCor )(x:: AbstractArray )
495
493
_conv_size_check (c, x)
496
- σ = NNlib. fast_act (c. σ, x)
497
494
cdims = crosscor_dims (c, x)
498
495
xT = _match_eltype (c, x)
499
- σ .( crosscor (xT, c. weight, cdims) .+ conv_reshape_bias (c))
496
+ NNlib . bias_act! (c . σ, crosscor (xT, c. weight, cdims), conv_reshape_bias (c))
500
497
end
501
498
502
499
function Base. show (io:: IO , l:: CrossCor )
0 commit comments