- Notifications
You must be signed in to change notification settings - Fork 31.1k
Fix Gradient Accumulation issue #34191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 22 commits
Commits
Show all changes
58 commits Select commit Hold shift + click to select a range
44301cc quick fix
ArthurZucker 1456088 3 losses
ArthurZucker e57f00c oups
ArthurZucker 7fa8503 fix
ArthurZucker b955ea5 nits
ArthurZucker 07478e0 check how it scales for special models
ArthurZucker 1b356ef propagate for conditiona detr
ArthurZucker 4ef45b0 propagate
ArthurZucker 61da9b1 propagate
ArthurZucker 2e3f0f7 propagate
ArthurZucker c31a3fb fixes
ArthurZucker a8cd107 propagate changes
ArthurZucker 711c357 update
ArthurZucker 4888cf3 fixup
ArthurZucker 4323d85 nits
ArthurZucker e5e4bbd f string
ArthurZucker 239a256 fixes
ArthurZucker bd298da more fixes
ArthurZucker 5dfc51c ?
ArthurZucker 0a1cd2b nit
ArthurZucker 64f7e29 arg annoying f string
ArthurZucker aa01ae9 nits
ArthurZucker 8c1d68a grumble
ArthurZucker 846cf1c update
ArthurZucker e7e8a20 nit
ArthurZucker 622290c refactor
ArthurZucker 91e28aa fix fetch tests
ArthurZucker da649b9 nit
ArthurZucker df6472a nit
ArthurZucker cf1eb7b Update src/transformers/loss/loss_utils.py
ArthurZucker dafd11b Merge branch 'quick-fix-ga' of github.com:huggingface/transformers in…
ArthurZucker 30f27cd update
ArthurZucker d0edfad nit
ArthurZucker 9bcecc3 fixup
ArthurZucker 2839b3c make pass
ArthurZucker 557d225 nits
ArthurZucker 393e178 port code to more models
ArthurZucker aac054d fixup
ArthurZucker ce32d5e ntis
ArthurZucker 4dc49ac arf
ArthurZucker d221e58 update
ArthurZucker f03b193 update
ArthurZucker 22b6283 nits
ArthurZucker 64829e3 update
ArthurZucker 0b6f425 fix
ArthurZucker e6f6f52 update
ArthurZucker fa691aa nits
ArthurZucker 66f6eef fine
ArthurZucker fcdf13d agjkfslga.jsdlkgjklas
ArthurZucker ece5e01 nits
ArthurZucker bb236eb fix fx?
ArthurZucker 7c2b7ce update
ArthurZucker 0be4379 update
ArthurZucker 36d76d7 styel
ArthurZucker 92979e7 fix imports
ArthurZucker a55e440 update
ArthurZucker b14c3dd update
ArthurZucker dbbc3ce fixup to fix the torch fx?
ArthurZucker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
| | ||
| from .models.detr.loss_detr import ForObjectDetectionLoss, ForSegmentationLoss | ||
| | ||
| | ||
| def DefaultCrossEntropyLoss(logits, labels, **kwargs): | ||
| # Upcast to float if we need to compute the loss to avoid potential precision issues | ||
| logits = logits.float() | ||
| # Shift so that tokens < n predict n | ||
| shift_logits = logits[..., :-1, :].contiguous() | ||
| shift_labels = labels[..., 1:].contiguous() | ||
| | ||
| # Flatten the tokens | ||
| shift_logits = shift_logits.view(-1, kwargs["vocab_size"]) | ||
| shift_labels = shift_labels.view(-1) | ||
| # Enable model parallelism | ||
| shift_labels = shift_labels.to(shift_logits.device) | ||
| | ||
| num_items = kwargs.pop("num_items", None) | ||
| | ||
| if num_items is not None: | ||
| # Calculate the CrossEntropyLoss manually when using grad accum | ||
| log_probs = nn.functional.log_softmax(shift_logits, dim=-1) | ||
| loss = -log_probs[range(shift_labels.size(0)), shift_labels] | ||
| loss = loss.sum() / num_items | ||
| else: | ||
| loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100) | ||
| | ||
| return loss | ||
| | ||
| | ||
| def ForSequenceClassificationLoss(logits, labels, pooled_logits, **kwargs): | ||
| config = kwargs["config"] | ||
| num_labels = config.num_labels | ||
| if config.problem_type is None: | ||
| if num_labels == 1: | ||
| config.problem_type = "regression" | ||
| elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | ||
| config.problem_type = "single_label_classification" | ||
| else: | ||
| config.problem_type = "multi_label_classification" | ||
| | ||
| if config.problem_type == "regression": | ||
| loss_fct = MSELoss() | ||
| if num_labels == 1: | ||
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||
| else: | ||
| loss = loss_fct(pooled_logits, labels) | ||
| elif config.problem_type == "single_label_classification": | ||
| loss_fct = CrossEntropyLoss() | ||
| loss = loss_fct(pooled_logits.view(-1, num_labels), labels.view(-1)) | ||
| elif config.problem_type == "multi_label_classification": | ||
| loss_fct = BCEWithLogitsLoss() | ||
| loss = loss_fct(pooled_logits, labels) | ||
| return loss | ||
| | ||
| | ||
| def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions): | ||
| total_loss = None | ||
| if start_positions is not None and end_positions is not None: | ||
| # If we are on multi-GPU, split add a dimension | ||
| if len(start_positions.size()) > 1: | ||
| start_positions = start_positions.squeeze(-1).to(start_logits.device) | ||
| if len(end_positions.size()) > 1: | ||
| end_positions = end_positions.squeeze(-1).to(end_logits.device) | ||
| # sometimes the start/end positions are outside our model inputs, we ignore these terms | ||
| ignored_index = start_logits.size(1) | ||
| start_positions = start_positions.clamp(0, ignored_index) | ||
| end_positions = end_positions.clamp(0, ignored_index) | ||
| | ||
| loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | ||
| start_loss = loss_fct(start_logits, start_positions) | ||
| end_loss = loss_fct(end_logits, end_positions) | ||
| total_loss = (start_loss + end_loss) / 2 | ||
| return total_loss | ||
| | ||
| | ||
| def ForTokenClassification(logits, labels, config, **kwargs): | ||
| # Upcast to float if we need to compute the loss to avoid potential precision issues | ||
| logits = logits.view(-1, config.num_labels) | ||
| labels = labels.view(-1) | ||
| logits = logits.float() | ||
| # Flatten the tokens | ||
| loss_fct = CrossEntropyLoss() | ||
| return loss_fct(logits, labels) | ||
| | ||
| | ||
| LOSS_MAPPING = { | ||
| "ForCausalLM": DefaultCrossEntropyLoss, | ||
ArthurZucker marked this conversation as resolved. Outdated Show resolved Hide resolved | ||
| "ForQuestionAnswering": ForQuestionAnsweringLoss, | ||
| "ForSequenceClassification": ForSequenceClassificationLoss, | ||
| "ForTokenClassification": ForTokenClassification, | ||
| } | ||
| | ||
| LOSS_MAPPING["ForSegmentation"] = ForSegmentationLoss | ||
| LOSS_MAPPING["ForObjectDetection"] = ForObjectDetectionLoss | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.