@@ -20,22 +20,18 @@ module nf_cross_attention_layer
20
20
end type cross_attention_layer
21
21
22
22
interface cross_attention_layer
23
- module function cross_attention_layer_cons (n_heads ) result(res)
24
- ! ! This function returns the `cross_attention_layer` instance.
25
- integer , intent (in ) :: sequence_length, model_dimension, n_heads
26
- type (cross_attention_layer) :: res
27
- end function cross_attention_layer_cons
23
+ module procedure cross_attention_layer_cons
28
24
end interface cross_attention_layer
29
25
30
26
contains
31
- module function cross_attention_layer_cons (n_heads ) result(res)
27
+ function cross_attention_layer_cons (n_heads ) result(res)
32
28
! ! This function returns the `cross_attention_layer` instance.
33
29
integer , intent (in ) :: n_heads
34
30
type (cross_attention_layer) :: res
35
31
res % n_heads = n_heads
36
32
end function cross_attention_layer_cons
37
33
38
- pure module subroutine backward(self, input, gradient)
34
+ pure subroutine backward (self , input , gradient )
39
35
! ! Cross Attention Back propagation
40
36
class(cross_attention_layer), intent (in out ) :: self
41
37
real , intent (in ) :: input(:, :, :)
@@ -46,7 +42,7 @@ pure module subroutine backward(self, input, gradient)
46
42
self % gradient(2 , :, :) = self % key_layer % gradient + self % value_layer % gradient
47
43
end subroutine backward
48
44
49
- pure module subroutine forward(self, input)
45
+ pure subroutine forward (self , input )
50
46
! ! Cross Attention Forward propagation
51
47
! ! Input Shape (kind, sequence_length, model_dimension)
52
48
! ! where kind is 1 for Query and 2 for Key-Value
@@ -56,7 +52,7 @@ pure module subroutine forward(self, input)
56
52
call self % common_forward(input(1 , :, :), input(2 , :, :), input(2 , :, :))
57
53
end subroutine forward
58
54
59
- module subroutine init (self , input_shape )
55
+ subroutine init (self , input_shape )
60
56
class(cross_attention_layer), intent (in out ) :: self
61
57
integer , intent (in ) :: input_shape(:)
62
58
0 commit comments