Skip to content

Conversation

@mahdilamb
Copy link
Contributor

@mahdilamb mahdilamb commented May 18, 2024

  • Enable using CutMix/MixUp with pre-encoded labels

Todo:

  • update test
  • check for already encoded inputs

cc @vfdev-5

@pytorch-bot
Copy link

pytorch-bot bot commented May 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8427

Note: Links to docs will display an error until the docs builds have been completed.

❌ 12 New Failures

As of commit 218fc58 with merge base 778ce48 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@NicolasHug
Copy link
Member

Thanks for the PR @mahdilamb .

Supporting labels that are already one-hot-encoded sounds OK to me, but instead of adding a new parameter, it seems that we could instead just check the shape of the labels and only call one_hot if the ndim != 2?

We would also need to add a few tests here

class TestCutMixMixUp:

@mahdilamb
Copy link
Contributor Author

Hi @NicolasHug makes sense to me. I'll get that moving

@mahdilamb
Copy link
Contributor Author

Hi @NicolasHug , that's updates as requested... it breaks Compose though!

@NicolasHug
Copy link
Member

Hi @mahdilamb - I've made a few changes to the PR locally but when I cannot push to update the PR, because you created from your main branch, so I don't have the permissions.

Would you mind closing this one and re-opening a new PR from a dev branch (i.e. do git checkout -b my_branch before committing)?

Alternatively you could also apply this diff to the current PR:

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 190b590c89..07235333af 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2169,29 +2169,30 @@ class TestAdjustBrightness: class TestCutMixMixUp: class DummyDataset: - def __init__(self, size, num_classes, encode_labels:bool): + def __init__(self, size, num_classes, one_hot_labels): self.size = size self.num_classes = num_classes - self.encode_labels = encode_labels + self.one_hot_labels = one_hot_labels assert size < num_classes def __getitem__(self, idx): img = torch.rand(3, 100, 100) - label = torch.tensor(idx) # This ensures all labels in a batch are unique and makes testing easier - if self.encode_labels: - label = torch.nn.functional.one_hot(label, num_classes=self.num_classes) + label = idx # This ensures all labels in a batch are unique and makes testing easier + if self.one_hot_labels: + label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes) return img, label def __len__(self): return self.size - @pytest.mark.parametrize(["T", "encode_labels"], [[transforms.CutMix, False], [transforms.MixUp, False], [transforms.CutMix, True], [transforms.MixUp, True]]) - def test_supported_input_structure(self, T, encode_labels: bool): + @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp]) + @pytest.mark.parametrize("one_hot_labels", (True, False)) + def test_supported_input_structure(self, T, one_hot_labels): batch_size = 32 num_classes = 100 - dataset = self.DummyDataset(size=batch_size, num_classes=num_classes,encode_labels=encode_labels) + dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels) cutmix_mixup = T(num_classes=num_classes) @@ -2201,10 +2202,7 @@ class TestCutMixMixUp: img, target = next(iter(dl)) input_img_size = img.shape[-3:] assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor) - if encode_labels: - assert target.shape == (batch_size, num_classes) - else: - assert target.shape == (batch_size,) + assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,) def check_output(img, target): assert img.shape == (batch_size, *input_img_size) @@ -2215,10 +2213,7 @@ class TestCutMixMixUp: # After Dataloader, as unpacked input img, target = next(iter(dl)) - if encode_labels: - assert target.shape == (batch_size, num_classes) - else: - assert target.shape == (batch_size,) + assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,) img, target = cutmix_mixup(img, target) check_output(img, target) @@ -2273,7 +2268,7 @@ class TestCutMixMixUp: with pytest.raises(ValueError, match="Could not infer where the labels are"): cutmix_mixup({"img": imgs, "Nothing_else": 3}) - with pytest.raises(ValueError, match="labels tensor should be of shape"): + with pytest.raises(ValueError, match="labels should be index based"): # Note: the error message isn't ideal, but that's because the label heuristic found the img as the label # It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently cutmix_mixup(imgs) @@ -2281,22 +2276,21 @@ class TestCutMixMixUp: with pytest.raises(ValueError, match="When using the default labels_getter"): cutmix_mixup(imgs, "not_a_tensor") - with pytest.raises(ValueError, match="labels tensor should be of shape"): - cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3))) - with pytest.raises(ValueError, match="Expected a batched input with 4 dims"): cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,))) with pytest.raises(ValueError, match="does not match the batch size of the labels"): cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,))) - with pytest.raises(ValueError, match="labels tensor should be of shape"): - # The purpose of this check is more about documenting the current - # behaviour of what happens on a Compose(), rather than actually - # asserting the expected behaviour. We may support Compose() in the - # future, e.g. for 2 consecutive CutMix? - labels = torch.randint(0, num_classes, size=(batch_size,)) - transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels) + with pytest.raises(ValueError, match="When passing 2D labels"): + wrong_num_classes = num_classes + 1 + T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes))) + + with pytest.raises(ValueError, match="but got a tensor of shape"): + cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4))) + + with pytest.raises(ValueError, match="num_classes must be passed"): + T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,))) @pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 48daa271ea..1d01012654 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import PIL.Image import torch @@ -142,7 +142,7 @@ class RandomErasing(_RandomApplyTransform): class _BaseMixUpCutMix(Transform): - def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default", labels_encoded: bool = False) -> None: + def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None: super().__init__() self.alpha = float(alpha) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) @@ -150,7 +150,6 @@ class _BaseMixUpCutMix(Transform): self.num_classes = num_classes self._labels_getter = _parse_labels_getter(labels_getter) - self._labels_encoded = labels_encoded def forward(self, *inputs): inputs = inputs if len(inputs) > 1 else inputs[0] @@ -163,10 +162,21 @@ class _BaseMixUpCutMix(Transform): labels = self._labels_getter(inputs) if not isinstance(labels, torch.Tensor): raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") - elif not 0 < labels.ndim <= 2 or (labels.ndim == 2 and labels.shape[1] != self.num_classes): + if labels.ndim not in (1, 2): raise ValueError( - f"labels tensor should be of shape (batch_size,) or (batch_size,num_classes) " f"but got shape {labels.shape} instead." + f"labels should be index based with shape (batch_size,) " + f"or probability based with shape (batch_size, num_classes), " + f"but got a tensor of shape {labels.shape} instead." ) + if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes: + raise ValueError( + f"When passing 2D labels, " + f"the number of elements in last dimension must match num_classes: " + f"{labels.shape[-1]} != {self.num_classes}. " + f"You can Leave num_classes to None." + ) + if labels.ndim == 1 and self.num_classes is None: + raise ValueError("num_classes must be passed if the labels are index-based (1D)") params = { "labels": labels, @@ -225,7 +235,8 @@ class MixUp(_BaseMixUpCutMix): Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. - num_classes (int): number of classes in the batch. Used for one-hot-encoding. + num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. + Can be None only if the labels are already one-hot-encoded. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``. @@ -273,7 +284,8 @@ class CutMix(_BaseMixUpCutMix): Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. - num_classes (int): number of classes in the batch. Used for one-hot-encoding. + num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. + Can be None only if the labels are already one-hot-encoded. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
@mahdilamb
Copy link
Contributor Author

