Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 1d062b46 authored by LIN Xiaoyu's avatar LIN Xiaoyu
Browse files

Update save_model.py

parent 70aba931
Branches
No related tags found
No related merge requests found
...@@ -116,7 +116,7 @@ class SaveLog(): ...@@ -116,7 +116,7 @@ class SaveLog():
model_save_path = os.path.join(dvae_models_save_path, 'model_batch{}.pt'.format(batch_idx)) model_save_path = os.path.join(dvae_models_save_path, 'model_batch{}.pt'.format(batch_idx))
torch.save(dvae_model.state_dict(), model_save_path) torch.save(dvae_model.state_dict(), model_save_path)
def save_model(self, epoch, epoch_iter, total_steps, model, iter_file_path, end_of_epoch=False, save_best=False): def save_model(self, epoch, epoch_iter, total_steps, model_state_dict, iter_file_path, end_of_epoch=False, save_best=False):
save_latest_freq = self.cfg.getint('Training', 'save_latest_freq') save_latest_freq = self.cfg.getint('Training', 'save_latest_freq')
save_epoch_freq = self.cfg.getint('Training', 'save_epoch_freq') save_epoch_freq = self.cfg.getint('Training', 'save_epoch_freq')
save_models_file = os.path.join(self.save_dir, 'models') save_models_file = os.path.join(self.save_dir, 'models')
...@@ -126,20 +126,20 @@ class SaveLog(): ...@@ -126,20 +126,20 @@ class SaveLog():
if total_steps % save_latest_freq == 0: if total_steps % save_latest_freq == 0:
print('Saving the latest model epoch %d, total_steps %d' % (epoch, total_steps)) print('Saving the latest model epoch %d, total_steps %d' % (epoch, total_steps))
save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt') save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt')
torch.save(model.state_dict(), save_latest_file) torch.save(model_state_dict, save_latest_file)
np.savetxt(iter_file_path, (epoch, epoch_iter), delimiter=',', fmt='%d') np.savetxt(iter_file_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
else: else:
if save_best: if save_best:
print('Saving the model with best validation loss at epoch %d, total_steps %d' % (epoch, total_steps)) print('Saving the model with best validation loss at epoch %d, total_steps %d' % (epoch, total_steps))
save_epoch_file = os.path.join(save_models_file, 'model_best.pt') save_epoch_file = os.path.join(save_models_file, 'model_best.pt')
torch.save(model.state_dict(), save_epoch_file) torch.save(model_state_dict, save_epoch_file)
if epoch % save_epoch_freq == 0: if epoch % save_epoch_freq == 0:
print('Saving the model at the end of epoch %d, total_steps %d' % (epoch, total_steps)) print('Saving the model at the end of epoch %d, total_steps %d' % (epoch, total_steps))
save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt') save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt')
torch.save(model.state_dict(), save_latest_file) torch.save(model_state_dict, save_latest_file)
save_epoch_file = os.path.join(save_models_file, 'model_epoch_%s.pt' % epoch) save_epoch_file = os.path.join(save_models_file, 'model_epoch_%s.pt' % epoch)
torch.save(model.state_dict(), save_epoch_file) torch.save(model_state_dict, save_epoch_file)
np.savetxt(iter_file_path, (epoch+1, 0), delimiter=',', fmt='%d') np.savetxt(iter_file_path, (epoch+1, 0), delimiter=',', fmt='%d')
def save_evaluation(self, summary_list, mota_list, total_iter): def save_evaluation(self, summary_list, mota_list, total_iter):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment