@@ -39,10 +39,25 @@ module nf_multihead_attention_layer
3939 real , allocatable :: k_input(:, :)
4040 real , allocatable :: v_input(:, :)
4141 real , allocatable :: o_input(:, :)
42+
43+ ! temporary storages for forward and backward passes
44+ real , allocatable :: normalized_attention(:, :, :)
45+ real , allocatable :: q_or_dq(:, :, :)
46+ real , allocatable :: k_or_dk(:, :, :)
47+ real , allocatable :: v_or_dv(:, :, :)
48+ real , allocatable :: d_output(:, :, :)
49+ real , allocatable :: v_heads(:, :, :)
50+ real , allocatable :: k_heads(:, :, :)
51+ real , allocatable :: q_heads(:, :, :)
52+ real , allocatable :: d_sdpa(:, :)
53+ real , allocatable :: jacobian(:, :)
54+ real , allocatable :: d_normalize(:, :, :)
4255 contains
4356
4457 procedure :: common_backward
4558 procedure :: common_forward
59+ procedure :: sdpa_forward
60+ procedure :: sdpa_backward
4661 procedure :: get_num_params
4762 procedure :: get_params
4863 procedure :: get_gradients
@@ -68,25 +83,38 @@ end function multihead_attention_layer_cons
6883
6984 interface
7085
71- pure module subroutine common_backward(self, input, gradient)
86+ pure module subroutine common_backward(self, input, gradient, attention_mask )
7287 ! ! General backprop for MultiHead Attention mechanism
7388 ! ! Might be used for both Self and Cross Attention
7489 ! ! Self Attention: sum output gradients
7590 ! ! Cross Attention: use them separately
7691 class(multihead_attention_layer), intent (in out ) :: self
7792 real , intent (in ) :: input(:, :)
7893 real , intent (in ) :: gradient(:, :)
94+ real , optional , intent (in ) :: attention_mask(:, :)
7995 end subroutine common_backward
8096
81- pure module subroutine common_forward(self, query, key, value)
97+ pure module subroutine common_forward(self, query, key, value, attention_mask )
8298 ! ! General forward propagation for MultiHead Attention Mechanism
8399 ! ! Might be used for both Self and Cross Attention
84100 ! ! Self Attention: pass the same value thrice
85101 ! ! Cross Attention: pass three values for your query, key and value
86102 class(multihead_attention_layer), intent (in out ) :: self
87103 real , intent (in ) :: query(:, :), key(:, :), value(:, :)
104+ real , optional , intent (in ) :: attention_mask(:, :)
88105 end subroutine common_forward
89106
107+ pure module subroutine sdpa_forward(self, attention_mask)
108+ class(multihead_attention_layer), intent (in out ) :: self
109+ real , intent (in ), optional :: attention_mask(:, :)
110+ end subroutine sdpa_forward
111+
112+ pure module subroutine sdpa_backward(self, gradient, attention_mask)
113+ class(multihead_attention_layer), intent (in out ) :: self
114+ real , intent (in ) :: gradient(:, :)
115+ real , intent (in ), optional :: attention_mask(:, :)
116+ end subroutine sdpa_backward
117+
90118 pure module subroutine init(self, input_shape)
91119 ! ! Initialize the layer data structures.
92120 ! !
@@ -119,7 +147,7 @@ pure module subroutine normalize_attention_matrix(self, attention_mask)
119147 ! ! Output dims: sequence_length, sequence_length, n_heads
120148 class(multihead_attention_layer), intent (in out ) :: self
121149 ! ! (sequence_length, sequence_length, n_heads)
122- real , optional , intent (in ) :: attention_mask(:, :, : )
150+ real , optional , intent (in ) :: attention_mask(:, :)
123151 ! ! (sequence_length, sequence_length, n_heads)
124152 end subroutine normalize_attention_matrix
125153
@@ -143,18 +171,18 @@ elemental module function get_num_params(self) result(num_params)
143171 end function get_num_params
144172
145173 module function get_params (self ) result(params)
146- class(multihead_attention_layer), intent (in ), target :: self
174+ class(multihead_attention_layer), intent (in ) :: self
147175 real , allocatable :: params(:)
148176 end function get_params
149177
150178 module function get_gradients (self ) result(gradients)
151- class(multihead_attention_layer), intent (in ), target :: self
179+ class(multihead_attention_layer), intent (in ) :: self
152180 real , allocatable :: gradients(:)
153181 end function get_gradients
154182
155183 module subroutine set_params (self , params )
156184 class(multihead_attention_layer), intent (in out ) :: self
157- real , intent (in ), target :: params(:)
185+ real , intent (in ) :: params(:)
158186 end subroutine set_params
159187
160188 module subroutine init_base (self , input_shape )
0 commit comments