Hi @NicolasHug, that's diff applied!

Hope you have a great weekend.

Mahdi

@NicolasHug
Copy link
Member

Thank you @mahdilamb

Before I can merge, do you mind fixing this one linting issue:

 torchvision/transforms/v2/_augment.py:213: error: Argument "num_classes" to "one_hot" has incompatible type "Optional[int]"; expected "int" [arg-type] label = one_hot(label, num_classes=self.num_classes) ^~~~~~~~~~~~~~~~ Found 1 error in 1 file (checked 235 source files) 

I think adding a simple # type: ignore[arg-type] comment will be enough - mypy is just not undersanding that self.num_classes can't be None at that point, so we should just silence it.

Thanks!

@mahdilamb
Copy link
Contributor Author

@NicolasHug, made the change, but if it fails will look into it properly. Also added you as a collaborator on the fork so you can mess about!

@NicolasHug NicolasHug changed the title Enable pre-encoded mixup Enable one-hot-encoded labels in MixUp and CutMix May 28, 2024
@NicolasHug NicolasHug merged commit c585a51 into pytorch:main May 28, 2024
@NicolasHug
Copy link
Member

Thank you @mahdilamb !

facebook-github-bot pushed a commit that referenced this pull request Jun 7, 2024
Summary: Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com> Reviewed By: vmoens Differential Revision: D58283866 fbshipit-source-id: 32b0b2ade02b3a81d167f64a3743c2bf62049308
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment