Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
0040b97
Add deterministic training functionality to PyTorch LLaMA benchmark
Aishwarya-Tonpe Aug 5, 2025
e103dd0
llama: add periodic checksum logging (deterministic-only, log-only); …
Aishwarya-Tonpe Aug 11, 2025
87ff6d6
deterministic training: enable seeding + deterministic algorithms acr…
Aishwarya-Tonpe Aug 11, 2025
8eee235
tests(pytorch): add strict determinism skip guards and detailed docst…
Aishwarya-Tonpe Aug 11, 2025
fe34247
Refactor LLaMA model tests: align strict, soft determinism, and check…
Aishwarya-Tonpe Aug 11, 2025
c374dfe
examples: add deterministic and strict_determinism flags and docs to …
Aishwarya-Tonpe Aug 11, 2025
614f96c
Deterministic fingerprints: replace checksum with Loss+ActMean across…
Aishwarya-Tonpe Aug 12, 2025
689dc44
Deterministic training + reproducible logging: align GPT-2/LLaMA/LSTM…
Aishwarya-Tonpe Aug 16, 2025
33c3f6a
Adding flag: Checck-frequency
Aishwarya-Tonpe Aug 18, 2025
f35e98b
Add Check frequency flag to tests
Aishwarya-Tonpe Aug 19, 2025
dd7fcbe
Code refactor: Move enable_determinism to base class, add a consolida…
Aishwarya-Tonpe Aug 20, 2025
d439395
Code refactor: Add a new test folder to remove redundant code, remove…
Aishwarya-Tonpe Aug 20, 2025
da9c85a
Code refactor: Move loss and ActMean logging to base class from indiv…
Aishwarya-Tonpe Aug 20, 2025
2635aad
Code refactor: Move _benchmark() method to base class
Aishwarya-Tonpe Aug 20, 2025
4a21990
Code refactor: Add method _finalize_periodic_logging to base class to…
Aishwarya-Tonpe Aug 20, 2025
ddd3f23
Code cleanup: Remove unnecessary imports
Aishwarya-Tonpe Aug 20, 2025
a9cb452
Code cleanup: Remove unnecessary imports
Aishwarya-Tonpe Aug 20, 2025
52c5516
Code cleanup: Remove unnecessary imports
Aishwarya-Tonpe Aug 20, 2025
6623f59
Code cleanup: Remove unnecessary imports
Aishwarya-Tonpe Aug 20, 2025
8853c21
Tescase addition: Add Failure testcase, renameflag
Aishwarya-Tonpe Aug 21, 2025
14be806
Delete extra lines
Aishwarya-Tonpe Aug 21, 2025
8cd1c19
Add Docstrings, align imports, add assertions messages
Aishwarya-Tonpe Aug 26, 2025
99bdc16
Lint Checks
Aishwarya-Tonpe Aug 27, 2025
4bc0445
Lint Checks
Aishwarya-Tonpe Aug 28, 2025
2c8d856
Lint Checks
Aishwarya-Tonpe Aug 28, 2025
d8d9ca0
Failed check: Resolving failed pipeline check for creating temp file …
Aishwarya-Tonpe Aug 28, 2025
8bcd801
Pipeline failure fixes : Fixing Lint failures on test, example and ba…
Aishwarya-Tonpe Aug 28, 2025
315d07f
Pipeline failure fixes : Fixing Lint failures on test, example and ba…
Aishwarya-Tonpe Aug 28, 2025
5ae57f0
Pipeline failure error: Github not reflecting change in base file, at…
Aishwarya-Tonpe Aug 28, 2025
c379c5e
Pipeline failure fixes
Aishwarya-Tonpe Aug 28, 2025
3b186cf
Pipeline failure fixes
Aishwarya-Tonpe Aug 29, 2025
64d7b81
Test file lint fixes
Aishwarya-Tonpe Aug 29, 2025
90a6595
Pipeline Error: Mixtral create Model
Aishwarya-Tonpe Aug 29, 2025
055723c
Modifying test parameters for efficiency
Aishwarya-Tonpe Aug 29, 2025
b47688d
Attempting to skip tests for heavy models in CI
Aishwarya-Tonpe Aug 29, 2025
13ad2fe
Attempting to skip tests for heavy models in CI
Aishwarya-Tonpe Aug 29, 2025
2ed5ae0
Skipping tests for CICD
Aishwarya-Tonpe Aug 29, 2025
10ae1a3
Removing unnecessary code
Aishwarya-Tonpe Sep 3, 2025
fb21a9f
Adding Metadata Overriding logic to fetch metadata from the log file …
Aishwarya-Tonpe Sep 4, 2025
f3bb260
Adding Metadata Overriding logic to fetch metadata from the log file …
Aishwarya-Tonpe Sep 4, 2025
172b02b
Lint Fixes
Aishwarya-Tonpe Sep 4, 2025
de326d5
Pipeline failure fix
Aishwarya-Tonpe Sep 4, 2025
6497bf5
Adding test for coverage
Aishwarya-Tonpe Sep 4, 2025
8a8599e
Pipeline failure fix
Aishwarya-Tonpe Sep 4, 2025
a68b4df
Pipeline failure fix
Aishwarya-Tonpe Sep 4, 2025
e59fc61
Adding Info about deterministic traning to docs
Aishwarya-Tonpe Sep 15, 2025
7c6120d
Adding Info about deterministic traning to docs
Aishwarya-Tonpe Sep 15, 2025
860f0f9
Merge branch 'main' into aishwaryatonpe/deterministic-training
polarG Sep 22, 2025
2892a69
Comments resolve: Add docstrings, Make changes to ensure same lenghts…
Aishwarya-Tonpe Oct 1, 2025
0195d98
COmment resolve : Remove process_info, deprecated
Aishwarya-Tonpe Oct 1, 2025
ea6f7fc
Fixing Lint errors
Aishwarya-Tonpe Oct 1, 2025
d8acbf2
Lint checkes resolve
Aishwarya-Tonpe Oct 2, 2025
8629e8b
Lint checkes resolve
Aishwarya-Tonpe Oct 2, 2025
b15393f
Test case fixes : removing log-path from test-pytorch_determinism_all
Aishwarya-Tonpe Oct 2, 2025
529ab12
Comments removed
Aishwarya-Tonpe Oct 2, 2025
2cb80c0
Merge branch 'main' into aishwaryatonpe/deterministic-training
Aishwarya-Tonpe Oct 2, 2025
54d3449
Fixing test_pytorch_deterministic_all
Aishwarya-Tonpe Oct 2, 2025
e91ec63
Comments address : Removing redundant code
Aishwarya-Tonpe Oct 2, 2025
8fc3d5f
Moving seeding logic to make it centralised to model base
Aishwarya-Tonpe Oct 2, 2025
0848c7a
Moving seeding logic to make it centralised to model base
Aishwarya-Tonpe Oct 2, 2025
42718f0
Merge branch 'main' into aishwaryatonpe/deterministic-training
Aishwarya-Tonpe Oct 8, 2025
615bc94
Comments resolve: removing redundant method, adding loggers
Aishwarya-Tonpe Oct 8, 2025
a2e2e20
Merge branch 'main' into aishwaryatonpe/deterministic-training
Aishwarya-Tonpe Oct 9, 2025
59cfdd1
Resolving merge conflicts
Aishwarya-Tonpe Oct 9, 2025
e893a5a
Merge branch 'main' into aishwaryatonpe/deterministic-training
Aishwarya-Tonpe Oct 23, 2025
d909477
Merge branch 'main' into aishwaryatonpe/deterministic-training
Aishwarya-Tonpe Nov 10, 2025
436890e
Merge branch 'main' of https://github.com/microsoft/superbenchmark in…
Dec 8, 2025
e4d2f5e
Removing check_frequency parameter from is_finished method in train a…
Dec 8, 2025
d0bfd38
Comments resolve : Removing check_frequency assignment to the variable
Dec 8, 2025
197007a
Update superbench/benchmarks/model_benchmarks/pytorch_base.py
Aishwarya-Tonpe Dec 8, 2025
4724815
Update tests/benchmarks/model_benchmarks/test_pytorch_determinism_all.py
Aishwarya-Tonpe Dec 8, 2025
fdc82ad
Update superbench/benchmarks/model_benchmarks/pytorch_base.py
Aishwarya-Tonpe Dec 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions examples/benchmarks/pytorch_deterministic_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Unified PyTorch deterministic training example for all supported models.

