Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 8e88c60e authored by LIN Xiaoyu's avatar LIN Xiaoyu
Browse files

Merge branch 'master' into 'main'

update

See merge request !5
parents 93769e48 eddc4068
No related branches found
No related tags found
1 merge request!5update
......@@ -11,7 +11,7 @@ val_data_dir = data/synthetic_trajectories/val_data
[Network]
name = SRNN_dvae_single
name = SRNN_dvae_single_test
x_dim = 4
z_dim = 4
activation = tanh
......
......@@ -40,6 +40,11 @@ def train(cfg_file):
cfg = ConfigParser()
cfg.read(cfg_file)
# Set random seed
random_seed = cfg.getint('Training', 'random_seed')
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Create save log directory
save_log = SaveLog(cfg)
save_dir = save_log.save_dir
......@@ -91,10 +96,7 @@ def train(cfg_file):
# Initialize training parameters
n_epochs, early_stop_patience, \
total_steps, start_epoch, epoch_iter, iter_file_path, random_seed = init_training_params(cfg, save_dir, train_data_loader)
torch.manual_seed(random_seed)
np.random.seed(random_seed)
total_steps, start_epoch, epoch_iter, iter_file_path = init_training_params(cfg, save_dir, train_data_loader)
# Start training
print('Start training...')
......
......@@ -94,7 +94,6 @@ def create_dvae_model(cfg, device, save_dir):
def init_training_params(cfg, save_dir, train_data_loader):
n_epochs = cfg.getint('Training', 'n_epochs')
early_stop_patience = cfg.getint('Training', 'early_stop_patience')
random_seed = cfg.getint('Training', 'random_seed')
iter_file_path = os.path.join(save_dir, 'iter.txt')
start_epoch, epoch_iter = 1, 0
......@@ -111,7 +110,7 @@ def init_training_params(cfg, save_dir, train_data_loader):
total_steps = (start_epoch - 1) * len(train_data_loader) + epoch_iter
return n_epochs, early_stop_patience, \
total_steps, start_epoch, epoch_iter, iter_file_path, random_seed
total_steps, start_epoch, epoch_iter, iter_file_path
def tracking_evaluation_onebatch(gt_seq, normalize_range, acc_list, eta_iter, x_mean_vem_iter):
total_iter = eta_iter.shape[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment