Gradient accumulation fix in cross entropy loss#21386
Gradient accumulation fix in cross entropy loss#21386Sohaib-Ahmed21 wants to merge 11 commits intoLightning-AI:masterfrom
Conversation
|
@SkafteNicki @justusschock kindly review this PR when you have some time so I can proceed with it, thanks! |
4152f03 to
aee774e
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21386 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 270 267 -3
Lines 23973 23916 -57
=========================================
- Hits 20748 18802 -1946
- Misses 3225 5114 +1889 |
dd4c876 to
b71147c
Compare
|
@SkafteNicki @lantiga @justusschock @deependujha all checks are passing now and the PR is ready for review. Kindly review it, thanks! |
|
Hi @Sohaib-Ahmed21, thanks for digging into this. I think this is a case where Lightning should avoid encoding domain assumptions and instead expose a more general mechanism. Assumptions around A cleaner approach would be to
def training_step(...):
return {"loss": loss, "normalize": valid_tokens}In the case of token-level cross entropy, This keeps responsibilities clearly separated:
Important Documentation and examples can still recommend returning What do you think @SkafteNicki @justusschock @bhimrazy @Sohaib-Ahmed21? |
|
Thanks for the detailed feedback @deependujha. The same thoughts were shared by @SkafteNicki here. I followed the HF implementation but the scope creep in case of pytorch-lightning as you mentioned above seems intuitive. I think in the light of comments by you and @SkafteNicki, I should revert back my logic to contain something like |
|
why |
Sorry the name is not descriptive, will have some definitive param to denote total count of valid tokens in global batch. Can you kindly confirm the approach so I revert, thanks! |
|
This approach makes minimal assumptions and addresses the gradient accumulation bug in the cross-entropy loss without requiring significant code restructuring. It would be good to align on this approach.z thoughts @bhimrazy ? |
9bb4b7f to
8a06cff
Compare
…-Ahmed21/pytorch-lightning into bugfix/20350_grad_acc_fix
|
I've reverted the PR to align on the discussed approach. @deependujha @justusschock @SkafteNicki please review it. |
yeah like the idea @deependujha. Also would be good to get a lead's review on approach & to catch any edge cases we might be missing. |
| and isinstance(training_step_output, Mapping) | ||
| and training_step_output.get("normalize") is not None | ||
| ): | ||
| normalize = training_step_output["normalize"] |
There was a problem hiding this comment.
It will probably need some validations here (e.g., checks for 0 and type checks).
cc: @deependujha
@SkafteNicki @tchaton @Borda please review this PR, thanks! |
|
Someone please review this. |
|
Review on this PR is pending for so long, someone please review @SkafteNicki |
What does this PR do?
This PR resolves the issue where loss normalization during gradient accumulation is incorrect in case of cross entropy loss.
Key Changes:
Implemented a peekable iterator that inspects all microbatches in the global batch before the first forward pass in accumulation window to determine the total number of valid tokens.
Added support for a
labelskey in each batch, similar to thelabelsfield used in Hugging Face Transformers.Each microbatch’s loss is now divided by the total valid-token count of the global batch, ensuring correct scaling during gradient accumulation.
Documentation updates will follow to guide users on required batch structure and loss settings.
Fixes #20350
Before submitting
yes will do so.
I think so not needed yet.
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--21386.org.readthedocs.build/en/21386/