-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
55 lines (38 loc) · 1.32 KB
/
train.py
File metadata and controls
55 lines (38 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#Coding=utf-8
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from models.customnet import CustomNet
from models.loss import JointsMSELoss
from utils.dataset import Dataset
from utils.trainer import Trainer
from IPython import embed
def main(config):
model = CustomNet()
print('[Log] Preparing training\n')
train_dataset = Dataset(task='train')
val_dataset = Dataset(task='val')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'])
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'])
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
criterion = JointsMSELoss(use_target_weight=False)
trainer = Trainer(config, model, train_dataloader, val_dataloader, optimizer, criterion)
trainer.train()
if __name__ == "__main__":
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
random.seed(0)
config = {
'device' : torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),
'lr' : 0.001,
'n_epoch' : 100,
'batch_size' : 4,
'dir_checkpoint' : './checkpoints/w_512_h_512_e_100'
}
main(config)