Skip to content

Commit 9e4199e

Browse files
authored
Gemma3 fixes (#41572)
* Multiple device error fix * FA2 equivalence fix * Move the train fwd in cfg test * Style * Added comment * Made the comment more clear
1 parent 4c8d293 commit 9e4199e

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def create_causal_mask_mapping(
798798
is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
799799
new_image_start = is_image & ~is_previous_image
800800
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
801-
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
801+
image_group_ids = torch.where(is_image, image_group_ids, -1)
802802
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
803803
token_type_ids.to(cache_position.device), image_group_ids
804804
)

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def create_causal_mask_mapping(
764764
is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
765765
new_image_start = is_image & ~is_previous_image
766766
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
767-
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
767+
image_group_ids = torch.where(is_image, image_group_ids, -1)
768768
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
769769
token_type_ids.to(cache_position.device), image_group_ids
770770
)

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pytest
2121
from parameterized import parameterized
22+
from pytest import mark
2223

2324
from transformers import (
2425
AutoModelForCausalLM,
@@ -33,9 +34,11 @@
3334
is_flash_attn_2_available,
3435
require_deterministic_for_xpu,
3536
require_flash_attn,
37+
require_flash_attn_3,
3638
require_read_token,
3739
require_torch,
3840
require_torch_accelerator,
41+
require_torch_gpu,
3942
require_torch_large_accelerator,
4043
slow,
4144
torch_device,
@@ -342,6 +345,20 @@ def test_automodelforcausallm(self):
342345
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
343346
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
344347

348+
@require_flash_attn
349+
@require_torch_gpu
350+
@mark.flash_attn_test
351+
@slow
352+
def test_flash_attn_2_from_config(self):
353+
self.flash_attn_from_config(attn_implementation="flash_attention_2", test_fwd_in_train=False)
354+
355+
@require_flash_attn_3
356+
@require_torch_gpu
357+
@mark.flash_attn_3_test
358+
@slow
359+
def test_flash_attn_3_from_config(self):
360+
self.flash_attn_from_config(attn_implementation="flash_attention_3", test_fwd_in_train=False)
361+
345362

346363
@slow
347364
@require_torch_accelerator

tests/test_modeling_common.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,7 +2976,7 @@ def test_model_is_small(self):
29762976

29772977
def flash_attn_inference_equivalence(
29782978
self, attn_implementation: str, padding_side: str, atol: float = 4e-2, rtol: float = 4e-2
2979-
):
2979+
) -> None:
29802980
r"""
29812981
Tests the equivalence between the eager and flash attention implementations.
29822982
This test is only for inference and runs with `dtype=torch.bfloat16`.
@@ -3114,9 +3114,6 @@ def flash_attn_inference_equivalence(
31143114
torch.testing.assert_close(logits_1_eager, logits_1_fa, atol=atol, rtol=rtol)
31153115
if padding_side == "left":
31163116
torch.testing.assert_close(logits_2_eager[1:], logits_2_fa[1:], atol=atol, rtol=rtol)
3117-
# Check it can run in training mode
3118-
model.train()
3119-
_ = model(**second_inputs)
31203117
else:
31213118
torch.testing.assert_close(logits_2_eager[:-1], logits_2_fa[:-1], atol=atol, rtol=rtol)
31223119

@@ -3651,7 +3648,7 @@ def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(s
36513648

36523649
assert not loss.isnan().any()
36533650

3654-
def flash_attn_from_config(self, attn_implementation: str):
3651+
def flash_attn_from_config(self, attn_implementation: str, test_fwd_in_train: bool = True):
36553652
r"""
36563653
Tests if the model can be loaded with `attn_implementation` from the config and if the
36573654
weights are not randomly initialized.
@@ -3669,6 +3666,14 @@ def flash_attn_from_config(self, attn_implementation: str):
36693666
config, attn_implementation=attn_implementation, dtype=torch.bfloat16
36703667
).to(torch_device)
36713668

3669+
# By default, we perform the forward pass in train mode, because it's more sctrict than eval mode. If the
3670+
# forward pass is successful in train mode, it will also be successful in eval mode. But since some models
3671+
# (eg. gemma3) need different inputs in train mode we have the option to test the forward pass in eval mode.
3672+
if test_fwd_in_train:
3673+
fa_model = fa_model.train()
3674+
else:
3675+
fa_model = fa_model.eval()
3676+
36723677
dummy_input = inputs_dict[fa_model.main_input_name]
36733678
if dummy_input.dtype in [torch.float32, torch.float16]:
36743679
dummy_input = dummy_input.to(torch.bfloat16)

0 commit comments

Comments
 (0)