2323
2424contains
2525
26- module function conv1d (filters , kernel_width , activation ) result(res)
26+ module function conv1d (filters , kernel_width , activation , stride ) result(res)
2727 integer , intent (in ) :: filters
2828 integer , intent (in ) :: kernel_width
2929 class(activation_function), intent (in ), optional :: activation
30+ integer , intent (in ), optional :: stride
3031 type (layer) :: res
3132
33+ integer :: stride_tmp
3234 class(activation_function), allocatable :: activation_tmp
3335
36+ if (stride < 1 ) &
37+ error stop ' stride must be >= 1 in a conv1d layer'
38+
3439 res % name = ' conv1d'
3540
3641 if (present (activation)) then
@@ -41,20 +46,28 @@ module function conv1d(filters, kernel_width, activation) result(res)
4146
4247 res % activation = activation_tmp % get_name()
4348
49+ if (present (stride)) then
50+ stride_tmp = stride
51+ else
52+ stride_tmp = 1
53+ endif
54+
4455 allocate ( &
4556 res % p, &
46- source= conv1d_layer(filters, kernel_width, activation_tmp) &
57+ source= conv1d_layer(filters, kernel_width, activation_tmp, stride_tmp ) &
4758 )
4859
4960 end function conv1d
5061
51- module function conv2d (filters , kernel_width , kernel_height , activation ) result(res)
62+ module function conv2d (filters , kernel_width , kernel_height , activation , stride ) result(res)
5263 integer , intent (in ) :: filters
5364 integer , intent (in ) :: kernel_width
5465 integer , intent (in ) :: kernel_height
5566 class(activation_function), intent (in ), optional :: activation
67+ integer , intent (in ), optional :: stride(:)
5668 type (layer) :: res
5769
70+ integer :: stride_tmp(2 )
5871 class(activation_function), allocatable :: activation_tmp
5972
6073 ! Enforce kernel_width == kernel_height for now;
@@ -63,6 +76,12 @@ module function conv2d(filters, kernel_width, kernel_height, activation) result(
6376 if (kernel_width /= kernel_height) &
6477 error stop ' kernel_width must equal kernel_height in a conv2d layer'
6578
79+ if (size (stride) /= 2 ) &
80+ error stop ' size of stride must be equal to 2 in a conv2d layer'
81+
82+ if (stride(1 ) < 1 .or. stride(2 ) < 1 ) &
83+ error stop ' stride must be >= 1 in a conv2d layer'
84+
6685 res % name = ' conv2d'
6786
6887 if (present (activation)) then
@@ -73,9 +92,15 @@ module function conv2d(filters, kernel_width, kernel_height, activation) result(
7392
7493 res % activation = activation_tmp % get_name()
7594
95+ if (present (stride)) then
96+ stride_tmp = stride
97+ else
98+ stride_tmp = [1 , 1 ]
99+ endif
100+
76101 allocate ( &
77102 res % p, &
78- source= conv2d_layer(filters, kernel_width, activation_tmp) &
103+ source= conv2d_layer(filters, kernel_width, activation_tmp, stride ) &
79104 )
80105
81106 end function conv2d
0 commit comments