diff --git a/train_dvae_single.py b/train_dvae_single.py index 30e5e3eb3b89a411c55af748d0123ff029f9417d..f5a5582566463a896821281ae476470705d37a0a 100644 --- a/train_dvae_single.py +++ b/train_dvae_single.py @@ -89,7 +89,7 @@ def train(cfg_file): for info in data_info: print('%s' % info) - # Initialize training parameters + # Initialize training parameters n_epochs, early_stop_patience, \ total_steps, start_epoch, epoch_iter, iter_file_path, random_seed = init_training_params(cfg, save_dir, train_data_loader) diff --git a/utils.py b/utils.py index dbcfb03d273281ba1d93a753c824266b8d5382a0..8958a0626c7eb26e3496233704e504c9ee9a6f67 100644 --- a/utils.py +++ b/utils.py @@ -81,7 +81,13 @@ def create_dvae_model(cfg, device, save_dir): else: model.load_state_dict(torch.load(save_path, map_location=device)) else: - print('No epoch specified, model will be trained from the begining.') + print('No epoch specified, model will be trained from the recorded latest epoch.') + save_filename = 'models/model_epoch_latest.pt' + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + raise ValueError('%s not exits!' % save_path) + else: + model.load_state_dict(torch.load(save_path, map_location=device)) return model.to(device) @@ -94,9 +100,13 @@ def init_training_params(cfg, save_dir, train_data_loader): start_epoch, epoch_iter = 1, 0 continue_train = cfg.getboolean('Training', 'continue_train') if continue_train: - if os.path.exists(iter_file_path): - start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int) - print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) + which_epoch = cfg.get('Training', 'which_epoch') + if which_epoch is not None: + print('Resuming from epoch %s' % which_epoch) + else: + if os.path.exists(iter_file_path): + start_epoch, epoch_iter = np.loadtxt(iter_file_path , delimiter=',', dtype=int) + print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) total_steps = (start_epoch - 1) * len(train_data_loader) + epoch_iter