Skip to content

Commit f54647c

Browse files
authored
fix: support tensor labels in DataCollatorWithFlattening (#42620)
* fix: support tensor labels in DataCollatorWithFlattening - Add tensor to list conversion in DataCollatorWithFlattening - Convert input_ids and labels to list if they are tensors - Add tests for both tensor and list labels - Fixes #42599 * style: fix whitespace linting errors * style: apply ruff format to test file
1 parent 1b8ccf1 commit f54647c

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

src/transformers/data/data_collator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1413,9 +1413,17 @@ def __call__(self, features, return_tensors=None, separator_id=None):
14131413
max_length = 0
14141414
for seq_idx, sample in enumerate(features):
14151415
input_ids = sample["input_ids"]
1416+
# Convert to list if tensor
1417+
if hasattr(input_ids, "tolist"):
1418+
input_ids = input_ids.tolist()
14161419
batch["input_ids"] += input_ids
1420+
14171421
if is_labels_provided:
1418-
batch["labels"] += [separator_id] + sample["labels"][1:]
1422+
labels = sample["labels"]
1423+
# Convert to list if tensor
1424+
if hasattr(labels, "tolist"):
1425+
labels = labels.tolist()
1426+
batch["labels"] += [separator_id] + labels[1:]
14191427
else:
14201428
batch["labels"] += [separator_id] + input_ids[1:]
14211429
if self.return_position_ids:

tests/trainer/test_data_collator.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,3 +1965,55 @@ def test__whole_word_mask(self):
19651965
).astype(bool)
19661966

19671967
np.testing.assert_array_equal(output_mask, expected_mask)
1968+
1969+
1970+
class DataCollatorWithFlatteningTest(unittest.TestCase):
1971+
"""Tests for DataCollatorWithFlattening"""
1972+
1973+
def test_flattening_with_tensor_labels(self):
1974+
"""Test that DataCollatorWithFlattening supports tensor labels (fixes issue #42599)."""
1975+
features = [
1976+
{
1977+
"input_ids": torch.tensor([1, 2, 3, 4]),
1978+
"labels": torch.tensor([10, 11, 12, 13]),
1979+
},
1980+
{
1981+
"input_ids": torch.tensor([5, 6, 7]),
1982+
"labels": torch.tensor([14, 15, 16]),
1983+
},
1984+
]
1985+
collator = DataCollatorWithFlattening(return_tensors="pt")
1986+
1987+
# This should not raise TypeError anymore
1988+
batch = collator(features)
1989+
1990+
# Verify the output
1991+
self.assertIsInstance(batch, dict)
1992+
self.assertIn("input_ids", batch)
1993+
self.assertIn("labels", batch)
1994+
self.assertIn("position_ids", batch)
1995+
1996+
# Check shapes
1997+
self.assertEqual(batch["input_ids"].shape, (1, 7)) # 4 + 3 tokens
1998+
self.assertEqual(batch["labels"].shape, (1, 7))
1999+
self.assertEqual(batch["position_ids"].shape, (1, 7))
2000+
2001+
def test_flattening_with_list_labels(self):
2002+
"""Test that DataCollatorWithFlattening still works with list labels."""
2003+
features = [
2004+
{
2005+
"input_ids": torch.tensor([1, 2, 3, 4]),
2006+
"labels": [10, 11, 12, 13],
2007+
},
2008+
{
2009+
"input_ids": torch.tensor([5, 6, 7]),
2010+
"labels": [14, 15, 16],
2011+
},
2012+
]
2013+
collator = DataCollatorWithFlattening(return_tensors="pt")
2014+
batch = collator(features)
2015+
2016+
# Verify it still works with lists
2017+
self.assertIsInstance(batch, dict)
2018+
self.assertEqual(batch["input_ids"].shape, (1, 7))
2019+
self.assertEqual(batch["labels"].shape, (1, 7))

0 commit comments

Comments
 (0)