Skip to content

Gradient accumulation fix in cross entropy loss#21386

Open
Sohaib-Ahmed21 wants to merge 11 commits intoLightning-AI:masterfrom
Sohaib-Ahmed21:bugfix/20350_grad_acc_fix
Open

Gradient accumulation fix in cross entropy loss#21386
Sohaib-Ahmed21 wants to merge 11 commits intoLightning-AI:masterfrom
Sohaib-Ahmed21:bugfix/20350_grad_acc_fix

Conversation

@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor

@Sohaib-Ahmed21 Sohaib-Ahmed21 commented Nov 27, 2025

What does this PR do?

This PR resolves the issue where loss normalization during gradient accumulation is incorrect in case of cross entropy loss.

Key Changes:

  • Implemented a peekable iterator that inspects all microbatches in the global batch before the first forward pass in accumulation window to determine the total number of valid tokens.

  • Added support for a labels key in each batch, similar to the labels field used in Hugging Face Transformers.

  • Each microbatch’s loss is now divided by the total valid-token count of the global batch, ensuring correct scaling during gradient accumulation.

  • Documentation updates will follow to guide users on required batch structure and loss settings.

Fixes #20350

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • yes
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
    yes will do so.
  • Did you write any new necessary tests? (not for typos and docs)
    I think so not needed yet.
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21386.org.readthedocs.build/en/21386/

@github-actions github-actions Bot added the pl Generic label for PyTorch Lightning package label Nov 27, 2025
Comment thread src/lightning/pytorch/loops/training_epoch_loop.py Outdated
Comment thread src/lightning/pytorch/loops/training_epoch_loop.py Outdated
Comment thread src/lightning/pytorch/loops/training_epoch_loop.py Outdated
@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor Author

Sohaib-Ahmed21 commented Dec 14, 2025

@SkafteNicki @justusschock kindly review this PR when you have some time so I can proceed with it, thanks!

@github-actions github-actions Bot added the docs Documentation related label Jan 7, 2026
@Sohaib-Ahmed21 Sohaib-Ahmed21 marked this pull request as ready for review January 7, 2026 17:21
@Sohaib-Ahmed21 Sohaib-Ahmed21 force-pushed the bugfix/20350_grad_acc_fix branch 2 times, most recently from 4152f03 to aee774e Compare January 7, 2026 19:49
@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 17, 2026

Codecov Report

❌ Patch coverage is 50.00000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 79%. Comparing base (0e20e15) to head (d83a6e3).
✅ All tests successful. No failed tests found.

❗ There is a different number of reports uploaded between BASE (0e20e15) and HEAD (d83a6e3). Click for more details.

HEAD has 160 uploads less than BASE
Flag BASE (0e20e15) HEAD (d83a6e3)
cpu 82 42
python 6 3
lightning_fabric 25 0
pytest 40 0
python3.12.7 18 9
python3.12 24 12
lightning 30 15
python3.11 12 6
python3.10 6 3
python3.13 16 9
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21386     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         270      267      -3     
  Lines       23973    23916     -57     
=========================================
- Hits        20748    18802   -1946     
- Misses       3225     5114   +1889     

@Sohaib-Ahmed21 Sohaib-Ahmed21 force-pushed the bugfix/20350_grad_acc_fix branch from dd4c876 to b71147c Compare January 18, 2026 10:29
@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor Author

Sohaib-Ahmed21 commented Jan 18, 2026

@SkafteNicki @lantiga @justusschock @deependujha all checks are passing now and the PR is ready for review. Kindly review it, thanks!

@deependujha
Copy link
Copy Markdown
Collaborator

deependujha commented Jan 21, 2026

Hi @Sohaib-Ahmed21, thanks for digging into this.

I think this is a case where Lightning should avoid encoding domain assumptions and instead expose a more general mechanism.

