From 01e57dd019144520fd7599994bcbfb887a36196f Mon Sep 17 00:00:00 2001 From: xiaoyu lin <xiaoyulin@wpa2-194-199-31-13-dyn.inrialpes.fr> Date: Wed, 17 Aug 2022 13:53:00 +0200 Subject: [PATCH 1/2] update --- train_dvae_single.py | 2 +- utils.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/train_dvae_single.py b/train_dvae_single.py index 30e5e3e..f5a5582 100644 --- a/train_dvae_single.py +++ b/train_dvae_single.py @@ -89,7 +89,7 @@ def train(cfg_file): for info in data_info: print('%s' % info) - # Initialize training parameters + # Initialize training parameters n_epochs, early_stop_patience, \ total_steps, start_epoch, epoch_iter, iter_file_path, random_seed = init_training_params(cfg, save_dir, train_data_loader) diff --git a/utils.py b/utils.py index dbcfb03..86e1e03 100644 --- a/utils.py +++ b/utils.py @@ -81,7 +81,13 @@ def create_dvae_model(cfg, device, save_dir): else: model.load_state_dict(torch.load(save_path, map_location=device)) else: - print('No epoch specified, model will be trained from the begining.') + print('No epoch specified, model will be trained from the recorded latest epoch.') + save_filename = 'models/model_epoch_latest.pt' + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + raise ValueError('%s not exits!' % save_path) + else: + model.load_state_dict(torch.load(save_path, map_location=device)) return model.to(device) @@ -94,9 +100,13 @@ def init_training_params(cfg, save_dir, train_data_loader): start_epoch, epoch_iter = 1, 0 continue_train = cfg.getboolean('Training', 'continue_train') if continue_train: - if os.path.exists(iter_file_path): - start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int) - print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) + which_epoch = cfg.get('Training', 'which_epoch') + if which_epoch is not None: + print('Resuming from epoch %d' % which_epoch) + else: + if os.path.exists(iter_file_path): + start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int) + print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) total_steps = (start_epoch - 1) * len(train_data_loader) + epoch_iter -- GitLab From 65f58f873130db256eff8dd0827907f2b1247501 Mon Sep 17 00:00:00 2001 From: xiaoyu lin <xiaoyulin@wpa2-194-199-31-13-dyn.inrialpes.fr> Date: Wed, 17 Aug 2022 13:58:50 +0200 Subject: [PATCH 2/2] update --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index 86e1e03..8958a06 100644 --- a/utils.py +++ b/utils.py @@ -102,7 +102,7 @@ def init_training_params(cfg, save_dir, train_data_loader): if continue_train: which_epoch = cfg.get('Training', 'which_epoch') if which_epoch is not None: - print('Resuming from epoch %d' % which_epoch) + print('Resuming from epoch %s' % which_epoch) else: if os.path.exists(iter_file_path): start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int) -- GitLab