diff --git a/train_dvae_single.py b/train_dvae_single.py
index 30e5e3eb3b89a411c55af748d0123ff029f9417d..f5a5582566463a896821281ae476470705d37a0a 100644
--- a/train_dvae_single.py
+++ b/train_dvae_single.py
@@ -89,7 +89,7 @@ def train(cfg_file):
     for info in data_info:
         print('%s' % info)
 
-    # Initialize training parameters
+    # Initialize training parameters 
     n_epochs, early_stop_patience, \
            total_steps, start_epoch, epoch_iter, iter_file_path, random_seed = init_training_params(cfg, save_dir, train_data_loader)
     
diff --git a/utils.py b/utils.py
index dbcfb03d273281ba1d93a753c824266b8d5382a0..8958a0626c7eb26e3496233704e504c9ee9a6f67 100644
--- a/utils.py
+++ b/utils.py
@@ -81,7 +81,13 @@ def create_dvae_model(cfg, device, save_dir):
             else:
                 model.load_state_dict(torch.load(save_path, map_location=device))
         else:
-            print('No epoch specified, model will be trained from the begining.')
+            print('No epoch specified, model will be trained from the recorded latest epoch.')
+            save_filename = 'models/model_epoch_latest.pt'
+            save_path = os.path.join(save_dir, save_filename)
+            if not os.path.isfile(save_path):
+                raise ValueError('%s not exits!' % save_path)
+            else:
+                model.load_state_dict(torch.load(save_path, map_location=device))
 
     return model.to(device)
 
@@ -94,9 +100,13 @@ def init_training_params(cfg, save_dir, train_data_loader):
     start_epoch, epoch_iter = 1, 0
     continue_train = cfg.getboolean('Training', 'continue_train')
     if continue_train:
-        if os.path.exists(iter_file_path):
-            start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int)
-        print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
+        which_epoch = cfg.get('Training', 'which_epoch')
+        if which_epoch is not None:
+            print('Resuming from epoch %s' % which_epoch)
+        else:
+            if os.path.exists(iter_file_path):
+                start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int)
+            print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
 
     total_steps = (start_epoch - 1) * len(train_data_loader) + epoch_iter