From 01e57dd019144520fd7599994bcbfb887a36196f Mon Sep 17 00:00:00 2001
From: xiaoyu lin <xiaoyulin@wpa2-194-199-31-13-dyn.inrialpes.fr>
Date: Wed, 17 Aug 2022 13:53:00 +0200
Subject: [PATCH 1/2] update

---
 train_dvae_single.py |  2 +-
 utils.py             | 18 ++++++++++++++----
 2 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/train_dvae_single.py b/train_dvae_single.py
index 30e5e3e..f5a5582 100644
--- a/train_dvae_single.py
+++ b/train_dvae_single.py
@@ -89,7 +89,7 @@ def train(cfg_file):
     for info in data_info:
         print('%s' % info)
 
-    # Initialize training parameters
+    # Initialize training parameters 
     n_epochs, early_stop_patience, \
            total_steps, start_epoch, epoch_iter, iter_file_path, random_seed = init_training_params(cfg, save_dir, train_data_loader)
     
diff --git a/utils.py b/utils.py
index dbcfb03..86e1e03 100644
--- a/utils.py
+++ b/utils.py
@@ -81,7 +81,13 @@ def create_dvae_model(cfg, device, save_dir):
             else:
                 model.load_state_dict(torch.load(save_path, map_location=device))
         else:
-            print('No epoch specified, model will be trained from the begining.')
+            print('No epoch specified, model will be trained from the recorded latest epoch.')
+            save_filename = 'models/model_epoch_latest.pt'
+            save_path = os.path.join(save_dir, save_filename)
+            if not os.path.isfile(save_path):
+                raise ValueError('%s not exits!' % save_path)
+            else:
+                model.load_state_dict(torch.load(save_path, map_location=device))
 
     return model.to(device)
 
@@ -94,9 +100,13 @@ def init_training_params(cfg, save_dir, train_data_loader):
     start_epoch, epoch_iter = 1, 0
     continue_train = cfg.getboolean('Training', 'continue_train')
     if continue_train:
-        if os.path.exists(iter_file_path):
-            start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int)
-        print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
+        which_epoch = cfg.get('Training', 'which_epoch')
+        if which_epoch is not None:
+            print('Resuming from epoch %d' % which_epoch)
+        else:
+            if os.path.exists(iter_file_path):
+                start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int)
+            print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
 
     total_steps = (start_epoch - 1) * len(train_data_loader) + epoch_iter
 
-- 
GitLab


From 65f58f873130db256eff8dd0827907f2b1247501 Mon Sep 17 00:00:00 2001
From: xiaoyu lin <xiaoyulin@wpa2-194-199-31-13-dyn.inrialpes.fr>
Date: Wed, 17 Aug 2022 13:58:50 +0200
Subject: [PATCH 2/2] update

---
 utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/utils.py b/utils.py
index 86e1e03..8958a06 100644
--- a/utils.py
+++ b/utils.py
@@ -102,7 +102,7 @@ def init_training_params(cfg, save_dir, train_data_loader):
     if continue_train:
         which_epoch = cfg.get('Training', 'which_epoch')
         if which_epoch is not None:
-            print('Resuming from epoch %d' % which_epoch)
+            print('Resuming from epoch %s' % which_epoch)
         else:
             if os.path.exists(iter_file_path):
                 start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int)
-- 
GitLab