-
Notifications
You must be signed in to change notification settings - Fork 79
Benchmark: Model benchmark - deterministic training support #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 61 commits
0040b97
e103dd0
87ff6d6
8eee235
fe34247
c374dfe
614f96c
689dc44
33c3f6a
f35e98b
dd7fcbe
d439395
da9c85a
2635aad
4a21990
ddd3f23
a9cb452
52c5516
6623f59
8853c21
14be806
8cd1c19
99bdc16
4bc0445
2c8d856
d8d9ca0
8bcd801
315d07f
5ae57f0
c379c5e
3b186cf
64d7b81
90a6595
055723c
b47688d
13ad2fe
2ed5ae0
10ae1a3
fb21a9f
f3bb260
172b02b
de326d5
6497bf5
8a8599e
a68b4df
e59fc61
7c6120d
860f0f9
2892a69
0195d98
ea6f7fc
d8acbf2
8629e8b
b15393f
529ab12
2cb80c0
54d3449
e91ec63
8fc3d5f
0848c7a
42718f0
615bc94
a2e2e20
59cfdd1
e893a5a
d909477
436890e
e4d2f5e
d0bfd38
197007a
4724815
fdc82ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,18 @@ For inference, supported percentiles include | |
|
|
||
| **New: Support fp8_hybrid and fp8_e4m3 precision for BERT models.** | ||
|
|
||
| **New: SDC Support** | ||
| SuperBench now supports SDC to ensure reproducibility across runs. This includes fixed seeds and deterministic algorithms. To enable SDC, the following flags and environment variables must be set: | ||
|
|
||
| - **Flags:** | ||
| - `--deterministic`: Enables deterministic computation. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use verb or noun for argument, e.g., enable-determinism |
||
| - `--deterministic_seed <seed>`: Sets the seed for reproducibility. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. random-seed? |
||
| - `--generate_log` : Generates the log file that can be used as reference for comparison | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| - `--compare_log <path>`: Specifies the path to the reference log for comparison. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the comparison necessary? if the loss etc. are one of the metrics, it can be separately compared like current throughput etc.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comparison serves a different purpose than performance metrics. The feature was designed this way to ensure exact equality and more importantly to generate a reference log (golden data) once and validate all subsequent runs across different machines and runs against it. |
||
|
|
||
| - **Environment Variables:** | ||
| - `CUBLAS_WORKSPACE_CONFIG=:4096:8`: Ensures deterministic behavior in cuBLAS. | ||
|
Comment on lines
+46
to
+47
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should set this in code when determinism feature is enabled rather than asking user to set it separately?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took this approach because setting CUBLAS_WORKSPACE_CONFIG programmatically is challenging because it must be set before CUDA initializes, which happens during PyTorch import and benchmark construction. By the time we parse --deterministic flag, CUDA is already initialized. Or setting it before parsing the args would result in it being set ALL THE TIME even when it's not necessary. Current : New Approach : |
||
|
|
||
| #### Metrics | ||
|
|
||
| | Name | Unit | Description | | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| """Unified PyTorch deterministic training example for all supported models. | ||
|
|
||
| Commands to run: | ||
| Generate log: | ||
|
|
||
| CUBLAS_WORKSPACE_CONFIG=:4096:8 python3 examples/benchmarks/pytorch_deterministic_example.py | ||
| --model <model_from_MODEL_CHOICES> --generate-log ./outputs/determinism_ref.json | ||
|
|
||
| Compare log: | ||
|
|
||
| CUBLAS_WORKSPACE_CONFIG=:4096:8 python3 examples/benchmarks/pytorch_deterministic_example.py | ||
| --model <model_from_MODEL_CHOICES> --compare-log ./outputs/determinism_ref.json | ||
|
|
||
| """ | ||
|
|
||
| import argparse | ||
| from superbench.benchmarks import BenchmarkRegistry, Framework | ||
| from superbench.common.utils import logger | ||
|
|
||
| MODEL_CHOICES = [ | ||
| 'bert-large', | ||
| 'gpt2-small', | ||
| 'llama2-7b', | ||
| 'mixtral-8x7b', | ||
| 'resnet101', | ||
| 'lstm', | ||
| ] | ||
|
|
||
| DEFAULT_PARAMS = { | ||
| 'bert-large': | ||
| '--batch_size 1 --seq_len 64 --num_warmup 1 --num_steps 200 --precision float32 ' | ||
| '--model_action train --deterministic --deterministic_seed 42 --check_frequency 20', | ||
| 'gpt2-small': | ||
| '--batch_size 1 --num_steps 300 --num_warmup 1 --seq_len 128 --precision float32 ' | ||
| '--model_action train --deterministic --deterministic_seed 42 --check_frequency 20', | ||
| 'llama2-7b': | ||
| '--batch_size 1 --num_steps 300 --num_warmup 1 --seq_len 512 --precision float32 --model_action train ' | ||
| '--deterministic --deterministic_seed 42 --check_frequency 20', | ||
| 'mixtral-8x7b': | ||
| '--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --intermediate_size=14336 ' | ||
| '--num_key_value_heads=8 --max_position_embeddings=32768 --router_aux_loss_coef=0.02 ' | ||
| '--deterministic --deterministic_seed 42 --check_frequency 20', | ||
| 'resnet101': | ||
| '--batch_size 1 --precision float32 --num_warmup 1 --num_steps 120 --sample_count 8192 ' | ||
| '--pin_memory --model_action train --deterministic --deterministic_seed 42 --check_frequency 20', | ||
| 'lstm': | ||
| '--batch_size 1 --num_steps 100 --num_warmup 1 --seq_len 64 --precision float16 ' | ||
| '--model_action train --deterministic --deterministic_seed 42 --check_frequency 20', | ||
| } | ||
|
|
||
|
|
||
| def main(): | ||
| """Main function for determinism example file.""" | ||
| parser = argparse.ArgumentParser(description='Unified PyTorch deterministic training example.') | ||
| parser.add_argument('--model', type=str, choices=MODEL_CHOICES, required=True, help='Model to run.') | ||
| parser.add_argument( | ||
| '--generate-log', | ||
| nargs='?', | ||
| const=True, | ||
| default=None, | ||
| help='Enable fingerprint log generation. Optionally specify a path to save the log.', | ||
| ) | ||
| parser.add_argument( | ||
| '--compare-log', | ||
| type=str, | ||
| default=None, | ||
| help='Path to reference fingerprint log for comparison.', | ||
| ) | ||
| parser.add_argument( | ||
| '--deterministic-seed', | ||
| type=int, | ||
| default=42, | ||
| help='Seed for deterministic training.', | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| parameters = DEFAULT_PARAMS[args.model] | ||
| if args.deterministic_seed: | ||
| parameters += f' --deterministic_seed {args.deterministic_seed}' | ||
| if args.generate_log: | ||
| parameters += ' --generate-log' | ||
| if isinstance(args.generate_log, str): | ||
| parameters += f' {args.generate_log}' | ||
| if args.compare_log: | ||
| parameters += f' --compare-log {args.compare_log}' | ||
|
|
||
| context = BenchmarkRegistry.create_benchmark_context(args.model, parameters=parameters, framework=Framework.PYTORCH) | ||
| benchmark = BenchmarkRegistry.launch_benchmark(context) | ||
| logger.info(f'Benchmark finished. Return code: {benchmark.return_code}') | ||
| if hasattr(benchmark, '_model_run_metadata'): | ||
| logger.info(f'Run metadata: {benchmark._model_run_metadata}') | ||
| if hasattr(benchmark, '_model_run_losses'): | ||
| logger.info(f'Losses: {benchmark._model_run_losses[:5]} ...') | ||
| if hasattr(benchmark, '_model_run_periodic'): | ||
| logger.info(f'Periodic: {benchmark._model_run_periodic}') | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,14 +110,95 @@ def parse_args(self, ignore_invalid=False): | |
| logger.error('Invalid argument - benchmark: {}, message: {}.'.format(self._name, str(e))) | ||
| return False, None, [] | ||
|
|
||
| if args is not None and 'compare_log' in [a.dest for a in self._parser._actions]: | ||
| args = self._parse_args_override_step(args) | ||
|
|
||
| ret = True | ||
| ret = self._check_unknown_args(unknown) | ||
|
Comment on lines
113
to
+117
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the first line in unnecessary |
||
|
|
||
| return ret, args, unknown | ||
|
|
||
| def _parse_args_override_step(self, args): | ||
Aishwarya-Tonpe marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Aishwarya-Tonpe marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Override arguments using metadata from a compare log file. | ||
|
|
||
| Args: | ||
| args: Parsed arguments. | ||
|
|
||
| Returns: | ||
| argparse.Namespace: Updated arguments with overridden values. | ||
| """ | ||
| return self._override_args_with_compare_log(args) | ||
|
|
||
| def _override_args_with_compare_log(self, args): | ||
| """Override arguments with metadata from a compare log file if available. | ||
|
|
||
| Args: | ||
| args: Parsed arguments. | ||
|
|
||
| Returns: | ||
| argparse: Arguments updated with metadata values. | ||
| """ | ||
| # Only override if compare_log is set and is a valid argument for this benchmark | ||
| if args is not None and hasattr(args, 'compare_log') and getattr(args, 'compare_log', None): | ||
| logger.info(f'Original Arguments before overriding from compare_log metadata for determinism: {args}') | ||
| try: | ||
| from superbench.common import model_log_utils | ||
| log_data = model_log_utils.load_model_log(args.compare_log) | ||
| metadata = log_data.get('metadata', {}) | ||
| try: | ||
| from superbench.benchmarks import Precision | ||
| except ImportError: | ||
| Precision = None | ||
| for key, value in metadata.items(): | ||
| if hasattr(args, key): | ||
| if key == 'precision' and Precision is not None: | ||
| setattr(args, key, self._convert_precision_value(value, Precision)) | ||
| else: | ||
| setattr(args, key, value) | ||
| logger.info(f'Arguments overridden from compare_log metadata for determinism. New Arguments: {args}') | ||
| except Exception as e: | ||
| logger.info(f'Failed to override args from compare_log metadata: {e}') | ||
Aishwarya-Tonpe marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return args | ||
|
|
||
| def _convert_precision_value(self, value, Precision): | ||
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Convert precision values to the appropriate format. | ||
|
|
||
| Args: | ||
| value: The precision value to convert. | ||
| Precision: The Precision class or type to convert to. | ||
|
|
||
| Returns: | ||
| list: A list of converted precision values. | ||
| """ | ||
| if isinstance(value, list): | ||
| converted = [] | ||
| for v in value: | ||
| if isinstance(v, Precision): | ||
| converted.append(v) | ||
| else: | ||
| converted.append(Precision(v)) | ||
| return converted | ||
| else: | ||
| if isinstance(value, Precision): | ||
| return [value] | ||
| else: | ||
| return [Precision(value)] | ||
|
|
||
| def _check_unknown_args(self, unknown): | ||
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Check for unknown arguments and log an error if any are found. | ||
|
|
||
| Args: | ||
| unknown (list): List of unknown arguments. | ||
|
|
||
| Returns: | ||
| bool: False if unknown arguments are found, True otherwise. | ||
| """ | ||
| if len(unknown) > 0: | ||
| logger.error( | ||
| 'Unknown arguments - benchmark: {}, unknown arguments: {}'.format(self._name, ' '.join(unknown)) | ||
| ) | ||
| ret = False | ||
|
|
||
| return ret, args, unknown | ||
| return False | ||
| return True | ||
|
|
||
| def _preprocess(self): | ||
| """Preprocess/preparation operations before the benchmarking. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The acronym 'SDC' is not defined. Consider either defining it on first use (e.g., 'Silent Data Corruption (SDC)') or using a more descriptive term like 'Deterministic Training' to match the PR description.