Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit ba29b19

Browse files
authored
Fix q40 gptj with MHA fusion enabled & remove logits.txt (#285)
Signed-off-by: Ding, Yi1 <yi1.ding@intel.com>
1 parent b7e6cc2 commit ba29b19

File tree

3 files changed

+11
-29
lines changed

3 files changed

+11
-29
lines changed

intel_extension_for_transformers/llm/runtime/graph/application/main_run.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,9 @@ int main(int argc, char** argv) {
461461

462462
std::vector<model_token_data> candidates;
463463
candidates.reserve(n_vocab);
464-
std::ofstream outFile("logits.txt", std::ios::app);
465464
for (model_token token_id = 0; token_id < n_vocab; token_id++) {
466-
outFile << logits[token_id] << " ";
467465
candidates.emplace_back(model_token_data{token_id, logits[token_id], 0.0f});
468466
}
469-
outFile << "\n";
470467

471468
model_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
472469

intel_extension_for_transformers/llm/runtime/graph/models/gptj/gptj.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,7 @@ static bool gptj_model_eval_internal(model_context& lctx, const model_token* tok
173173
1 * N * n_embd * batch_size * ne_element_size(QKVcur)),
174174
n_embd / n_head, n_head, N, batch_size),
175175
n_past, n_rot, 0, 0);
176-
if (!run_mha_reordered) {
177-
Vcur = ne_view_1d(ctx0, QKVcur, N * n_embd * batch_size, 2 * N * n_embd * batch_size * ne_element_size(QKVcur));
178-
} else {
179-
Vcur = ne_reshape_4d(
180-
ctx0,
181-
ne_view_1d(ctx0, QKVcur, N * n_embd * batch_size, 2 * N * n_embd * batch_size * ne_element_size(QKVcur)),
182-
n_embd / n_head, n_head, N, batch_size);
183-
}
184-
176+
Vcur = ne_view_1d(ctx0, QKVcur, N * n_embd * batch_size, 2 * N * n_embd * batch_size * ne_element_size(QKVcur));
185177
} else {
186178
if (!enable_tp) {
187179
// printf("\n\n\n work into attention split,\n\n\n");
@@ -291,7 +283,9 @@ static bool gptj_model_eval_internal(model_context& lctx, const model_token* tok
291283
head_size, n_ctx, n_head, batch_size, // ne
292284
0, 0, v_size, // nb (jblas managed)
293285
il * kv_n_ctx_block * v_size); // offset
294-
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past));
286+
// jblas alway view V as (D, n_head, seq, bs)
287+
const auto Vcur_plain = ne_reshape_4d(ctx0, Vcur, n_embd / n_head, n_head, N, batch_size);
288+
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur_plain, n_past));
295289
}
296290

297291
struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3);

intel_extension_for_transformers/llm/runtime/graph/models/llama/llama.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,9 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
149149
ne_reshape_3d(ctx0, ne_view_1d(ctx0, QKVcur, N * n_embd, 1 * N * n_embd * ne_element_size(QKVcur)),
150150
n_embd / n_head, n_head, N),
151151
n_past, n_rot, 0, 0);
152-
if (!run_mha_reordered) {
153-
Vcur = ne_transpose(
154-
ctx0, ne_reshape_2d(ctx0, ne_view_1d(ctx0, QKVcur, N * n_embd, 2 * N * n_embd * ne_element_size(QKVcur)),
155-
n_embd, N));
156-
} else {
157-
Vcur = ne_reshape_3d(ctx0, ne_view_1d(ctx0, QKVcur, N * n_embd, 2 * N * n_embd * ne_element_size(QKVcur)),
158-
n_embd / n_head, n_head, N);
159-
}
152+
Vcur = ne_transpose(
153+
ctx0, ne_reshape_2d(ctx0, ne_view_1d(ctx0, QKVcur, N * n_embd, 2 * N * n_embd * ne_element_size(QKVcur)),
154+
n_embd, N));
160155

161156
} else {
162157
Qcur = ne_rope_inplace(
@@ -165,13 +160,7 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
165160
Kcur = ne_rope_inplace(
166161
ctx0, ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), n_embd / n_head, n_head, N),
167162
n_past, n_rot, 0, 0);
168-
if (!run_mha_reordered) {
169-
Vcur = ne_transpose(ctx0, ne_reshape_2d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), n_embd, N));
170-
} else {
171-
Vcur = ne_rope_inplace(
172-
ctx0, ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), n_embd / n_head, n_head, N),
173-
n_past, n_rot, 0, 0);
174-
}
163+
Vcur = ne_transpose(ctx0, ne_reshape_2d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), n_embd, N));
175164
}
176165
ne_set_name(Qcur, "Qcur");
177166
ne_set_name(Kcur, "Kcur");
@@ -258,7 +247,9 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
258247
head_size, n_ctx, n_head, // ne
259248
0, 0, // nb (jblas managed)
260249
il * v_size); // offset
261-
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past));
250+
// jblas alway view V as (D, n_head, seq)
251+
const auto Vcur_plain = ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd * N, 0), n_embd / n_head, n_head, N);
252+
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur_plain, n_past));
262253
}
263254

264255
struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3);

0 commit comments

Comments
 (0)