Skip to content

Commit 1277358

Browse files
committed
Addition of stride in API of conv
1 parent 00acae2 commit 1277358

File tree

6 files changed

+47
-10
lines changed

6 files changed

+47
-10
lines changed

src/nf/nf_conv1d_layer.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module nf_conv1d_layer
1515
integer :: channels
1616
integer :: kernel_size
1717
integer :: filters
18+
integer :: stride
1819

1920
real, allocatable :: biases(:) ! size(filters)
2021
real, allocatable :: kernel(:,:,:) ! filters x channels x window
@@ -39,12 +40,13 @@ module nf_conv1d_layer
3940
end type conv1d_layer
4041

4142
interface conv1d_layer
42-
module function conv1d_layer_cons(filters, kernel_size, activation) &
43+
module function conv1d_layer_cons(filters, kernel_size, activation, stride) &
4344
result(res)
4445
!! `conv1d_layer` constructor function
4546
integer, intent(in) :: filters
4647
integer, intent(in) :: kernel_size
4748
class(activation_function), intent(in) :: activation
49+
integer, intent(in) :: stride
4850
type(conv1d_layer) :: res
4951
end function conv1d_layer_cons
5052
end interface conv1d_layer

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
contains
99

10-
module function conv1d_layer_cons(filters, kernel_size, activation) result(res)
10+
module function conv1d_layer_cons(filters, kernel_size, activation, stride) result(res)
1111
integer, intent(in) :: filters
1212
integer, intent(in) :: kernel_size
1313
class(activation_function), intent(in) :: activation
14+
integer, intent(in) :: stride
1415
type(conv1d_layer) :: res
1516

1617
res % kernel_size = kernel_size
1718
res % filters = filters
1819
res % activation_name = activation % get_name()
20+
res % stride = stride
1921
allocate( res % activation, source = activation )
2022
end function conv1d_layer_cons
2123

src/nf/nf_conv2d_layer.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module nf_conv2d_layer
1616
integer :: channels
1717
integer :: kernel_size
1818
integer :: filters
19+
integer :: stride(2)
1920

2021
real, allocatable :: biases(:) ! size(filters)
2122
real, allocatable :: kernel(:,:,:,:) ! filters x channels x window x window
@@ -40,12 +41,13 @@ module nf_conv2d_layer
4041
end type conv2d_layer
4142

4243
interface conv2d_layer
43-
module function conv2d_layer_cons(filters, kernel_size, activation) &
44+
module function conv2d_layer_cons(filters, kernel_size, activation, stride) &
4445
result(res)
4546
!! `conv2d_layer` constructor function
4647
integer, intent(in) :: filters
4748
integer, intent(in) :: kernel_size
4849
class(activation_function), intent(in) :: activation
50+
integer, intent(in) :: stride(:)
4951
type(conv2d_layer) :: res
5052
end function conv2d_layer_cons
5153
end interface conv2d_layer

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77

88
contains
99

10-
module function conv2d_layer_cons(filters, kernel_size, activation) result(res)
10+
module function conv2d_layer_cons(filters, kernel_size, activation, stride) result(res)
1111
implicit none
1212
integer, intent(in) :: filters
1313
integer, intent(in) :: kernel_size
1414
class(activation_function), intent(in) :: activation
15+
integer, intent(in) :: stride(:)
1516
type(conv2d_layer) :: res
1617

1718
res % kernel_size = kernel_size
1819
res % filters = filters
1920
res % activation_name = activation % get_name()
21+
res % stride = stride
2022
allocate( res % activation, source = activation )
2123

2224
end function conv2d_layer_cons

src/nf/nf_layer_constructors.f90

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ end function input3d
9494

9595
interface conv
9696

97-
module function conv1d(filters, kernel_width, activation) result(res)
97+
module function conv1d(filters, kernel_width, activation, stride) result(res)
9898
!! 1-d convolutional layer constructor.
9999
!!
100100
!! This layer is for building 1-d convolutional network.
@@ -117,11 +117,13 @@ module function conv1d(filters, kernel_width, activation) result(res)
117117
!! Width of the convolution window, commonly 3 or 5
118118
class(activation_function), intent(in), optional :: activation
119119
!! Activation function (default sigmoid)
120+
integer, intent(in), optional :: stride
121+
!! Stride length of the convolution
120122
type(layer) :: res
121123
!! Resulting layer instance
122124
end function conv1d
123125

