diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index 0dffe467..b7c1b4fb 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -22,7 +22,7 @@ from init2winit.init_lib import initializers from init2winit.model_lib import models from init2winit.trainer_lib import training_algorithm -from init2winit.trainer_lib import training_algorithms +from init2winit.trainer_lib import training_algorithms_registry from ml_collections.config_dict import config_dict @@ -138,7 +138,9 @@ def build_hparams( overrides_dict = json.loads(overrides_dict) # Training hparams come from the training algorithm. - algo_cls = training_algorithms.get_training_algorithm(training_algorithm_name) + algo_cls = training_algorithms_registry.get_training_algorithm( + training_algorithm_name + ) # For OptaxTrainingAlgorithm, pass optimizer_name (if overridden) and # model_name so it can resolve defaults using the 3-tier hierarchy. diff --git a/init2winit/main.py b/init2winit/main.py index e13451d7..c327e9c6 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -31,7 +31,7 @@ from init2winit.init_lib import initializers from init2winit.model_lib import models from init2winit.trainer_lib import trainers -from init2winit.trainer_lib import training_algorithms +from init2winit.trainer_lib import training_algorithms_registry import jax from jax.experimental import multihost_utils from ml_collections import config_flags @@ -157,8 +157,10 @@ def _run( num_device_prefetches=num_device_prefetches, num_tf_data_map_parallel_calls=num_tf_data_map_parallel_calls, )) - training_algorithm_class = training_algorithms.get_training_algorithm( - training_algorithm_name + training_algorithm_class = ( + training_algorithms_registry.get_training_algorithm( + training_algorithm_name + ) ) merged_hps = hyperparameters.build_hparams( diff --git a/init2winit/trainer_lib/training_algorithms.py b/init2winit/trainer_lib/training_algorithms_registry.py similarity index 99% rename from init2winit/trainer_lib/training_algorithms.py rename to init2winit/trainer_lib/training_algorithms_registry.py index 3c28a59e..40a539e0 100644 --- a/init2winit/trainer_lib/training_algorithms.py +++ b/init2winit/trainer_lib/training_algorithms_registry.py @@ -16,6 +16,7 @@ """Module for registering and retrieving training algorithms.""" from init2winit.trainer_lib import training_algorithm + # pylint: disable=g-bad-import-order @@ -34,4 +35,5 @@ def register_training_algorithm(name): def decorator(cls): _ALGORITHMS[name] = cls return cls + return decorator