Skip to content

Commit 529d3a2

Browse files
authored
Fix Glm4vModelTest::test_eager_matches_fa2_generate (#40947)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent a2ac4de commit 529d3a2

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

tests/models/glm4v/test_modeling_glm4v.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def prepare_config_and_inputs_for_common(self):
159159

160160
inputs_dict = {
161161
"pixel_values": pixel_values,
162-
"image_grid_thw": torch.tensor([[1, patches_per_side, patches_per_side]] * self.batch_size),
162+
"image_grid_thw": torch.tensor(
163+
[[1, patches_per_side, patches_per_side]] * self.batch_size, device=torch_device
164+
),
163165
"input_ids": input_ids,
164166
"attention_mask": attention_mask,
165167
}

tests/models/glm4v_moe/test_modeling_glm4v_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def prepare_config_and_inputs_for_common(self):
170170

171171
inputs_dict = {
172172
"pixel_values": pixel_values,
173-
"image_grid_thw": torch.tensor([[1, patches_per_side, patches_per_side]] * self.batch_size),
173+
"image_grid_thw": torch.tensor(
174+
[[1, patches_per_side, patches_per_side]] * self.batch_size, device=torch_device
175+
),
174176
"input_ids": input_ids,
175177
"attention_mask": attention_mask,
176178
}

0 commit comments

Comments
 (0)