124-
module function conv2d(filters, kernel_width, kernel_height, activation) result(res)
126+
module function conv2d(filters, kernel_width, kernel_height, activation, stride) result(res)
125127
!! 2-d convolutional layer constructor.
126128
!!
127129
!! This layer is for building 2-d convolutional network.
@@ -147,6 +149,8 @@ module function conv2d(filters, kernel_width, kernel_height, activation) result(
147149
!! Height of the convolution window, commonly 3 or 5
148150
class(activation_function), intent(in), optional :: activation
149151
!! Activation function (default sigmoid)
152+
integer, intent(in), optional :: stride(:)
153+
!! Stride length of the convolution
150154
type(layer) :: res
151155
!! Resulting layer instance
152156
end function conv2d

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,19 @@
2323

2424
contains
2525

26-
module function conv1d(filters, kernel_width, activation) result(res)
26+
module function conv1d(filters, kernel_width, activation, stride) result(res)
2727
integer, intent(in) :: filters
2828
integer, intent(in) :: kernel_width
2929
class(activation_function), intent(in), optional :: activation
30+
integer, intent(in), optional :: stride
3031
type(layer) :: res
3132

33+
integer :: stride_tmp
3234
class(activation_function), allocatable :: activation_tmp
3335

36+
if (stride < 1) &
37+
error stop 'stride must be >= 1 in a conv1d layer'
38+
3439
res % name = 'conv1d'
3540

3641
if (present(activation)) then
@@ -41,20 +46,28 @@ module function conv1d(filters, kernel_width, activation) result(res)
4146

4247
res % activation = activation_tmp % get_name()
4348

49+
if (present(stride)) then
50+
stride_tmp = stride
51+
else
52+
stride_tmp = 1
53+
endif
54+
4455
allocate( &
4556
res % p, &
46-
source=conv1d_layer(filters, kernel_width, activation_tmp) &
57+
source=conv1d_layer(filters, kernel_width, activation_tmp, stride_tmp) &
4758
)
4859

4960
end function conv1d
5061

51-
module function conv2d(filters, kernel_width, kernel_height, activation) result(res)
62+
module function conv2d(filters, kernel_width, kernel_height, activation, stride) result(res)
5263
integer, intent(in) :: filters
5364
integer, intent(in) :: kernel_width
5465
integer, intent(in) :: kernel_height
5566
class(activation_function), intent(in), optional :: activation
67+
integer, intent(in), optional :: stride(:)
5668
type(layer) :: res
5769

70+
integer :: stride_tmp(2)
5871
class(activation_function), allocatable :: activation_tmp
5972

6073
! Enforce kernel_width == kernel_height for now;
@@ -63,6 +76,12 @@ module function conv2d(filters, kernel_width, kernel_height, activation) result(
6376
if (kernel_width /= kernel_height) &
6477
error stop 'kernel_width must equal kernel_height in a conv2d layer'
6578

79+
if (size(stride) /= 2 ) &
80+
error stop 'size of stride must be equal to 2 in a conv2d layer'
81+
82+
if (stride(1) < 1 .or. stride(2) < 1) &
83+
error stop 'stride must be >= 1 in a conv2d layer'
84+
6685
res % name = 'conv2d'
6786

6887
if (present(activation)) then
@@ -73,9 +92,15 @@ module function conv2d(filters, kernel_width, kernel_height, activation) result(
7392

7493
res % activation = activation_tmp % get_name()
7594

95+
if (present(stride)) then
96+
stride_tmp = stride
97+
else
98+
stride_tmp = [1, 1]
99+
endif
100+
76101
allocate( &
77102
res % p, &
78-
source=conv2d_layer(filters, kernel_width, activation_tmp) &
103+
source=conv2d_layer(filters, kernel_width, activation_tmp, stride) &
79104
)
80105

81106
end function conv2d

0 commit comments

Comments
 (0)