Assumptions around CrossEntropyLoss, ignore indices like -100, and how “valid tokens” are inferred are very NLP- and HF-specific. Lightning, on the other hand, is intentionally domain-agnostic and used across vision, speech, RL, and custom losses. Once Lightning starts inferring valid_token_count from labels, it implicitly assumes token-level supervision, a specific loss function, and specific masking semantics. That feels brittle and difficult to generalize.


A cleaner approach would be to let the user explicitly control normalization via the training_step output.

For example, the user could return a generic normalize value alongside the loss, and Lightning would simply divide the loss by this value during gradient accumulation.

def training_step(...):
    return {"loss": loss, "normalize": valid_tokens}

In the case of token-level cross entropy, normalize would naturally be the number of valid (non-masked) tokens. In other setups, it could represent batch size, number of sequences, number of samples seen, or any other user-defined notion of “effective batch”.

This keeps responsibilities clearly separated:

  • the user defines loss semantics and masking
  • Lightning handles accumulation and scaling, without needing to know why a given normalization is correct

Important

Documentation and examples can still recommend returning normalize=valid_token_count for HF-style CE with ignore_index=-100, but the framework itself remains free of hardcoded assumptions. This feels more robust, extensible, and aligned with Lightning’s design philosophy than inferring normalization internally.

What do you think @SkafteNicki @justusschock @bhimrazy @Sohaib-Ahmed21?

@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor Author

Thanks for the detailed feedback @deependujha.

The same thoughts were shared by @SkafteNicki here. I followed the HF implementation but the scope creep in case of pytorch-lightning as you mentioned above seems intuitive.

I think in the light of comments by you and @SkafteNicki, I should revert back my logic to contain something like num_ignore_val in training_step output and mention in docs accordingly. Thoughts please @deependujha @SkafteNicki, thanks!

@deependujha
Copy link
Copy Markdown
Collaborator

why num_ignore_val?

@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor Author

why num_ignore_val?

Sorry the name is not descriptive, will have some definitive param to denote total count of valid tokens in global batch.

Can you kindly confirm the approach so I revert, thanks!

@deependujha
Copy link
Copy Markdown
Collaborator

deependujha commented Jan 23, 2026

This approach makes minimal assumptions and addresses the gradient accumulation bug in the cross-entropy loss without requiring significant code restructuring. It would be good to align on this approach.z

thoughts @bhimrazy ?

@Sohaib-Ahmed21 Sohaib-Ahmed21 force-pushed the bugfix/20350_grad_acc_fix branch from 9bb4b7f to 8a06cff Compare January 24, 2026 17:09
@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor Author

Sohaib-Ahmed21 commented Jan 29, 2026

I've reverted the PR to align on the discussed approach. @deependujha @justusschock @SkafteNicki please review it.

@bhimrazy
Copy link
Copy Markdown
Collaborator

bhimrazy commented Feb 4, 2026

This approach makes minimal assumptions and addresses the gradient accumulation bug in the cross-entropy loss without requiring significant code restructuring. It would be good to align on this approach.z

thoughts @bhimrazy ?

yeah like the idea @deependujha.

Also would be good to get a lead's review on approach & to catch any edge cases we might be missing.

and isinstance(training_step_output, Mapping)
and training_step_output.get("normalize") is not None
):
normalize = training_step_output["normalize"]
Copy link
Copy Markdown
Collaborator

@bhimrazy bhimrazy Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will probably need some validations here (e.g., checks for 0 and type checks).

cc: @deependujha

@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor Author

Sohaib-Ahmed21 commented Feb 4, 2026

Also would be good to get a lead's review on approach & to catch any edge cases we might be missing.

@SkafteNicki @tchaton @Borda please review this PR, thanks!

@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor Author

Someone please review this.

@Sohaib-Ahmed21 Sohaib-Ahmed21 requested a review from bhimrazy April 23, 2026 19:44
@Sohaib-Ahmed21
Copy link
Copy Markdown
Contributor Author

Review on this PR is pending for so long, someone please review @SkafteNicki

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs Documentation related pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient accumulation calcluation may be incorrect

4 participants