@@ -149,14 +149,9 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
149149 ne_reshape_3d (ctx0, ne_view_1d (ctx0, QKVcur, N * n_embd, 1  * N * n_embd * ne_element_size (QKVcur)),
150150 n_embd / n_head, n_head, N),
151151 n_past, n_rot, 0 , 0 );
152-  if  (!run_mha_reordered) {
153-  Vcur = ne_transpose (
154-  ctx0, ne_reshape_2d (ctx0, ne_view_1d (ctx0, QKVcur, N * n_embd, 2  * N * n_embd * ne_element_size (QKVcur)),
155-  n_embd, N));
156-  } else  {
157-  Vcur = ne_reshape_3d (ctx0, ne_view_1d (ctx0, QKVcur, N * n_embd, 2  * N * n_embd * ne_element_size (QKVcur)),
158-  n_embd / n_head, n_head, N);
159-  }
152+  Vcur = ne_transpose (
153+  ctx0, ne_reshape_2d (ctx0, ne_view_1d (ctx0, QKVcur, N * n_embd, 2  * N * n_embd * ne_element_size (QKVcur)),
154+  n_embd, N));
160155
161156 } else  {
162157 Qcur = ne_rope_inplace (
@@ -165,13 +160,7 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
165160 Kcur = ne_rope_inplace (
166161 ctx0, ne_reshape_3d (ctx0, ne_mul_mat (ctx0, model.layers [il].attn [1 ], cur), n_embd / n_head, n_head, N),
167162 n_past, n_rot, 0 , 0 );
168-  if  (!run_mha_reordered) {
169-  Vcur = ne_transpose (ctx0, ne_reshape_2d (ctx0, ne_mul_mat (ctx0, model.layers [il].attn [2 ], cur), n_embd, N));
170-  } else  {
171-  Vcur = ne_rope_inplace (
172-  ctx0, ne_reshape_3d (ctx0, ne_mul_mat (ctx0, model.layers [il].attn [2 ], cur), n_embd / n_head, n_head, N),
173-  n_past, n_rot, 0 , 0 );
174-  }
163+  Vcur = ne_transpose (ctx0, ne_reshape_2d (ctx0, ne_mul_mat (ctx0, model.layers [il].attn [2 ], cur), n_embd, N));
175164 }
176165 ne_set_name (Qcur, " Qcur"  );
177166 ne_set_name (Kcur, " Kcur"  );
@@ -258,7 +247,9 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
258247 head_size, n_ctx, n_head, //  ne
259248 0 , 0 , //  nb (jblas managed)
260249 il * v_size); //  offset
261-  ne_build_forward_expand (&gf, ne_flash_attn_update_v (ctx0, v_cache, Vcur, n_past));
250+  //  jblas alway view V as (D, n_head, seq)
251+  const  auto  Vcur_plain = ne_reshape_3d (ctx0, ne_view_1d (ctx0, Vcur, n_embd * N, 0 ), n_embd / n_head, n_head, N);
252+  ne_build_forward_expand (&gf, ne_flash_attn_update_v (ctx0, v_cache, Vcur_plain, n_past));
262253 }
263254
264255 struct  ne_tensor * Q = ne_permute (ctx0, Qcur, 0 , 2 , 1 , 3 );
0 commit comments