From 1d062b46a328945b60004799a6b14a9fad9c33f1 Mon Sep 17 00:00:00 2001
From: LIN Xiaoyu <xiaoyu.lin@inria.fr>
Date: Wed, 17 Aug 2022 09:04:52 +0000
Subject: [PATCH] Update save_model.py

---
 save_model.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/save_model.py b/save_model.py
index 88f34f6..11abe40 100644
--- a/save_model.py
+++ b/save_model.py
@@ -116,7 +116,7 @@ class SaveLog():
         model_save_path = os.path.join(dvae_models_save_path, 'model_batch{}.pt'.format(batch_idx))
         torch.save(dvae_model.state_dict(), model_save_path)
 
-    def save_model(self, epoch, epoch_iter, total_steps, model, iter_file_path, end_of_epoch=False, save_best=False):
+    def save_model(self, epoch, epoch_iter, total_steps, model_state_dict, iter_file_path, end_of_epoch=False, save_best=False):
         save_latest_freq = self.cfg.getint('Training', 'save_latest_freq')
         save_epoch_freq = self.cfg.getint('Training', 'save_epoch_freq')
         save_models_file = os.path.join(self.save_dir, 'models')
@@ -126,20 +126,20 @@ class SaveLog():
             if total_steps % save_latest_freq == 0:
                 print('Saving the latest model epoch %d, total_steps %d' % (epoch, total_steps))
                 save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt')
-                torch.save(model.state_dict(), save_latest_file)
+                torch.save(model_state_dict, save_latest_file)
                 np.savetxt(iter_file_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
 
         else:
             if save_best:
                 print('Saving the model with best validation loss at epoch %d, total_steps %d' % (epoch, total_steps))
                 save_epoch_file = os.path.join(save_models_file, 'model_best.pt')
-                torch.save(model.state_dict(), save_epoch_file)
+                torch.save(model_state_dict, save_epoch_file)
             if epoch % save_epoch_freq == 0:
                 print('Saving the model at the end of epoch %d, total_steps %d' % (epoch, total_steps))
                 save_latest_file = os.path.join(save_models_file, 'model_epoch_latest.pt')
-                torch.save(model.state_dict(), save_latest_file)
+                torch.save(model_state_dict, save_latest_file)
                 save_epoch_file = os.path.join(save_models_file, 'model_epoch_%s.pt' % epoch)
-                torch.save(model.state_dict(), save_epoch_file)
+                torch.save(model_state_dict, save_epoch_file)
                 np.savetxt(iter_file_path, (epoch+1, 0), delimiter=',', fmt='%d')
 
     def save_evaluation(self, summary_list, mota_list, total_iter):
-- 
GitLab