-
Notifications
You must be signed in to change notification settings - Fork 239
Open
Description
Dear Authors,
Thank you for maintaining an excellent repository and keeping it intuitive for learners.
I wanted to confirm there is a better way to have validation loss during train time.
How do I compute validation loss during training?
#810
Taken inspiration from: facebookresearch/detectron2#810 (comment)
def do_train(args, cfg):
"""
Args:
cfg: an object with the following attributes:
model: instantiate to a module
dataloader.{train,test}: instantiate to dataloaders
dataloader.evaluator: instantiate to evaluator for test set
optimizer: instantaite to an optimizer
lr_multiplier: instantiate to a fvcore scheduler
train: other misc config defined in `configs/common/train.py`, including:
output_dir (str)
init_checkpoint (str)
amp.enabled (bool)
max_iter (int)
eval_period, log_period (int)
device (str)
checkpointer (dict)
ddp (dict)
"""
model = instantiate(cfg.model)
logger = logging.getLogger("detectron2")
logger.info("Model:\n{}".format(model))
model.to(cfg.train.device)
# instantiate optimizer
cfg.optimizer.params.model = model
optim = instantiate(cfg.optimizer)
# build training loader
train_loader = instantiate(cfg.dataloader.train)
# create ddp model
model = create_ddp_model(model, **cfg.train.ddp)
# build model ema
ema.may_build_model_ema(cfg, model)
trainer = Trainer(
model=model,
dataloader=train_loader,
optimizer=optim,
amp=cfg.train.amp.enabled,
clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None,
)
checkpointer = DetectionCheckpointer(
model,
cfg.train.output_dir,
trainer=trainer,
# save model ema
**ema.may_get_ema_checkpointer(cfg, model)
)
if comm.is_main_process():
# writers = default_writers(cfg.train.output_dir, cfg.train.max_iter)
output_dir = cfg.train.output_dir
PathManager.mkdirs(output_dir)
writers = [
CommonMetricPrinter(cfg.train.max_iter),
JSONWriter(os.path.join(output_dir, "metrics.json")),
TensorboardXWriter(output_dir),
]
if cfg.train.wandb.enabled:
PathManager.mkdirs(cfg.train.wandb.params.dir)
writers.append(WandbWriter(cfg))
trainer.register_hooks(
[
hooks.IterationTimer(),
ema.EMAHook(cfg, model) if cfg.train.model_ema.enabled else None,
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
if comm.is_main_process()
else None,
hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
hooks.PeriodicWriter(
writers,
period=cfg.train.log_period,
)
if comm.is_main_process()
else None,
]
)
val_loss = ValidationLoss(cfg)
trainer.register_hooks([val_loss]) ## Register Validation Loss Hook
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
if args.resume and checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
start_iter = trainer.iter + 1
else:
start_iter = 0
trainer.train(start_iter, cfg.train.max_iter)
import copy
from detectron2.engine import HookBase
class ValidationLoss(HookBase):
def __init__(self, cfg):
super().__init__()
self.cfg = copy.deepcopy(cfg)
## In order to get instances from test data
#print("self.cfg.dataloader.test.keys():", self.cfg.dataloader.test.keys())
##self.cfg.dataloader.test.keys(): dict_keys(['dataset', 'mapper', 'num_workers', '_target_'])
self.cfg.dataloader.test.mapper.is_train = True
self.data_loader = instantiate(self.cfg.dataloader.test)
def after_step(self):
if self.trainer.iter % (self.cfg.train.eval_period // 4) == 0: # Evaluate every 25% of eval_period
self._compute_validation_loss()
def _compute_validation_loss(self):
assert self.trainer.model.training, "[Trainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[Trainer] CUDA is required for AMP training!"
from torch.cuda.amp import autocast
"""
If you want to do something with the data, you can wrap the dataloader.
"""
total_loss = 0
num_batches = 0
loss_sums = {} # Dictionary to accumulate loss for each component
with torch.no_grad():
for data in self.data_loader: #tqdm.tqdm()
##print(data) ## Contains instances when self.cfg.dataloader.test.mapper.is_train : True
"""
If you want to do something with the losses, you can wrap the model.
"""
with autocast(enabled=self.trainer.amp):
loss_dict = self.trainer.model(data)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
#loss_dict = self.trainer.model(data)
#print("loss_dict:", loss_dict)
# Accumulate each loss separately
for key, loss in loss_dict.items():
loss_value = loss.item()
loss_sums[key] = loss_sums.get(key, 0) + loss_value # Accumulate loss per key
total_loss += losses.item()
num_batches += 1
# Compute dataset-wide average losses
avg_losses = {key: value / num_batches for key, value in loss_sums.items()} if num_batches > 0 else {}
# Store scalar values
for key, avg_loss in avg_losses.items():
self.trainer.storage.put_scalar(f"validation_{key}", avg_loss)
avg_val_loss = total_loss / num_batches if num_batches > 0 else 0
self.trainer.storage.put_scalar("validation_loss", avg_val_loss)
print(f"Validation Loss Breakdown: {avg_losses}") # Optional logging
Please add validation loss evaluation as a feature.
Thank you
Metadata
Metadata
Assignees
Labels
No labels