Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 906f1ee3 authored by LIN Xiaoyu's avatar LIN Xiaoyu
Browse files

Update save_model.py

parent 11b4d1de
No related branches found
No related tags found
No related merge requests found
import os ## DVAE-UMOT
import pickle ## Copyright Inria
## Year 2022
import numpy as np ## Contact : xiaoyu.lin@inria.fr
import torch
from torch.utils.tensorboard import SummaryWriter ## DVAE-UMOT is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
class SaveLog(): ## (at your option) any later version.
def __init__(self, cfg):
self.cfg = cfg ## DVAE-UMOT is distributed in the hope that it will be useful,
self.save_dir = self.create_save_directory() ## but WITHOUT ANY WARRANTY; without even the implied warranty of
self.log_dir = self.create_log_file() ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
self.log_file = self.create_log_file() ## GNU General Public License for more details.
self.tf_path = os.path.join(self.save_dir, 'tensorboard') ##
self.tf_path_train = os.path.join(self.tf_path, 'training') ## You should have received a copy of the GNU General Public License
self.tf_path_val = os.path.join(self.tf_path, 'val') ## along with this program, DVAE-UMOT. If not, see <http://www.gnu.org/licenses/> and the LICENSE file.
self.tf_path_dvaeumot = os.path.join(self.tf_path, 'dvaeumot')
self.summary_writer_training = SummaryWriter(self.tf_path_train) # DVAE-UMOT has code derived from
self.summary_writer_val = SummaryWriter(self.tf_path_val) # (1) ArTIST, https://github.com/fatemeh-slh/ArTIST.
self.summary_writer_dvaeumot = SummaryWriter(self.tf_path_dvaeumot) # (2) DVAE, https://github.com/XiaoyuBIE1994/DVAE, distributed under MIT License 2020 INRIA.
def create_save_directory(self): import os
save_root = self.cfg.get('User', 'save_root') import pickle
model_name = self.cfg.get('Network', 'name')
dataset_name = self.cfg.get('DataFrame', 'dataset_name') import numpy as np
directory_name = '{}_{}'.format(dataset_name, model_name) import torch
save_dir = os.path.join(save_root, directory_name) from torch.utils.tensorboard import SummaryWriter
if not(os.path.isdir(save_dir)):
os.makedirs(save_dir)
class SaveLog():
return save_dir def __init__(self, cfg):
self.cfg = cfg
def save_config_file(self): self.save_dir = self.create_save_directory()
save_path = os.path.join(self.save_dir, 'config.ini') self.log_dir = self.create_log_file()
with open(save_path, 'w') as configfile: self.log_file = self.create_log_file()
self.cfg.write(configfile) self.tf_path = os.path.join(self.save_dir, 'tensorboard')
self.tf_path_train = os.path.join(self.tf_path, 'training')
def create_log_file(self): self.tf_path_val = os.path.join(self.tf_path, 'val')
log_file = os.path.join(self.save_dir, 'log.txt') self.tf_path_dvaeumot = os.path.join(self.tf_path, 'dvaeumot')
with open(log_file, "w") as f: self.summary_writer_training = SummaryWriter(self.tf_path_train)
f.write('Experiment Log\n') self.summary_writer_val = SummaryWriter(self.tf_path_val)
self.summary_writer_dvaeumot = SummaryWriter(self.tf_path_dvaeumot)
return log_file
def create_save_directory(self):
def print_info(self, info_list): save_root = self.cfg.get('User', 'save_root')
with open(self.log_file, "a") as f: model_name = self.cfg.get('Network', 'name')
for info in info_list: dataset_name = self.cfg.get('DataFrame', 'dataset_name')
f.write('%s\n' % info) directory_name = '{}_{}'.format(dataset_name, model_name)
save_dir = os.path.join(save_root, directory_name)
def plot_current_training_loss(self, loss_dict, step): if not(os.path.isdir(save_dir)):
for k, v in loss_dict.items(): os.makedirs(save_dir)
self.summary_writer_training.add_scalar(k, v, step)
return save_dir
def plot_current_val_loss(self, loss_dict, step):
for k, v in loss_dict.items(): def save_config_file(self):
self.summary_writer_val.add_scalar(k, v, step) save_path = os.path.join(self.save_dir, 'config.ini')
with open(save_path, 'w') as configfile:
def plot_dvaeumot_dvae_loss(self, loss_dict, step): self.cfg.write(configfile)
for k, v in loss_dict.items():
self.summary_writer_dvaeumot.add_scalar(k, v, step) def create_log_file(self):
log_file = os.path.join(self.save_dir, 'log.txt')
def save_dvaeumot_results(self, batch_idx, results_list): with open(log_file, "w") as f:
self.results_save_path = os.path.join(self.save_dir, 'Results_VEM_initphidiag_batch{}'.format(batch_idx)) f.write('Experiment Log\n')
results_file_name = ['x_mean_dvaeumot_iter.pkl', 'data_gt.pkl', 'data_obs.pkl']
if not(os.path.isdir(self.results_save_path)): return log_file
os.makedirs(self.results_save_path)
for i in range(len(results_list)): def print_info(self, info_list):
with open(os.path.join(self.results_save_path, results_file_name[i]), 'wb') as file: with open(self.log_file, "a") as f:
pickle.dump(results_list[i].to('cpu'), file) for info in info_list:
f.write('%s\n' % info)
def save_KF_results(self, batch_idx, results_list):
self.results_save_path = os.path.join(self.save_dir, 'Results_VEM_initphidiag_batch{}'.format(batch_idx)) def plot_current_training_loss(self, loss_dict, step):
results_file_name = ['Eta_iter.pkl', 'x_mean_dvaeumot_iter.pkl', 'x_var_dvaeumot_iter.pkl', 'Lambda_iter.pkl', 'data_obs.pkl', 'data_gt.pkl'] for k, v in loss_dict.items():
if not (os.path.isdir(self.results_save_path)): self.summary_writer_training.add_scalar(k, v, step)
os.makedirs(self.results_save_path)
for i in range(len(results_list)): def plot_current_val_loss(self, loss_dict, step):
with open(os.path.join(self.results_save_path, results_file_name[i]), 'wb') as file: for k, v in loss_dict.items():
pickle.dump(results_list[i].to('cpu'), file) self.summary_writer_val.add_scalar(k, v, step)
def save_dvaeumot_init_params(self, params_list, batch_idx): def plot_dvaeumot_dvae_loss(self, loss_dict, step):
self.init_params_path = os.path.join(self.save_dir, 'InitParams_VEM_initphidiag_{}'.format(batch_idx)) for k, v in loss_dict.items():
init_params_file_name = ['x_mean_dvaeumot_init.pkl', 'x_var_dvaeumot_init.pkl', 'x_sampled_init.pkl', 'Phi_init.pkl', 'Phi_inv_init.pkl', 'o.pkl'] self.summary_writer_dvaeumot.add_scalar(k, v, step)
if not(os.path.isdir(self.init_params_path)):
os.makedirs(self.init_params_path) def save_dvaeumot_results(self, batch_idx, results_list):
for i in range(len(params_list)): self.results_save_path = os.path.join(self.save_dir, 'Results_VEM_initphidiag_batch{}'.format(batch_idx))
with open(os.path.join(self.init_params_path, init_params_file_name[i]), 'wb') as file: results_file_name = ['x_mean_dvaeumot_iter.pkl', 'data_gt.pkl', 'data_obs.pkl']
pickle.dump(params_list[i].to('cpu'), file) if not(os.path.isdir(self.results_save_path)):
os.makedirs(self.results_save_path)
def save_model_dvae(self, batch_idx, dvae_model): for i in range(len(results_list)):
dvae_models_save_path = os.path.join(self.save_dir, 'DVAE_MODEL') with open(os.path.join(self.results_save_path, results_file_name[i]), 'wb') as file:
if not(os.path.isdir(dvae_models_save_path)): pickle.dump(results_list[i].to('cpu'), file)
os.makedirs(dvae_models_save_path)
model_save_path = os.path.join(dvae_models_save_path, 'model_batch{}.pt'.format(batch_idx)) def save_KF_results(self, batch_idx, results_list):
torch.save(dvae_model.state_dict(), model_save_path) self.results_save_path = os.path.join(self.save_dir, 'Results_VEM_initphidiag_batch{}'.format(batch_idx))
results_file_name = ['Eta_iter.pkl', 'x_mean_dvaeumot_iter.pkl', 'x_var_dvaeumot_iter.pkl', 'Lambda_iter.pkl', 'data_obs.pkl', 'data_gt.pkl']
def save_model(self, epoch, epoch_iter, total_steps, model, iter_file_path, end_of_epoch=False, save_best=False): if not (os.path.isdir(self.results_save_path)):
save_latest_freq = self.cfg.getint('Training', 'save_latest_freq') os.makedirs(self.results_save_path)
save_epoch_freq = self.cfg.getint('Training', 'save_epoch_freq') for i in range(len(results_list)):
save_models_file = os.path.join(self.save_dir, 'models') with open(os.path.join(self.results_save_path, results_file_name[i]), 'wb') as file:
if not (os.path.isdir(save_models_file)): pickle.dump(results_list[i].to('cpu'), file)
os.makedirs(save_models_file)
if not end_of_epoch: def save_dvaeumot_init_params(self, params_list, batch_idx):
if total_steps % save_latest_freq == 0: self.init_params_path = os.path.join(self.save_dir, 'InitParams_VEM_initphidiag_{}'.format(batch_idx))
print('Saving the latest model epoch %d, total_steps %d' % (epoch, total_steps)) init_params_file_name = ['x_mean_dvaeumot_init.pkl', 'x_var_dvaeumot_init.pkl', 'x_sampled_init.pkl', 'Phi_init.pkl', 'Phi_inv_init.pkl', 'o.pkl']
save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt') if not(os.path.isdir(self.init_params_path)):
torch.save(model.state_dict(), save_latest_file) os.makedirs(self.init_params_path)
np.savetxt(iter_file_path, (epoch, epoch_iter), delimiter=',', fmt='%d') for i in range(len(params_list)):
with open(os.path.join(self.init_params_path, init_params_file_name[i]), 'wb') as file:
else: pickle.dump(params_list[i].to('cpu'), file)
if save_best:
print('Saving the model with best validation loss at epoch %d, total_steps %d' % (epoch, total_steps)) def save_model_dvae(self, batch_idx, dvae_model):
save_epoch_file = os.path.join(save_models_file, 'model_best.pt') dvae_models_save_path = os.path.join(self.save_dir, 'DVAE_MODEL')
torch.save(model.state_dict(), save_epoch_file) if not(os.path.isdir(dvae_models_save_path)):
if epoch % save_epoch_freq == 0: os.makedirs(dvae_models_save_path)
print('Saving the model at the end of epoch %d, total_steps %d' % (epoch, total_steps)) model_save_path = os.path.join(dvae_models_save_path, 'model_batch{}.pt'.format(batch_idx))
save_latest_file = os.path.join(self.save_dir, 'model_latest.pt') torch.save(dvae_model.state_dict(), model_save_path)
torch.save(model.state_dict(), save_latest_file)
save_epoch_file = os.path.join(save_models_file, 'model_epoch_%s.pt' % epoch) def save_model(self, epoch, epoch_iter, total_steps, model, iter_file_path, end_of_epoch=False, save_best=False):
torch.save(model.state_dict(), save_epoch_file) save_latest_freq = self.cfg.getint('Training', 'save_latest_freq')
np.savetxt(iter_file_path, (epoch+1, 0), delimiter=',', fmt='%d') save_epoch_freq = self.cfg.getint('Training', 'save_epoch_freq')
save_models_file = os.path.join(self.save_dir, 'models')
def save_evaluation(self, summary_list, mota_list, total_iter): if not (os.path.isdir(save_models_file)):
eval_path = os.path.join(self.save_dir, 'evaluation_metrics.txt') os.makedirs(save_models_file)
mota_path = os.path.join(self.save_dir, 'mota_list.txt') if not end_of_epoch:
with open(eval_path, "w") as text_file: if total_steps % save_latest_freq == 0:
for iter_number in range(total_iter): print('Saving the latest model epoch %d, total_steps %d' % (epoch, total_steps))
text_file.write('#'*20) save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt')
text_file.write('Iteration {}'.format(iter_number)) torch.save(model.state_dict(), save_latest_file)
text_file.write('#'*20) np.savetxt(iter_file_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
text_file.write('\n')
text_file.write(summary_list[iter_number]) else:
text_file.write('\n') if save_best:
np.savetxt(mota_path, mota_list, delimiter=',') print('Saving the model with best validation loss at epoch %d, total_steps %d' % (epoch, total_steps))
save_epoch_file = os.path.join(save_models_file, 'model_best.pt')
torch.save(model.state_dict(), save_epoch_file)
if epoch % save_epoch_freq == 0:
print('Saving the model at the end of epoch %d, total_steps %d' % (epoch, total_steps))
save_latest_file = os.path.join(self.save_dir, 'model_latest.pt')
torch.save(model.state_dict(), save_latest_file)
save_epoch_file = os.path.join(save_models_file, 'model_epoch_%s.pt' % epoch)
torch.save(model.state_dict(), save_epoch_file)
np.savetxt(iter_file_path, (epoch+1, 0), delimiter=',', fmt='%d')
def save_evaluation(self, summary_list, mota_list, total_iter):
eval_path = os.path.join(self.save_dir, 'evaluation_metrics.txt')
mota_path = os.path.join(self.save_dir, 'mota_list.txt')
with open(eval_path, "w") as text_file:
for iter_number in range(total_iter):
text_file.write('#'*20)
text_file.write('Iteration {}'.format(iter_number))
text_file.write('#'*20)
text_file.write('\n')
text_file.write(summary_list[iter_number])
text_file.write('\n')
np.savetxt(mota_path, mota_list, delimiter=',')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment