Skip to content

Commit 304c6a1

Browse files
authored
Enable fx tracing for Mistral (huggingface#30209)
* tracing for mistral * typo * fix copies
1 parent 98717cb commit 304c6a1

File tree

7 files changed

+9
-6
lines changed

7 files changed

+9
-6
lines changed

src/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,9 +868,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
868868
expert_layer = self.experts[expert_idx]
869869
idx, top_x = torch.where(expert_mask[expert_idx])
870870

871-
if top_x.shape[0] == 0:
872-
continue
873-
874871
# Index the correct hidden states and compute the expert hidden state for
875872
# the current expert. We need to make sure to multiply the output hidden
876873
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)

src/transformers/models/qwen2_moe/modeling_qwen2_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -840,9 +840,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
840840
expert_layer = self.experts[expert_idx]
841841
idx, top_x = torch.where(expert_mask[expert_idx])
842842

843-
if top_x.shape[0] == 0:
844-
continue
845-
846843
# Index the correct hidden states and compute the expert hidden state for
847844
# the current expert. We need to make sure to multiply the output hidden
848845
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)

src/transformers/utils/fx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,16 @@ def _generate_supported_model_class_names(
141141
"marian",
142142
"mbart",
143143
"megatron-bert",
144+
"mistral",
145+
"mixtral",
144146
"mobilebert",
145147
"mt5",
146148
"nezha",
147149
"opt",
148150
"pegasus",
149151
"plbart",
152+
"qwen2",
153+
"qwen2_moe",
150154
"resnet",
151155
"roberta",
152156
"segformer",
@@ -758,6 +762,7 @@ class HFTracer(Tracer):
758762
"tensor",
759763
"clamp",
760764
"finfo",
765+
"tril",
761766
]
762767
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
763768

tests/models/mistral/test_modeling_mistral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
303303
)
304304
test_headmasking = False
305305
test_pruning = False
306+
fx_compatible = True
306307

307308
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
308309
def is_pipeline_test_to_skip(

tests/models/mixtral/test_modeling_mixtral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
302302
)
303303
test_headmasking = False
304304
test_pruning = False
305+
fx_compatible = True
305306

306307
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
307308
def is_pipeline_test_to_skip(

tests/models/qwen2/test_modeling_qwen2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
313313
)
314314
test_headmasking = False
315315
test_pruning = False
316+
fx_compatible = True
316317

317318
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
318319
def is_pipeline_test_to_skip(

tests/models/qwen2_moe/test_modeling_qwen2_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
342342
)
343343
test_headmasking = False
344344
test_pruning = False
345+
fx_compatible = True
345346

346347
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
347348
def is_pipeline_test_to_skip(

0 commit comments

Comments
 (0)