diff --git a/config/cfg_dvae_single.ini b/config/cfg_dvae_single.ini index b9a52ef5cb48c7ce3d4e4f004bc806c6078ced9f..e43f105e6d201f20a99f9a22ee0cd9b16f02b4ce 100644 --- a/config/cfg_dvae_single.ini +++ b/config/cfg_dvae_single.ini @@ -11,7 +11,7 @@ val_data_dir = data/synthetic_trajectories/val_data [Network] -name = SRNN_dvae_single +name = SRNN_dvae_single_test x_dim = 4 z_dim = 4 activation = tanh diff --git a/train_dvae_single.py b/train_dvae_single.py index f5a5582566463a896821281ae476470705d37a0a..365ca6b251d4c7d161a393ce5f808dd453211272 100644 --- a/train_dvae_single.py +++ b/train_dvae_single.py @@ -40,6 +40,11 @@ def train(cfg_file): cfg = ConfigParser() cfg.read(cfg_file) + # Set random seed + random_seed = cfg.getint('Training', 'random_seed') + torch.manual_seed(random_seed) + np.random.seed(random_seed) + # Create save log directory save_log = SaveLog(cfg) save_dir = save_log.save_dir @@ -91,10 +96,7 @@ def train(cfg_file): # Initialize training parameters n_epochs, early_stop_patience, \ - total_steps, start_epoch, epoch_iter, iter_file_path, random_seed = init_training_params(cfg, save_dir, train_data_loader) - - torch.manual_seed(random_seed) - np.random.seed(random_seed) + total_steps, start_epoch, epoch_iter, iter_file_path = init_training_params(cfg, save_dir, train_data_loader) # Start training print('Start training...') diff --git a/utils.py b/utils.py index 8958a0626c7eb26e3496233704e504c9ee9a6f67..bb2a320a27e6cec63dd300601ad236ea96944f5a 100644 --- a/utils.py +++ b/utils.py @@ -94,7 +94,6 @@ def create_dvae_model(cfg, device, save_dir): def init_training_params(cfg, save_dir, train_data_loader): n_epochs = cfg.getint('Training', 'n_epochs') early_stop_patience = cfg.getint('Training', 'early_stop_patience') - random_seed = cfg.getint('Training', 'random_seed') iter_file_path = os.path.join(save_dir, 'iter.txt') start_epoch, epoch_iter = 1, 0 @@ -111,7 +110,7 @@ def init_training_params(cfg, save_dir, train_data_loader): total_steps = (start_epoch - 1) * len(train_data_loader) + epoch_iter return n_epochs, early_stop_patience, \ - total_steps, start_epoch, epoch_iter, iter_file_path, random_seed + total_steps, start_epoch, epoch_iter, iter_file_path def tracking_evaluation_onebatch(gt_seq, normalize_range, acc_list, eta_iter, x_mean_vem_iter): total_iter = eta_iter.shape[0]