Commands to run:
Generate log:

CUBLAS_WORKSPACE_CONFIG=:4096:8 python3 examples/benchmarks/pytorch_deterministic_example.py
--model <model_from_MODEL_CHOICES> --generate-log --log-path ./outputs/determinism_ref.json

CUBLAS_WORKSPACE_CONFIG=:4096:8 python3 examples/benchmarks/pytorch_deterministic_example.py
--model bert-large --generate-log --log-path ./outputs/determinism_ref.json



Compare log:

CUBLAS_WORKSPACE_CONFIG=:4096:8 python3 examples/benchmarks/pytorch_deterministic_example.py
--model <model_from_MODEL_CHOICES> --compare-log ./outputs/determinism_ref.json


CUBLAS_WORKSPACE_CONFIG=:4096:8 python3 examples/benchmarks/pytorch_deterministic_example.py
--model bert-large --compare-log ./outputs/determinism_ref.json
"""

import argparse
from superbench.benchmarks import BenchmarkRegistry, Framework

MODEL_CHOICES = [
"bert-large",
"gpt2-small",
"llama2-7b",
"mixtral-8x7b",
"resnet101",
"lstm",
]

DEFAULT_PARAMS = {
"bert-large": "--batch_size 1 --seq_len 128 --num_warmup 1 --num_steps 300 --precision float32 "
"--model_action train --deterministic --deterministic_seed 42 --check_frequency 20",
"gpt2-small": "--batch_size 1 --num_steps 300 --num_warmup 1 --seq_len 128 --precision float32 "
"--model_action train --deterministic --deterministic_seed 42 --check_frequency 20",
"llama2-7b": "--batch_size 1 --num_steps 300 --num_warmup 1 --seq_len 512 --precision float32 --model_action train "
"--deterministic --deterministic_seed 42 --check_frequency 20",
"mixtral-8x7b": "--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --intermediate_size=14336 "
"--num_key_value_heads=8 --max_position_embeddings=32768 --router_aux_loss_coef=0.02 "
"--deterministic --deterministic_seed 42 --check_frequency 20",
"resnet101": "--batch_size 192 --precision float32 float32 --num_warmup 64 --num_steps 512 --sample_count 8192 "
"--pin_memory --model_action train --deterministic --deterministic_seed 42 --check_frequency 20",
"lstm": "--batch_size 1 --num_steps 300 --num_warmup 1 --seq_len 256 --precision float16 "
"--model_action train --deterministic --deterministic_seed 42 --check_frequency 20",
}


def main():
parser = argparse.ArgumentParser(
description="Unified PyTorch deterministic training example."
)
parser.add_argument(
"--model", type=str, choices=MODEL_CHOICES, required=True, help="Model to run."
)
parser.add_argument(
"--generate-log", action="store_true", help="Enable fingerprint log generation."
)
parser.add_argument(
"--log-path", type=str, default=None, help="Path to save fingerprint log."
)
parser.add_argument(
"--compare-log",
type=str,
default=None,
help="Path to reference fingerprint log for comparison.",
)
parser.add_argument(
"--deterministic-seed",
type=int,
default=42,
help="Seed for deterministic training.",
)
args = parser.parse_args()

parameters = DEFAULT_PARAMS[args.model]
parameters = parameters.replace("--deterministic_seed", "--deterministic_seed")
if args.deterministic_seed:
parameters += f" --deterministic_seed {args.deterministic_seed}"
if args.generate_log:
parameters += " --generate-log"
if args.log_path:
parameters += f" --log-path {args.log_path}"
if args.compare_log:
parameters += f" --compare-log {args.compare_log}"

print(f"Running {args.model} with parameters: {parameters}")
context = BenchmarkRegistry.create_benchmark_context(
args.model, parameters=parameters, framework=Framework.PYTORCH
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
print(f"Benchmark finished. Return code: {benchmark.return_code}")
if hasattr(benchmark, "_model_run_metadata"):
print("Run metadata:", benchmark._model_run_metadata)
if hasattr(benchmark, "_model_run_losses"):
print("Losses:", benchmark._model_run_losses[:5], "...")
if hasattr(benchmark, "_model_run_periodic"):
print("Periodic:", benchmark._model_run_periodic)


if __name__ == "__main__":
main()
Loading
Loading