From eddc4068b186b71dc90d566de3628588e5fd4509 Mon Sep 17 00:00:00 2001
From: xiaoyu lin <xiaoyulin@wpa2-194-199-31-13-dyn.inrialpes.fr>
Date: Wed, 17 Aug 2022 14:16:38 +0200
Subject: [PATCH] update

---
 config/cfg_dvae_single.ini |  2 +-
 train_dvae_single.py       | 10 ++++++----
 utils.py                   |  3 +--
 3 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/config/cfg_dvae_single.ini b/config/cfg_dvae_single.ini
index b9a52ef..e43f105 100644
--- a/config/cfg_dvae_single.ini
+++ b/config/cfg_dvae_single.ini
@@ -11,7 +11,7 @@ val_data_dir = data/synthetic_trajectories/val_data
 
 
 [Network]
-name = SRNN_dvae_single
+name = SRNN_dvae_single_test
 x_dim = 4
 z_dim = 4
 activation = tanh
diff --git a/train_dvae_single.py b/train_dvae_single.py
index f5a5582..365ca6b 100644
--- a/train_dvae_single.py
+++ b/train_dvae_single.py
@@ -40,6 +40,11 @@ def train(cfg_file):
     cfg = ConfigParser()
     cfg.read(cfg_file)
 
+    # Set random seed
+    random_seed = cfg.getint('Training', 'random_seed')
+    torch.manual_seed(random_seed)
+    np.random.seed(random_seed)
+
     # Create save log directory
     save_log = SaveLog(cfg)
     save_dir = save_log.save_dir
@@ -91,10 +96,7 @@ def train(cfg_file):
 
     # Initialize training parameters 
     n_epochs, early_stop_patience, \
-           total_steps, start_epoch, epoch_iter, iter_file_path, random_seed = init_training_params(cfg, save_dir, train_data_loader)
-    
-    torch.manual_seed(random_seed)
-    np.random.seed(random_seed)
+           total_steps, start_epoch, epoch_iter, iter_file_path = init_training_params(cfg, save_dir, train_data_loader)
     
     # Start training
     print('Start training...')
diff --git a/utils.py b/utils.py
index 8958a06..bb2a320 100644
--- a/utils.py
+++ b/utils.py
@@ -94,7 +94,6 @@ def create_dvae_model(cfg, device, save_dir):
 def init_training_params(cfg, save_dir, train_data_loader):
     n_epochs = cfg.getint('Training', 'n_epochs')
     early_stop_patience = cfg.getint('Training', 'early_stop_patience')
-    random_seed = cfg.getint('Training', 'random_seed')
 
     iter_file_path = os.path.join(save_dir, 'iter.txt')
     start_epoch, epoch_iter = 1, 0
@@ -111,7 +110,7 @@ def init_training_params(cfg, save_dir, train_data_loader):
     total_steps = (start_epoch - 1) * len(train_data_loader) + epoch_iter
 
     return n_epochs, early_stop_patience, \
-           total_steps, start_epoch, epoch_iter, iter_file_path, random_seed
+           total_steps, start_epoch, epoch_iter, iter_file_path
 
 def tracking_evaluation_onebatch(gt_seq, normalize_range, acc_list, eta_iter, x_mean_vem_iter):
     total_iter = eta_iter.shape[0]
-- 
GitLab