From 1d062b46a328945b60004799a6b14a9fad9c33f1 Mon Sep 17 00:00:00 2001 From: LIN Xiaoyu <xiaoyu.lin@inria.fr> Date: Wed, 17 Aug 2022 09:04:52 +0000 Subject: [PATCH] Update save_model.py --- save_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/save_model.py b/save_model.py index 88f34f6..11abe40 100644 --- a/save_model.py +++ b/save_model.py @@ -116,7 +116,7 @@ class SaveLog(): model_save_path = os.path.join(dvae_models_save_path, 'model_batch{}.pt'.format(batch_idx)) torch.save(dvae_model.state_dict(), model_save_path) - def save_model(self, epoch, epoch_iter, total_steps, model, iter_file_path, end_of_epoch=False, save_best=False): + def save_model(self, epoch, epoch_iter, total_steps, model_state_dict, iter_file_path, end_of_epoch=False, save_best=False): save_latest_freq = self.cfg.getint('Training', 'save_latest_freq') save_epoch_freq = self.cfg.getint('Training', 'save_epoch_freq') save_models_file = os.path.join(self.save_dir, 'models') @@ -126,20 +126,20 @@ class SaveLog(): if total_steps % save_latest_freq == 0: print('Saving the latest model epoch %d, total_steps %d' % (epoch, total_steps)) save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt') - torch.save(model.state_dict(), save_latest_file) + torch.save(model_state_dict, save_latest_file) np.savetxt(iter_file_path, (epoch, epoch_iter), delimiter=',', fmt='%d') else: if save_best: print('Saving the model with best validation loss at epoch %d, total_steps %d' % (epoch, total_steps)) save_epoch_file = os.path.join(save_models_file, 'model_best.pt') - torch.save(model.state_dict(), save_epoch_file) + torch.save(model_state_dict, save_epoch_file) if epoch % save_epoch_freq == 0: print('Saving the model at the end of epoch %d, total_steps %d' % (epoch, total_steps)) save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt') - torch.save(model.state_dict(), save_latest_file) + torch.save(model_state_dict, save_latest_file) save_epoch_file = os.path.join(save_models_file, 'model_epoch_%s.pt' % epoch) - torch.save(model.state_dict(), save_epoch_file) + torch.save(model_state_dict, save_epoch_file) np.savetxt(iter_file_path, (epoch+1, 0), delimiter=',', fmt='%d') def save_evaluation(self, summary_list, mota_list, total_iter): -- GitLab