From eddc4068b186b71dc90d566de3628588e5fd4509 Mon Sep 17 00:00:00 2001 From: xiaoyu lin <xiaoyulin@wpa2-194-199-31-13-dyn.inrialpes.fr> Date: Wed, 17 Aug 2022 14:16:38 +0200 Subject: [PATCH] update --- config/cfg_dvae_single.ini | 2 +- train_dvae_single.py | 10 ++++++---- utils.py | 3 +-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/config/cfg_dvae_single.ini b/config/cfg_dvae_single.ini index b9a52ef..e43f105 100644 --- a/config/cfg_dvae_single.ini +++ b/config/cfg_dvae_single.ini @@ -11,7 +11,7 @@ val_data_dir = data/synthetic_trajectories/val_data [Network] -name = SRNN_dvae_single +name = SRNN_dvae_single_test x_dim = 4 z_dim = 4 activation = tanh diff --git a/train_dvae_single.py b/train_dvae_single.py index f5a5582..365ca6b 100644 --- a/train_dvae_single.py +++ b/train_dvae_single.py @@ -40,6 +40,11 @@ def train(cfg_file): cfg = ConfigParser() cfg.read(cfg_file) + # Set random seed + random_seed = cfg.getint('Training', 'random_seed') + torch.manual_seed(random_seed) + np.random.seed(random_seed) + # Create save log directory save_log = SaveLog(cfg) save_dir = save_log.save_dir @@ -91,10 +96,7 @@ def train(cfg_file): # Initialize training parameters n_epochs, early_stop_patience, \ - total_steps, start_epoch, epoch_iter, iter_file_path, random_seed = init_training_params(cfg, save_dir, train_data_loader) - - torch.manual_seed(random_seed) - np.random.seed(random_seed) + total_steps, start_epoch, epoch_iter, iter_file_path = init_training_params(cfg, save_dir, train_data_loader) # Start training print('Start training...') diff --git a/utils.py b/utils.py index 8958a06..bb2a320 100644 --- a/utils.py +++ b/utils.py @@ -94,7 +94,6 @@ def create_dvae_model(cfg, device, save_dir): def init_training_params(cfg, save_dir, train_data_loader): n_epochs = cfg.getint('Training', 'n_epochs') early_stop_patience = cfg.getint('Training', 'early_stop_patience') - random_seed = cfg.getint('Training', 'random_seed') iter_file_path = os.path.join(save_dir, 'iter.txt') start_epoch, epoch_iter = 1, 0 @@ -111,7 +110,7 @@ def init_training_params(cfg, save_dir, train_data_loader): total_steps = (start_epoch - 1) * len(train_data_loader) + epoch_iter return n_epochs, early_stop_patience, \ - total_steps, start_epoch, epoch_iter, iter_file_path, random_seed + total_steps, start_epoch, epoch_iter, iter_file_path def tracking_evaluation_onebatch(gt_seq, normalize_range, acc_list, eta_iter, x_mean_vem_iter): total_iter = eta_iter.shape[0] -- GitLab