-
xiaoyu lin authoredxiaoyu lin authored
train_dvae_single.py 7.61 KiB
## DVAE-UMOT
## Copyright Inria
## Year 2022
## Contact : xiaoyu.lin@inria.fr
## DVAE-UMOT is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
## DVAE-UMOT is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program, DVAE-UMOT. If not, see <http://www.gnu.org/licenses/> and the LICENSE file.
# DVAE-UMOT has code derived from
# (1) ArTIST, https://github.com/fatemeh-slh/ArTIST.
# (2) DVAE, https://github.com/XiaoyuBIE1994/DVAE, distributed under MIT License 2020 INRIA.
import datetime
import os
from random import random
import shutil
import sys
from configparser import ConfigParser
import numpy as np
import torch
from save_model import SaveLog
from utils import get_basic_info, get_loss_info, initialize_optimizer, create_dvae_model, init_training_params
from data.data_loader import create_dataloader
def train(cfg_file):
# Read the config file
if not os.path.isfile(cfg_file):
raise ValueError('Invalid config file path')
cfg = ConfigParser()
cfg.read(cfg_file)
# Set random seed
random_seed = cfg.getint('Training', 'random_seed')
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Create save log directory
save_log = SaveLog(cfg)
save_dir = save_log.save_dir
# Save config file
save_cfg_path = os.path.join(save_dir, 'config.ini')
shutil.copy(cfg_file, save_cfg_path)
# Print basic information
use_cuda = cfg.getboolean('Training', 'use_cuda')
device = 'cuda' if torch.cuda.is_available() and use_cuda else 'cpu'
basic_info = get_basic_info(device)
save_log.print_info(basic_info)
for info in basic_info:
print('%s' % info)
# Create and initialize model
model = create_dvae_model(cfg, device, save_dir)
# Set module.training = True
model.train()
torch.autograd.set_detect_anomaly(True)
# Print model information
model_info = model.get_info()
save_log.print_info(model_info)
for info in model_info:
print('%s' % info)
# Create optimizer
optimizer = initialize_optimizer(cfg, model)
# Load data
train_data_loader, train_data_size = create_dataloader(cfg, data_type='dvae_train')
validation = cfg.getboolean('Training', 'validation')
if validation:
val_data_loader, val_data_size = create_dataloader(cfg, data_type='dvae_val')
# Print data information
data_info = []
data_info.append('========== DATA INFO ==========')
data_info.append('Training data: %s' % train_data_size)
if validation:
data_info.append('Validation data: %s' % val_data_size)
save_log.print_info(data_info)
for info in data_info:
print('%s' % info)
# Initialize training parameters
n_epochs, early_stop_patience, \
total_steps, start_epoch, epoch_iter, iter_file_path = init_training_params(cfg, save_dir, train_data_loader)
# Start training
print('Start training...')
if validation:
best_val_loss = np.inf
cpt_patience = 0
cur_best_epoch = n_epochs
best_state_dict = model.state_dict()
for epoch in range(start_epoch, n_epochs):
epoch_start_time = datetime.datetime.now()
training_total_loss = 0
training_recon_loss = 0
training_KLD_loss = 0
epoch_iter_number = 0
for idx, data in enumerate(train_data_loader, start=epoch_iter):
batch_size = data.shape[0]
total_steps += batch_size
epoch_iter += batch_size
epoch_iter_number += 1
data = data.to(device)
recon_data_mean, recon_data_logvar = model(data, compute_loss=True)
loss_dict = model.loss
optimizer.zero_grad()
loss_dict['loss_tot'].backward()
optimizer.step()
training_total_loss += loss_dict['loss_tot'] * batch_size
training_recon_loss += loss_dict['loss_recon'] * batch_size
training_KLD_loss += loss_dict['loss_KLD'] * batch_size
# Save latest model
save_log.save_model(epoch, epoch_iter, total_steps, model.state_dict(), iter_file_path, end_of_epoch=False, save_best=False)
training_total_loss = training_total_loss / train_data_size
training_recon_loss = training_recon_loss / train_data_size
training_KLD_loss = training_KLD_loss / train_data_size
# Display training loss
save_log.plot_current_training_loss(loss_dict, total_steps)
#Validation
if validation:
val_total_loss = 0
val_recon_loss = 0
val_KLD_loss = 0
with torch.no_grad():
for idx, val_data in enumerate(val_data_loader):
batch_size = val_data.shape[0]
val_data = val_data.to(device)
val_data = torch.autograd.Variable(val_data)
recon_data_mean, recon_data_logvar = model(val_data, compute_loss=True)
loss_dict_val = model.loss
val_total_loss += loss_dict_val['loss_tot'] * batch_size
val_recon_loss += loss_dict_val['loss_recon'] * batch_size
val_KLD_loss += loss_dict_val['loss_KLD'] * batch_size
val_total_loss = val_total_loss / val_data_size
val_recon_loss = val_recon_loss / val_data_size
val_KLD_loss = val_KLD_loss / val_data_size
avg_val_loss_dict = {'loss_tot': val_total_loss, 'loss_recon': val_recon_loss, 'loss_KLD': val_KLD_loss}
save_log.plot_current_val_loss(avg_val_loss_dict, total_steps)
torch.cuda.empty_cache()
# Early stop patience
if val_total_loss < best_val_loss:
best_val_loss = val_total_loss
cpt_patience = 0
best_state_dict = model.state_dict()
cur_best_epoch = epoch
else:
cpt_patience += 1
# End of epoch
epoch_end_time = datetime.datetime.now()
iter_time = (epoch_end_time - epoch_start_time).seconds / 60
training_info = 'End of epoch {} \t training time: {:.2f}m \t training loss {:.4f} \t'\
.format(epoch, iter_time, training_total_loss)
if validation:
training_info += 'val loss {:.4f}'.format(val_total_loss)
training_info = [training_info]
save_log.print_info(training_info)
for info in training_info:
print('%s' % info)
# Stop training if early-stop triggers
if validation:
if cpt_patience == early_stop_patience:
save_log.print_info(['Early stop patience achieved'])
print('Early stop patience achieved')
break
# Save model for this epoch
save_log.save_model(epoch, epoch_iter, total_steps, model.state_dict(), iter_file_path, end_of_epoch=True, save_best=False)
# Save the final weights of network with the best validation loss
save_log.save_model(cur_best_epoch, epoch_iter, total_steps, best_state_dict, iter_file_path, end_of_epoch=True, save_best=True)
if __name__ == '__main__':
if len(sys.argv) == 2:
cfg_file = sys.argv[1]
train(cfg_file)
else:
print('Error: Please indicate config file path')