Skip to content

Commit 9789a44

Browse files
committed
bug fix
1 parent e9777f2 commit 9789a44

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/paddlefleet/models/gpt/gpt_embedding.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,16 @@ def forward(
116116
).contiguous()
117117

118118
preproc_output = {
119-
# "input_ids": input_ids.contiguous(),
119+
"input_ids": input_ids.contiguous(),
120120
"hidden_states": decoder_input,
121121
"rotary_pos_emb": rotary_pos_emb,
122122
"rotary_pos_cos": rotary_pos_cos,
123123
"rotary_pos_sin": rotary_pos_sin,
124-
"embedding_weight": self.embedding_weight,
125-
"position_embedding_weight": self.position_embedding_weight,
126-
}
124+
# "embedding_weight": self.embedding_weight,
125+
# "position_embedding_weight": self.position_embedding_weight,
126+
} # pass these two weights will cause error in backward
127127

128-
# preproc_output = {**dict_args, **preproc_output}
128+
preproc_output = {**dict_args, **preproc_output}
129129

130130
for key in list(preproc_output.keys()):
131131
if preproc_output[key] is None:

0 commit comments

Comments
 (0)