diff --git a/models/dvae_umot_model.py b/models/dvae_umot_model.py
index 2bbb3b4461c98ebd84aa1ab00ccc6f6dd908ca5a..9db90b32df191dd614f70a6dad5d32a975f2d0a9 100644
--- a/models/dvae_umot_model.py
+++ b/models/dvae_umot_model.py
@@ -1,417 +1,437 @@
-import datetime
-
-import torch
-
-
-class DVAE_UMOT_MODEL():
-    def __init__(self, cfg, device, dvae_model, save_log):
-        self.device = device
-
-        self.N_iter_total = cfg.getint('DVAE_UMOT', 'N_iter_total')
-        self.num_source = cfg.getint('DVAE_UMOT', 'num_source')
-        self.num_obs = cfg.getint('DVAE_UMOT', 'num_obs')
-        self.std_ratio = cfg.getfloat('DVAE_UMOT', 'std_ratio')
-        self.init_iter_number = cfg.getint('DVAE_UMOT', 'init_iter_number')
-        self.init_subseq_len = cfg.getint('DVAE_UMOT', 'init_subseq_len')
-        self.batch_size = cfg.getint('Training', 'batch_size')
-        self.seq_len = cfg.getint('DataFrame', 'sequence_len')
-        self.finetune = cfg.getboolean('Training', 'finetune')
-
-        self.dvae_model = dvae_model
-
-        self.o_dim = cfg.getint('DVAE_UMOT', 'o_dim')
-        self.x_dim = cfg.getint('Network', 'x_dim')
-        self.z_dim = cfg.getint('Network', 'z_dim')
-
-        self.lr = lr = cfg.getfloat('Training', 'lr')
-        self.optimizer = torch.optim.Adam(self.dvae_model.parameters(), lr=lr)
-        self.save_log = save_log
-
-    def model_training(self, data_gt, data_obs, batch_idx, save_frequence):
-        ### Tensor dimensions
-        ## Initialized tensors
-        # x_mean_dvaeumot_init: (num_source, batch_size, seq_len, x_dim)
-        # x_var_dvaeumot_init: (num_source, batch_size, seq_len, x_dim, x_dim)
-        # Phi_init: (batch_size, seq_len, num_obs, o_dim, o_dim)
-
-        ## Recording tensors
-        # Eta_iter: (N_iter_total, num_source, batch_size, seq_len, num_obs)
-        # x_mean_dvaeumot_iter: (N_iter_total, num_source, batch_size, seq_len, x_dim)
-        # x_var_dvaeumot_iter: (N_iter_total, num_source, batch_size, seq_len, x_dim, x_dim)
-        # Phi_iter: (N_iter_total, batch_size, seq_len, num_obs, o_dim, o_dim)
-        # Phi_inv_iter: (N_iter_total, batch_size, seq_len, num_obs, o_dim, o_dim)
-
-        ## Tensors inside iteration
-        # Eta_n: (batch_size, seq_len, num_obs)
-        # x_mean_dvaeumot_n: (batch_size, seq_len, x_dim)
-        # x_var_dvaeumot_n: (batch_size, seq_len, x_dim, x_dim)
-        # Phi: (batch_size, seq_len, num_obs, o_dim, o_dim)
-        # Phi_inv: (batch_size, seq_len, num_obs, o_dim, o_dim)
-
-        # data_gt: (batch_size, seq_len, num_obs, o_dim)
-        # data_obs: (batch_size, seq_len, num_obs, o_dim)
-
-        # Parameters Initialization
-        data_obs = data_obs.float()
-        batch_size = data_obs.shape[0]
-
-        x_mean_dvaeumot_init, x_var_dvaeumot_init, Phi_init, Phi_inv_init = self.parameters_init_split(data_obs, iter_number=self.init_iter_number, split_len=self.init_subseq_len, std_ratio=self.std_ratio)
-
-        # Initialization of recording tensors
-        Eta_iter = torch.zeros(self.N_iter_total, self.num_source, batch_size, self.seq_len, self.num_obs).to(self.device)
-        x_mean_dvaeumot_iter = torch.zeros(self.N_iter_total, self.num_source, batch_size, self.seq_len, self.x_dim).to(self.device)
-        x_var_dvaeumot_iter = torch.zeros(self.N_iter_total, self.num_source, batch_size, self.seq_len, self.x_dim, self.x_dim).to(self.device)
-        loss_elbo_iter = []
-
-        # Start DVAE_UMOT iterations
-        for i in range(self.N_iter_total):
-            iter_start = datetime.datetime.now()
-            ew_start = datetime.datetime.now()
-            Eta_n_sum = torch.zeros(batch_size, self.seq_len, self.num_obs).to(self.device)
-
-            # E-W Step
-            for n in range(self.num_source):
-                if i == 0:
-                    x_mean_dvaeumot_n = x_mean_dvaeumot_init[n, :, :, :]
-                    x_var_dvaeumot_n = x_var_dvaeumot_init[n, :, :, :, :]
-
-                else:
-                    x_mean_dvaeumot_n = x_mean_dvaeumot_iter[i-1, n, :, :, :]
-                    x_var_dvaeumot_n = x_var_dvaeumot_iter[i-1, n, :, :, :, :]
-                Phi = Phi_init
-                Phi_inv = Phi_inv_init
-                
-                Eta_n = self.compute_eta_n(data_obs, Phi, Phi_inv, x_mean_dvaeumot_n, x_var_dvaeumot_n)
-                Eta_n_tosum = torch.clone(Eta_n)
-                Eta_n_tosum[torch.isnan(Eta_n_tosum)] = 0
-                Eta_n_sum += Eta_n_tosum
-                Eta_iter[i, n, :, :, :] = Eta_n
-            Eta_n_sum = Eta_n_sum.expand(self.num_source, batch_size, self.seq_len, self.num_obs)
-            Eta_iter[i, :, :, :, :] = Eta_iter[i, :, :, :, :] / Eta_n_sum
-            Eta_for_loss = torch.clone(Eta_iter[i, :, :, :, :])
-            Eta_for_loss[torch.isnan(Eta_for_loss)] = 0
-            loss_qw = torch.sum(Eta_for_loss * (Eta_for_loss + 0.0000000001).log()) / (self.num_source * batch_size * self.seq_len * self.num_obs)
-            ew_end = datetime.datetime.now()
-            ew_time = (ew_end - ew_start).seconds / 60
-            print('E-W time {:.2f}m'.format(ew_time))
-
-            # E-S and E-Z Step
-            es_start = datetime.datetime.now()
-            loss_esez = torch.zeros(1).to(self.device)
-            loss_recon = torch.zeros(1).to(self.device)
-            loss_kld = torch.zeros(1).to(self.device)
-            loss_dvae = torch.zeros(1).to(self.device)
-            if self.finetune:
-                for n in range(self.num_source):
-                    if i == 0:
-                        x_dvaeumot_im1_n = x_mean_dvaeumot_init[n, :, :, :]
-                    else:
-                        x_dvaeumot_im1_n = x_mean_dvaeumot_iter[i-1, n, :, :, :]
-                    Phi_inv = Phi_inv_init
-                    Eta_n = Eta_iter[i, n, :, :, :]
-                    frame0 = data_obs[:, 0, n, :]
-                    x_mean_dvaeumot_n, x_var_dvaeumot_n = self.dvae_model(x_dvaeumot_im1_n, Eta_n, Phi_inv, data_obs, frame0, compute_loss=True)
-                    loss_dict = self.dvae_model.loss
-                    
-                    loss_dvae += loss_dict['loss_tot'].detach()
-                    loss_recon += loss_dict['loss_recon'].detach()
-                    loss_kld += loss_dict['loss_KLD'].detach()
-                    loss_esez += loss_dict['loss_tot']
-
-                    x_mean_dvaeumot_iter[i, n, :, :, :] = x_mean_dvaeumot_n.detach()
-                    x_var_dvaeumot_iter[i, n, :, :, :, :] = x_var_dvaeumot_n.detach()               
-                self.optimizer.zero_grad()
-                loss_esez.backward()
-                self.optimizer.step()                
-            else:
-                with torch.no_grad():
-                    for n in range(self.num_source):
-                        if i == 0:
-                            x_dvaeumot_im1_n = x_mean_dvaeumot_init[n, :, :, :]
-                        else:
-                            x_dvaeumot_im1_n = x_mean_dvaeumot_iter[i-1, n, :, :, :]
-                        Phi_inv = Phi_inv_init
-                        Eta_n = Eta_iter[i, n, :, :, :]
-                        frame0 = data_obs[:, 0, n, :]
-                        x_mean_dvaeumot_n, x_var_dvaeumot_n = self.dvae_model(x_dvaeumot_im1_n, Eta_n, Phi_inv, data_obs, frame0, compute_loss=True)
-                        loss_dict = self.dvae_model.loss
-                        
-                        loss_dvae += loss_dict['loss_tot'].detach()
-                        loss_recon += loss_dict['loss_recon'].detach()
-                        loss_kld += loss_dict['loss_KLD'].detach()
-                        loss_esez += loss_dict['loss_tot']
-
-                        x_mean_dvaeumot_iter[i, n, :, :, :] = x_mean_dvaeumot_n.detach()
-                        x_var_dvaeumot_iter[i, n, :, :, :, :] = x_var_dvaeumot_n.detach()
-
-            loss_dvae = loss_dvae / self.num_source
-            loss_recon = loss_recon / self.num_source
-            loss_kld = loss_kld / self.num_source
-            loss_qs = - torch.sum(0.5 * torch.log(torch.det(x_var_dvaeumot_iter[i, :, :, :, :, :].view(batch_size * self.seq_len * self.num_source, self.x_dim, self.x_dim))))
-            loss_qs = loss_qs / (batch_size * self.seq_len * self.num_source)
-
-            es_end = datetime.datetime.now()
-            es_time = (es_end - es_start).seconds / 60
-            print('E-S time {:.2f}m'.format(es_time))
-
-            iter_end = datetime.datetime.now()
-            iter_time = (iter_end - iter_start).seconds / 60
-            print('Iter time {:.2f}m'.format(iter_time))
-
-            # Save the results
-            loss_qw_qs = self.compute_loss_qwqs(Eta_iter[i, :, :, :, :], data_obs, Phi, Phi_inv, x_mean_dvaeumot_iter[i, :, :, :, :], x_var_dvaeumot_iter[i, :, :, :, :, :])
-            loss_elbo = loss_qw_qs + loss_dvae + loss_qw + loss_qs
-            loss_elbo_iter.append({'loss_elbo': float(loss_elbo.to('cpu')), 'loss_qw_qs': float(loss_qw_qs.to('cpu')),
-                                   'loss_dvae': float(loss_dvae.to('cpu')), 'loss_recon': float(loss_recon.to('cpu')),
-                                   'loss_kld': float(loss_kld.to('cpu')), 'loss_qw': float(loss_qw.to('cpu')),
-                                   'loss_qs': float(loss_qs.to('cpu'))})
-            print('loss_elbo: {}'.format(loss_elbo))
-            print('loss_qw_qs: {}'.format(loss_qw_qs))
-            print('loss_dvae: {}'.format(loss_dvae))
-            print('loss_qw: {}'.format(loss_qw))
-            print('loss_qs: {}'.format(loss_qs))
-            result_list = [x_mean_dvaeumot_iter, data_gt, data_obs]
-            if batch_idx % save_frequence == 0:
-                self.save_log.save_dvaeumot_results(batch_idx, result_list)
-
-        return Eta_iter, x_mean_dvaeumot_iter, x_var_dvaeumot_iter
-
-    def parameters_init_split(self, data_obs, iter_number=10, split_len=50, std_ratio=0.04):
-        batch_size = data_obs.shape[0]
-
-        # Initialize the observation variance matrix Phi with the size of detection bounding boxes at the first frame
-        variance_matrix = torch.zeros(batch_size, self.seq_len, self.num_obs, self.o_dim, self.o_dim).to(self.device)
-        variance_matrix_inv = torch.zeros(batch_size, self.seq_len, self.num_obs, self.o_dim, self.o_dim).to(self.device)
-        for i in range(batch_size):
-            for j in range(self.num_obs):
-                std_onesource = torch.zeros(self.o_dim)
-                w = data_obs[i, 0, j, 2] - data_obs[i, 0, j, 0]
-                h = data_obs[i, 0, j, 3] - data_obs[i, 0, j, 1]
-                std_w = w * std_ratio
-                std_h = h * std_ratio
-                std_onesource[0] = std_w
-                std_onesource[2] = std_w
-                std_onesource[1] = std_h
-                std_onesource[3] = std_h
-                variance_matrix_onesource = torch.diag(torch.pow(std_onesource, 2))
-                variance_matrix_onesource_inv = torch.inverse(variance_matrix_onesource)
-                variance_matrix_onesource_seq = variance_matrix_onesource.expand(self.seq_len, self.o_dim, self.o_dim)
-                variance_matrix_onesource_inv_seq = variance_matrix_onesource_inv.expand(self.seq_len, self.o_dim, self.o_dim)
-
-                variance_matrix[i, :, j, :, :] = variance_matrix_onesource_seq
-                variance_matrix_inv[i, :, j, :, :] = variance_matrix_onesource_inv_seq
-
-        x_mean_dvaeumot_init = torch.zeros(self.num_source, batch_size, self.seq_len, self.x_dim).to(self.device)
-        x_var_dvaeumot_init = torch.zeros(self.num_source, batch_size, self.seq_len, self.x_dim, self.x_dim).to(self.device)
-
-        Phi_init = variance_matrix
-        Phi_inv_init = variance_matrix_inv
-
-        # Initialization by sub-sequences
-        start_frame = 0
-        while start_frame < self.seq_len:
-
-            # Pre-Initialization
-            x_mean_dvaeumot_init_split = torch.zeros(self.num_source, batch_size, split_len, self.x_dim).to(self.device)
-            x_var_dvaeumot_init_split = torch.zeros(self.num_source, batch_size, split_len, self.x_dim, self.x_dim).to(
-                self.device)
-            Phi_init_split = variance_matrix[:, start_frame:start_frame+split_len, :, :, :]
-            Phi_inv_init_split = variance_matrix_inv[:, start_frame:start_frame+split_len, :, :, :]
-
-            # Initialize the sequence of s with frame 0
-            for n in range(self.num_source):
-                if start_frame == 0:
-                    frame0 = data_obs[:, 0, n, :]
-                else:
-                    frame0 = x_mean_dvaeumot_iter_split[-1, n, :, -1, :]
-
-                x_var_dvaeumot_init_n = Phi_init_split[:, :, n, :, :].permute(1, 0, 2, 3)
-                x_mean_dvaeumot_init_n = frame0.expand(split_len, batch_size, self.x_dim)
-
-                x_mean_dvaeumot_init_split[n, :, :, :] = x_mean_dvaeumot_init_n.permute(1, 0, 2).squeeze()
-                x_var_dvaeumot_init_split[n, :, :, :, :] = x_var_dvaeumot_init_n.permute(1, 0, 2, 3).squeeze()
-
-            x_mean_dvaeumot_init[:, :, start_frame:start_frame+split_len, :] = x_mean_dvaeumot_init_split
-            x_var_dvaeumot_init[:, :, start_frame:start_frame+split_len, :, :] = x_var_dvaeumot_init_split
-
-            data_obs_split = data_obs[:, start_frame:start_frame+split_len, :, :]
-
-            Eta_iter_split = torch.zeros(iter_number, self.num_source, batch_size, split_len, self.num_obs).to(
-                self.device)
-            x_mean_dvaeumot_iter_split = torch.zeros(iter_number, self.num_source, batch_size, split_len, self.x_dim).to(
-                self.device)
-            x_var_dvaeumot_iter_split = torch.zeros(iter_number, self.num_source, batch_size, split_len, self.x_dim,
-                                         self.x_dim).to(self.device)
-
-            # Run the EM algorithm
-            for i in range(iter_number):
-                # E-W
-                Eta_n_sum = torch.zeros(batch_size, split_len, self.num_obs).to(self.device)
-                for n in range(self.num_source):
-                    if i == 0:
-                        x_mean_dvaeumot_n = x_mean_dvaeumot_init_split[n, :, :, :]
-                        x_var_dvaeumot_n = x_var_dvaeumot_init_split[n, :, :, :, :]
-                    else:
-                        x_mean_dvaeumot_n = x_mean_dvaeumot_iter_split[i - 1, n, :, :, :]
-                        x_var_dvaeumot_n = x_var_dvaeumot_iter_split[i - 1, n, :, :, :, :]                    
-                    Phi = Phi_init_split
-                    Phi_inv = Phi_inv_init_split
-
-                    Eta_n = self.compute_eta_n(data_obs_split, Phi, Phi_inv, x_mean_dvaeumot_n, x_var_dvaeumot_n)
-                    Eta_n_tosum = torch.clone(Eta_n)
-                    Eta_n_tosum[torch.isnan(Eta_n_tosum)] = 0
-                    Eta_n_sum += Eta_n_tosum
-                    Eta_iter_split[i, n, :, :, :] = Eta_n
-                Eta_n_sum = Eta_n_sum.expand(self.num_source, batch_size, split_len, self.num_obs)
-                Eta_iter_split[i, :, :, :, :] = Eta_iter_split[i, :, :, :, :] / Eta_n_sum
-                # E-S/E-Z
-                with torch.no_grad():
-                    for n in range(self.num_source):
-                        if i == 0:
-                            x_dvaeumot_im1_n = x_mean_dvaeumot_init_split[n, :, :, :]
-                        else:
-                            x_dvaeumot_im1_n = x_mean_dvaeumot_iter_split[i - 1, n, :, :, :]
-                            
-                        Phi_inv = Phi_inv_init_split
-                        Eta_n = Eta_iter_split[i, n, :, :, :]
-                        frame0 = x_mean_dvaeumot_init_split[n, :, 0, :]
-                        x_mean_dvaeumot_n, x_var_dvaeumot_n = self.dvae_model(x_dvaeumot_im1_n, Eta_n, Phi_inv, data_obs_split, frame0, compute_loss=False)
-
-                        x_mean_dvaeumot_iter_split[i, n, :, :, :] = x_mean_dvaeumot_n.detach()
-                        x_var_dvaeumot_iter_split[i, n, :, :, :, :] = x_var_dvaeumot_n.detach()
-
-            start_frame += split_len
-
-        return x_mean_dvaeumot_init, x_var_dvaeumot_init, Phi_init, Phi_inv_init
-
-    def compute_eta_n(self, o, Phi, Phi_inv, x_mean_dvaeumot_n, x_var_dvaeumot_n):
-        ### Tensor dimensions
-        # o: (batch_size, seq_len, num_obs, o_dim)
-        # Phi: (batch_size, seq_len, num_obs, o_dim, o_dim)
-        # Phi_inv: (batch_size, seq_len, num_obs, o_dim, o_dim)
-        # x_mean_dvaeumot_n: (batch_size, seq_len, x_dim)
-        # x_var_dvaeumot_n: (batch_size, seq_len, x_dim, x_dim)
-
-        # Eta_n: (batch_size, seq_len, num_obs)
-
-        seq_len = o.shape[1]
-        batch_size = o.shape[0]
-
-        Eta_n = torch.zeros(self.num_obs, batch_size, seq_len).to(self.device)
-
-        for k in range(self.num_obs):
-            Phi_k = Phi[:, :, k, :, :]
-            Phi_inv_k = Phi_inv[:, :, k, :, :]
-            o_k = o[:, :, k, :]
-
-            det_Phi_k = torch.det(Phi_k)
-            det_Phi_k_sqrt = torch.sqrt(det_Phi_k)
-
-            o_ms = o_k.unsqueeze(-1) - x_mean_dvaeumot_n.unsqueeze(-1)
-            o_ms_Phi = torch.matmul(o_ms.squeeze().unsqueeze(-2), Phi_inv_k)
-            o_ms_Phi_sq = torch.matmul(o_ms_Phi, o_ms)
-            o_ms_Phi_sq = 0.008 * o_ms_Phi_sq.squeeze()
-            gaussian_exp_term = torch.exp(-0.5*o_ms_Phi_sq)
-
-            Phi_Sigma = torch.matmul(Phi_inv_k, x_var_dvaeumot_n)
-            trace = 0.008 * torch.diagonal(Phi_Sigma, dim1=-2, dim2=-1).sum(-1)
-            exp_tr_term = torch.exp(-0.5 * trace)
-
-            Beta_n = (1 / det_Phi_k_sqrt) * gaussian_exp_term * exp_tr_term + 0.0000000001
-            Eta_n[k, :, :] = Beta_n / (self.num_source + 1)
-
-        Eta_n = Eta_n.permute(1, 2, 0)
-
-        return Eta_n
-
-    def compute_loss_qwqs(self, eta, o, Phi, Phi_inv, x_mean_dvaeumot, x_var_dvaeumot):
-        batch_size = eta.shape[1]
-        loss_qwqs = torch.zeros(batch_size, self.seq_len).to(self.device)
-        for n in range(self.num_source):
-            x_mean_dvaeumot_n = x_mean_dvaeumot[n, :, :, :]
-            x_var_dvaeumot_n = x_var_dvaeumot[n, :, :, :, :]
-            for k in range(self.num_obs):
-                Phi_k = Phi[:, :, k, :, :]
-                Phi_inv_k = Phi_inv[:, :, k, :, :]
-                o_k = o[:, :, k, :]
-                eta_kn = eta[n, :, :, k]
-
-                det_Phi_k = torch.det(Phi_k)
-                o_ms = o_k.unsqueeze(-1) - x_mean_dvaeumot_n.unsqueeze(-1)
-                o_ms_Phi = torch.matmul(o_ms.squeeze().unsqueeze(-2), Phi_inv_k)
-                o_ms_Phi_sq = torch.matmul(o_ms_Phi, o_ms)
-                o_ms_Phi_sq = o_ms_Phi_sq.squeeze()
-
-                Phi_Sigma = torch.matmul(Phi_inv_k, x_var_dvaeumot_n)
-                trace = torch.diagonal(Phi_Sigma, dim1=-2, dim2=-1).sum(-1)
-
-                loss_qwqs_kn = eta_kn * 0.5 * (torch.log(det_Phi_k) + o_ms_Phi_sq + trace)
-                loss_qwqs_kn[torch.isnan(loss_qwqs_kn)] = 0
-                loss_qwqs += loss_qwqs_kn
-
-        loss_qwqs = torch.sum(loss_qwqs) / (self.num_source * batch_size * self.seq_len * self.num_obs)
-
-        return loss_qwqs
-
-    def compute_phi(self, o, x_mean_dvaeumot, x_var_dvaeumot, Eta, eps=0.000001):
-        ### Tensor dimensions
-        # o: (batch_size, seq_len, num_obs, o_dim)
-        # x_mean_dvaeumot: (num_source, batch_size, seq_len, x_dim)
-        # x_var_dvaeumot: (num_source, batch_size, seq_len, x_dim, x_dim)
-        # Eta: (num_source, batch_size, seq_len, num_obs)
-
-        # Phi: (batch_size, seq_len, num_obs, o_dim, o_dim)
-        # Phi_inv: (batch_size, seq_len, num_obs, o_dim, o_dim)
-
-        batch_size = o.shape[0]
-        seq_len = o.shape[1]
-        num_obs = o.shape[2]
-        Phi = torch.zeros(num_obs, batch_size, seq_len, self.o_dim, self.o_dim).to(self.device)
-        Phi_inv = torch.zeros(num_obs, batch_size, seq_len, self.o_dim, self.o_dim).to(self.device)
-
-        for k in range(num_obs):
-            Phi_k = torch.zeros(batch_size, seq_len, self.o_dim, self.o_dim).to(self.device)
-            o_k = o[:, :, k, :]
-            for n in range(self.num_source):
-                x_mean_dvaeumot_n = x_mean_dvaeumot[n, :, :, :]
-                x_var_dvaeumot_n = x_var_dvaeumot[n, :, :, :, :]
-                o_ms = o_k.unsqueeze(-1) - x_mean_dvaeumot_n.unsqueeze(-1)
-                o_ms_sq = torch.matmul(o_ms, o_ms.squeeze().unsqueeze(-2))
-
-                one_source = (Eta[n, :, :, k].unsqueeze(-1) * (x_var_dvaeumot_n + o_ms_sq).view(batch_size, seq_len, self.o_dim*self.o_dim)).view(batch_size, seq_len, self.o_dim, self.o_dim)
-                Phi_k += one_source
-
-            Phi_k = Phi_k + eps * torch.eye(self.o_dim).to(self.device)
-            Phi[k, :, :, :, :] = Phi_k
-            try:
-                for i in range(batch_size):
-                    for t in range(seq_len):
-                        if Phi_k[i, t].sum().isnan():
-                            Phi_inv[k, i, t, :, :] = Phi_k[i, t]
-                        else:
-                            u = torch.cholesky(Phi_k[i, t])
-                            Phi_inv[k, i, t, :, :] = torch.cholesky_inverse(u)
-            except RuntimeError:
-                print('Phi: {}'.format(Phi_k))
-                print('Phi_inv: {}'.format(Phi_inv[k, :, :, :, :]))
-                print('o_ms_sq: {}'.format(o_ms_sq))
-                print('o: {}'.format(o[i, t, k, :]))
-                print('x_mean_dvaeumot: {}'.format(x_mean_dvaeumot[n, i, t, :]))
-                print('Eta: {}'.format(Eta[n, i, t, k]))
-                print('x_var_dvaeumot: {}'.format(x_var_dvaeumot[n, i, t, :, :]))
-        
-        Phi = Phi.permute(1, 2, 0, 3, 4)
-        Phi_inv = Phi_inv.permute(1, 2, 0, 3, 4)
-
-        return Phi, Phi_inv
-
-
-
-
-
-
-
-
-
-
-
+## DVAE-UMOT
+## Copyright Inria
+## Year 2022
+## Contact : xiaoyu.lin@inria.fr
+
+## DVAE-UMOT is free software: you can redistribute it and/or modify
+## it under the terms of the GNU General Public License as published by
+## the Free Software Foundation, either version 3 of the License, or
+## (at your option) any later version.
+
+## DVAE-UMOT is distributed in the hope that it will be useful,
+## but WITHOUT ANY WARRANTY; without even the implied warranty of
+## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+## GNU General Public License for more details.
+##
+## You should have received a copy of the GNU General Public License
+## along with this program, DVAE-UMOT.  If not, see <http://www.gnu.org/licenses/> and the LICENSE file.
+
+# DVAE-UMOT has code derived from 
+# (1) ArTIST, https://github.com/fatemeh-slh/ArTIST.
+# (2) DVAE, https://github.com/XiaoyuBIE1994/DVAE, distributed under MIT License 2020 INRIA.
+
+import datetime
+import torch
+
+class DVAE_UMOT_MODEL():
+    def __init__(self, cfg, device, dvae_model, save_log):
+        self.device = device
+
+        self.N_iter_total = cfg.getint('DVAE_UMOT', 'N_iter_total')
+        self.num_source = cfg.getint('DVAE_UMOT', 'num_source')
+        self.num_obs = cfg.getint('DVAE_UMOT', 'num_obs')
+        self.std_ratio = cfg.getfloat('DVAE_UMOT', 'std_ratio')
+        self.init_iter_number = cfg.getint('DVAE_UMOT', 'init_iter_number')
+        self.init_subseq_len = cfg.getint('DVAE_UMOT', 'init_subseq_len')
+        self.batch_size = cfg.getint('Training', 'batch_size')
+        self.seq_len = cfg.getint('DataFrame', 'sequence_len')
+        self.finetune = cfg.getboolean('Training', 'finetune')
+
+        self.dvae_model = dvae_model
+
+        self.o_dim = cfg.getint('DVAE_UMOT', 'o_dim')
+        self.x_dim = cfg.getint('Network', 'x_dim')
+        self.z_dim = cfg.getint('Network', 'z_dim')
+
+        self.lr = lr = cfg.getfloat('Training', 'lr')
+        self.optimizer = torch.optim.Adam(self.dvae_model.parameters(), lr=lr)
+        self.save_log = save_log
+
+    def model_training(self, data_gt, data_obs, batch_idx, save_frequence):
+        ### Tensor dimensions
+        ## Initialized tensors
+        # x_mean_dvaeumot_init: (num_source, batch_size, seq_len, x_dim)
+        # x_var_dvaeumot_init: (num_source, batch_size, seq_len, x_dim, x_dim)
+        # Phi_init: (batch_size, seq_len, num_obs, o_dim, o_dim)
+
+        ## Recording tensors
+        # Eta_iter: (N_iter_total, num_source, batch_size, seq_len, num_obs)
+        # x_mean_dvaeumot_iter: (N_iter_total, num_source, batch_size, seq_len, x_dim)
+        # x_var_dvaeumot_iter: (N_iter_total, num_source, batch_size, seq_len, x_dim, x_dim)
+        # Phi_iter: (N_iter_total, batch_size, seq_len, num_obs, o_dim, o_dim)
+        # Phi_inv_iter: (N_iter_total, batch_size, seq_len, num_obs, o_dim, o_dim)
+
+        ## Tensors inside iteration
+        # Eta_n: (batch_size, seq_len, num_obs)
+        # x_mean_dvaeumot_n: (batch_size, seq_len, x_dim)
+        # x_var_dvaeumot_n: (batch_size, seq_len, x_dim, x_dim)
+        # Phi: (batch_size, seq_len, num_obs, o_dim, o_dim)
+        # Phi_inv: (batch_size, seq_len, num_obs, o_dim, o_dim)
+
+        # data_gt: (batch_size, seq_len, num_obs, o_dim)
+        # data_obs: (batch_size, seq_len, num_obs, o_dim)
+
+        # Parameters Initialization
+        data_obs = data_obs.float()
+        batch_size = data_obs.shape[0]
+
+        x_mean_dvaeumot_init, x_var_dvaeumot_init, Phi_init, Phi_inv_init = self.parameters_init_split(data_obs, iter_number=self.init_iter_number, split_len=self.init_subseq_len, std_ratio=self.std_ratio)
+
+        # Initialization of recording tensors
+        Eta_iter = torch.zeros(self.N_iter_total, self.num_source, batch_size, self.seq_len, self.num_obs).to(self.device)
+        x_mean_dvaeumot_iter = torch.zeros(self.N_iter_total, self.num_source, batch_size, self.seq_len, self.x_dim).to(self.device)
+        x_var_dvaeumot_iter = torch.zeros(self.N_iter_total, self.num_source, batch_size, self.seq_len, self.x_dim, self.x_dim).to(self.device)
+        loss_elbo_iter = []
+
+        # Start DVAE_UMOT iterations
+        for i in range(self.N_iter_total):
+            iter_start = datetime.datetime.now()
+            ew_start = datetime.datetime.now()
+            Eta_n_sum = torch.zeros(batch_size, self.seq_len, self.num_obs).to(self.device)
+
+            # E-W Step
+            for n in range(self.num_source):
+                if i == 0:
+                    x_mean_dvaeumot_n = x_mean_dvaeumot_init[n, :, :, :]
+                    x_var_dvaeumot_n = x_var_dvaeumot_init[n, :, :, :, :]
+
+                else:
+                    x_mean_dvaeumot_n = x_mean_dvaeumot_iter[i-1, n, :, :, :]
+                    x_var_dvaeumot_n = x_var_dvaeumot_iter[i-1, n, :, :, :, :]
+                Phi = Phi_init
+                Phi_inv = Phi_inv_init
+                
+                Eta_n = self.compute_eta_n(data_obs, Phi, Phi_inv, x_mean_dvaeumot_n, x_var_dvaeumot_n)
+                Eta_n_tosum = torch.clone(Eta_n)
+                Eta_n_tosum[torch.isnan(Eta_n_tosum)] = 0
+                Eta_n_sum += Eta_n_tosum
+                Eta_iter[i, n, :, :, :] = Eta_n
+            Eta_n_sum = Eta_n_sum.expand(self.num_source, batch_size, self.seq_len, self.num_obs)
+            Eta_iter[i, :, :, :, :] = Eta_iter[i, :, :, :, :] / Eta_n_sum
+            Eta_for_loss = torch.clone(Eta_iter[i, :, :, :, :])
+            Eta_for_loss[torch.isnan(Eta_for_loss)] = 0
+            loss_qw = torch.sum(Eta_for_loss * (Eta_for_loss + 0.0000000001).log()) / (self.num_source * batch_size * self.seq_len * self.num_obs)
+            ew_end = datetime.datetime.now()
+            ew_time = (ew_end - ew_start).seconds / 60
+            print('E-W time {:.2f}m'.format(ew_time))
+
+            # E-S and E-Z Step
+            es_start = datetime.datetime.now()
+            loss_esez = torch.zeros(1).to(self.device)
+            loss_recon = torch.zeros(1).to(self.device)
+            loss_kld = torch.zeros(1).to(self.device)
+            loss_dvae = torch.zeros(1).to(self.device)
+            if self.finetune:
+                for n in range(self.num_source):
+                    if i == 0:
+                        x_dvaeumot_im1_n = x_mean_dvaeumot_init[n, :, :, :]
+                    else:
+                        x_dvaeumot_im1_n = x_mean_dvaeumot_iter[i-1, n, :, :, :]
+                    Phi_inv = Phi_inv_init
+                    Eta_n = Eta_iter[i, n, :, :, :]
+                    frame0 = data_obs[:, 0, n, :]
+                    x_mean_dvaeumot_n, x_var_dvaeumot_n = self.dvae_model(x_dvaeumot_im1_n, Eta_n, Phi_inv, data_obs, frame0, compute_loss=True)
+                    loss_dict = self.dvae_model.loss
+                    
+                    loss_dvae += loss_dict['loss_tot'].detach()
+                    loss_recon += loss_dict['loss_recon'].detach()
+                    loss_kld += loss_dict['loss_KLD'].detach()
+                    loss_esez += loss_dict['loss_tot']
+
+                    x_mean_dvaeumot_iter[i, n, :, :, :] = x_mean_dvaeumot_n.detach()
+                    x_var_dvaeumot_iter[i, n, :, :, :, :] = x_var_dvaeumot_n.detach()               
+                self.optimizer.zero_grad()
+                loss_esez.backward()
+                self.optimizer.step()                
+            else:
+                with torch.no_grad():
+                    for n in range(self.num_source):
+                        if i == 0:
+                            x_dvaeumot_im1_n = x_mean_dvaeumot_init[n, :, :, :]
+                        else:
+                            x_dvaeumot_im1_n = x_mean_dvaeumot_iter[i-1, n, :, :, :]
+                        Phi_inv = Phi_inv_init
+                        Eta_n = Eta_iter[i, n, :, :, :]
+                        frame0 = data_obs[:, 0, n, :]
+                        x_mean_dvaeumot_n, x_var_dvaeumot_n = self.dvae_model(x_dvaeumot_im1_n, Eta_n, Phi_inv, data_obs, frame0, compute_loss=True)
+                        loss_dict = self.dvae_model.loss
+                        
+                        loss_dvae += loss_dict['loss_tot'].detach()
+                        loss_recon += loss_dict['loss_recon'].detach()
+                        loss_kld += loss_dict['loss_KLD'].detach()
+                        loss_esez += loss_dict['loss_tot']
+
+                        x_mean_dvaeumot_iter[i, n, :, :, :] = x_mean_dvaeumot_n.detach()
+                        x_var_dvaeumot_iter[i, n, :, :, :, :] = x_var_dvaeumot_n.detach()
+
+            loss_dvae = loss_dvae / self.num_source
+            loss_recon = loss_recon / self.num_source
+            loss_kld = loss_kld / self.num_source
+            loss_qs = - torch.sum(0.5 * torch.log(torch.det(x_var_dvaeumot_iter[i, :, :, :, :, :].view(batch_size * self.seq_len * self.num_source, self.x_dim, self.x_dim))))
+            loss_qs = loss_qs / (batch_size * self.seq_len * self.num_source)
+
+            es_end = datetime.datetime.now()
+            es_time = (es_end - es_start).seconds / 60
+            print('E-S time {:.2f}m'.format(es_time))
+
+            iter_end = datetime.datetime.now()
+            iter_time = (iter_end - iter_start).seconds / 60
+            print('Iter time {:.2f}m'.format(iter_time))
+
+            # Save the results
+            loss_qw_qs = self.compute_loss_qwqs(Eta_iter[i, :, :, :, :], data_obs, Phi, Phi_inv, x_mean_dvaeumot_iter[i, :, :, :, :], x_var_dvaeumot_iter[i, :, :, :, :, :])
+            loss_elbo = loss_qw_qs + loss_dvae + loss_qw + loss_qs
+            loss_elbo_iter.append({'loss_elbo': float(loss_elbo.to('cpu')), 'loss_qw_qs': float(loss_qw_qs.to('cpu')),
+                                   'loss_dvae': float(loss_dvae.to('cpu')), 'loss_recon': float(loss_recon.to('cpu')),
+                                   'loss_kld': float(loss_kld.to('cpu')), 'loss_qw': float(loss_qw.to('cpu')),
+                                   'loss_qs': float(loss_qs.to('cpu'))})
+            print('loss_elbo: {}'.format(loss_elbo))
+            print('loss_qw_qs: {}'.format(loss_qw_qs))
+            print('loss_dvae: {}'.format(loss_dvae))
+            print('loss_qw: {}'.format(loss_qw))
+            print('loss_qs: {}'.format(loss_qs))
+            result_list = [x_mean_dvaeumot_iter, data_gt, data_obs]
+            if batch_idx % save_frequence == 0:
+                self.save_log.save_dvaeumot_results(batch_idx, result_list)
+
+        return Eta_iter, x_mean_dvaeumot_iter, x_var_dvaeumot_iter
+
+    def parameters_init_split(self, data_obs, iter_number=10, split_len=50, std_ratio=0.04):
+        batch_size = data_obs.shape[0]
+
+        # Initialize the observation variance matrix Phi with the size of detection bounding boxes at the first frame
+        variance_matrix = torch.zeros(batch_size, self.seq_len, self.num_obs, self.o_dim, self.o_dim).to(self.device)
+        variance_matrix_inv = torch.zeros(batch_size, self.seq_len, self.num_obs, self.o_dim, self.o_dim).to(self.device)
+        for i in range(batch_size):
+            for j in range(self.num_obs):
+                std_onesource = torch.zeros(self.o_dim)
+                w = data_obs[i, 0, j, 2] - data_obs[i, 0, j, 0]
+                h = data_obs[i, 0, j, 3] - data_obs[i, 0, j, 1]
+                std_w = w * std_ratio
+                std_h = h * std_ratio
+                std_onesource[0] = std_w
+                std_onesource[2] = std_w
+                std_onesource[1] = std_h
+                std_onesource[3] = std_h
+                variance_matrix_onesource = torch.diag(torch.pow(std_onesource, 2))
+                variance_matrix_onesource_inv = torch.inverse(variance_matrix_onesource)
+                variance_matrix_onesource_seq = variance_matrix_onesource.expand(self.seq_len, self.o_dim, self.o_dim)
+                variance_matrix_onesource_inv_seq = variance_matrix_onesource_inv.expand(self.seq_len, self.o_dim, self.o_dim)
+
+                variance_matrix[i, :, j, :, :] = variance_matrix_onesource_seq
+                variance_matrix_inv[i, :, j, :, :] = variance_matrix_onesource_inv_seq
+
+        x_mean_dvaeumot_init = torch.zeros(self.num_source, batch_size, self.seq_len, self.x_dim).to(self.device)
+        x_var_dvaeumot_init = torch.zeros(self.num_source, batch_size, self.seq_len, self.x_dim, self.x_dim).to(self.device)
+
+        Phi_init = variance_matrix
+        Phi_inv_init = variance_matrix_inv
+
+        # Initialization by sub-sequences
+        start_frame = 0
+        while start_frame < self.seq_len:
+
+            # Pre-Initialization
+            x_mean_dvaeumot_init_split = torch.zeros(self.num_source, batch_size, split_len, self.x_dim).to(self.device)
+            x_var_dvaeumot_init_split = torch.zeros(self.num_source, batch_size, split_len, self.x_dim, self.x_dim).to(
+                self.device)
+            Phi_init_split = variance_matrix[:, start_frame:start_frame+split_len, :, :, :]
+            Phi_inv_init_split = variance_matrix_inv[:, start_frame:start_frame+split_len, :, :, :]
+
+            # Initialize the sequence of s with frame 0
+            for n in range(self.num_source):
+                if start_frame == 0:
+                    frame0 = data_obs[:, 0, n, :]
+                else:
+                    frame0 = x_mean_dvaeumot_iter_split[-1, n, :, -1, :]
+
+                x_var_dvaeumot_init_n = Phi_init_split[:, :, n, :, :].permute(1, 0, 2, 3)
+                x_mean_dvaeumot_init_n = frame0.expand(split_len, batch_size, self.x_dim)
+
+                x_mean_dvaeumot_init_split[n, :, :, :] = x_mean_dvaeumot_init_n.permute(1, 0, 2).squeeze()
+                x_var_dvaeumot_init_split[n, :, :, :, :] = x_var_dvaeumot_init_n.permute(1, 0, 2, 3).squeeze()
+
+            x_mean_dvaeumot_init[:, :, start_frame:start_frame+split_len, :] = x_mean_dvaeumot_init_split
+            x_var_dvaeumot_init[:, :, start_frame:start_frame+split_len, :, :] = x_var_dvaeumot_init_split
+
+            data_obs_split = data_obs[:, start_frame:start_frame+split_len, :, :]
+
+            Eta_iter_split = torch.zeros(iter_number, self.num_source, batch_size, split_len, self.num_obs).to(
+                self.device)
+            x_mean_dvaeumot_iter_split = torch.zeros(iter_number, self.num_source, batch_size, split_len, self.x_dim).to(
+                self.device)
+            x_var_dvaeumot_iter_split = torch.zeros(iter_number, self.num_source, batch_size, split_len, self.x_dim,
+                                         self.x_dim).to(self.device)
+
+            # Run the EM algorithm
+            for i in range(iter_number):
+                # E-W
+                Eta_n_sum = torch.zeros(batch_size, split_len, self.num_obs).to(self.device)
+                for n in range(self.num_source):
+                    if i == 0:
+                        x_mean_dvaeumot_n = x_mean_dvaeumot_init_split[n, :, :, :]
+                        x_var_dvaeumot_n = x_var_dvaeumot_init_split[n, :, :, :, :]
+                    else:
+                        x_mean_dvaeumot_n = x_mean_dvaeumot_iter_split[i - 1, n, :, :, :]
+                        x_var_dvaeumot_n = x_var_dvaeumot_iter_split[i - 1, n, :, :, :, :]                    
+                    Phi = Phi_init_split
+                    Phi_inv = Phi_inv_init_split
+
+                    Eta_n = self.compute_eta_n(data_obs_split, Phi, Phi_inv, x_mean_dvaeumot_n, x_var_dvaeumot_n)
+                    Eta_n_tosum = torch.clone(Eta_n)
+                    Eta_n_tosum[torch.isnan(Eta_n_tosum)] = 0
+                    Eta_n_sum += Eta_n_tosum
+                    Eta_iter_split[i, n, :, :, :] = Eta_n
+                Eta_n_sum = Eta_n_sum.expand(self.num_source, batch_size, split_len, self.num_obs)
+                Eta_iter_split[i, :, :, :, :] = Eta_iter_split[i, :, :, :, :] / Eta_n_sum
+                # E-S/E-Z
+                with torch.no_grad():
+                    for n in range(self.num_source):
+                        if i == 0:
+                            x_dvaeumot_im1_n = x_mean_dvaeumot_init_split[n, :, :, :]
+                        else:
+                            x_dvaeumot_im1_n = x_mean_dvaeumot_iter_split[i - 1, n, :, :, :]
+                            
+                        Phi_inv = Phi_inv_init_split
+                        Eta_n = Eta_iter_split[i, n, :, :, :]
+                        frame0 = x_mean_dvaeumot_init_split[n, :, 0, :]
+                        x_mean_dvaeumot_n, x_var_dvaeumot_n = self.dvae_model(x_dvaeumot_im1_n, Eta_n, Phi_inv, data_obs_split, frame0, compute_loss=False)
+
+                        x_mean_dvaeumot_iter_split[i, n, :, :, :] = x_mean_dvaeumot_n.detach()
+                        x_var_dvaeumot_iter_split[i, n, :, :, :, :] = x_var_dvaeumot_n.detach()
+
+            start_frame += split_len
+
+        return x_mean_dvaeumot_init, x_var_dvaeumot_init, Phi_init, Phi_inv_init
+
+    def compute_eta_n(self, o, Phi, Phi_inv, x_mean_dvaeumot_n, x_var_dvaeumot_n):
+        ### Tensor dimensions
+        # o: (batch_size, seq_len, num_obs, o_dim)
+        # Phi: (batch_size, seq_len, num_obs, o_dim, o_dim)
+        # Phi_inv: (batch_size, seq_len, num_obs, o_dim, o_dim)
+        # x_mean_dvaeumot_n: (batch_size, seq_len, x_dim)
+        # x_var_dvaeumot_n: (batch_size, seq_len, x_dim, x_dim)
+
+        # Eta_n: (batch_size, seq_len, num_obs)
+
+        seq_len = o.shape[1]
+        batch_size = o.shape[0]
+
+        Eta_n = torch.zeros(self.num_obs, batch_size, seq_len).to(self.device)
+
+        for k in range(self.num_obs):
+            Phi_k = Phi[:, :, k, :, :]
+            Phi_inv_k = Phi_inv[:, :, k, :, :]
+            o_k = o[:, :, k, :]
+
+            det_Phi_k = torch.det(Phi_k)
+            det_Phi_k_sqrt = torch.sqrt(det_Phi_k)
+
+            o_ms = o_k.unsqueeze(-1) - x_mean_dvaeumot_n.unsqueeze(-1)
+            o_ms_Phi = torch.matmul(o_ms.squeeze().unsqueeze(-2), Phi_inv_k)
+            o_ms_Phi_sq = torch.matmul(o_ms_Phi, o_ms)
+            o_ms_Phi_sq = 0.008 * o_ms_Phi_sq.squeeze()
+            gaussian_exp_term = torch.exp(-0.5*o_ms_Phi_sq)
+
+            Phi_Sigma = torch.matmul(Phi_inv_k, x_var_dvaeumot_n)
+            trace = 0.008 * torch.diagonal(Phi_Sigma, dim1=-2, dim2=-1).sum(-1)
+            exp_tr_term = torch.exp(-0.5 * trace)
+
+            Beta_n = (1 / det_Phi_k_sqrt) * gaussian_exp_term * exp_tr_term + 0.0000000001
+            Eta_n[k, :, :] = Beta_n / (self.num_source + 1)
+
+        Eta_n = Eta_n.permute(1, 2, 0)
+
+        return Eta_n
+
+    def compute_loss_qwqs(self, eta, o, Phi, Phi_inv, x_mean_dvaeumot, x_var_dvaeumot):
+        batch_size = eta.shape[1]
+        loss_qwqs = torch.zeros(batch_size, self.seq_len).to(self.device)
+        for n in range(self.num_source):
+            x_mean_dvaeumot_n = x_mean_dvaeumot[n, :, :, :]
+            x_var_dvaeumot_n = x_var_dvaeumot[n, :, :, :, :]
+            for k in range(self.num_obs):
+                Phi_k = Phi[:, :, k, :, :]
+                Phi_inv_k = Phi_inv[:, :, k, :, :]
+                o_k = o[:, :, k, :]
+                eta_kn = eta[n, :, :, k]
+
+                det_Phi_k = torch.det(Phi_k)
+                o_ms = o_k.unsqueeze(-1) - x_mean_dvaeumot_n.unsqueeze(-1)
+                o_ms_Phi = torch.matmul(o_ms.squeeze().unsqueeze(-2), Phi_inv_k)
+                o_ms_Phi_sq = torch.matmul(o_ms_Phi, o_ms)
+                o_ms_Phi_sq = o_ms_Phi_sq.squeeze()
+
+                Phi_Sigma = torch.matmul(Phi_inv_k, x_var_dvaeumot_n)
+                trace = torch.diagonal(Phi_Sigma, dim1=-2, dim2=-1).sum(-1)
+
+                loss_qwqs_kn = eta_kn * 0.5 * (torch.log(det_Phi_k) + o_ms_Phi_sq + trace)
+                loss_qwqs_kn[torch.isnan(loss_qwqs_kn)] = 0
+                loss_qwqs += loss_qwqs_kn
+
+        loss_qwqs = torch.sum(loss_qwqs) / (self.num_source * batch_size * self.seq_len * self.num_obs)
+
+        return loss_qwqs
+
+    def compute_phi(self, o, x_mean_dvaeumot, x_var_dvaeumot, Eta, eps=0.000001):
+        ### Tensor dimensions
+        # o: (batch_size, seq_len, num_obs, o_dim)
+        # x_mean_dvaeumot: (num_source, batch_size, seq_len, x_dim)
+        # x_var_dvaeumot: (num_source, batch_size, seq_len, x_dim, x_dim)
+        # Eta: (num_source, batch_size, seq_len, num_obs)
+
+        # Phi: (batch_size, seq_len, num_obs, o_dim, o_dim)
+        # Phi_inv: (batch_size, seq_len, num_obs, o_dim, o_dim)
+
+        batch_size = o.shape[0]
+        seq_len = o.shape[1]
+        num_obs = o.shape[2]
+        Phi = torch.zeros(num_obs, batch_size, seq_len, self.o_dim, self.o_dim).to(self.device)
+        Phi_inv = torch.zeros(num_obs, batch_size, seq_len, self.o_dim, self.o_dim).to(self.device)
+
+        for k in range(num_obs):
+            Phi_k = torch.zeros(batch_size, seq_len, self.o_dim, self.o_dim).to(self.device)
+            o_k = o[:, :, k, :]
+            for n in range(self.num_source):
+                x_mean_dvaeumot_n = x_mean_dvaeumot[n, :, :, :]
+                x_var_dvaeumot_n = x_var_dvaeumot[n, :, :, :, :]
+                o_ms = o_k.unsqueeze(-1) - x_mean_dvaeumot_n.unsqueeze(-1)
+                o_ms_sq = torch.matmul(o_ms, o_ms.squeeze().unsqueeze(-2))
+
+                one_source = (Eta[n, :, :, k].unsqueeze(-1) * (x_var_dvaeumot_n + o_ms_sq).view(batch_size, seq_len, self.o_dim*self.o_dim)).view(batch_size, seq_len, self.o_dim, self.o_dim)
+                Phi_k += one_source
+
+            Phi_k = Phi_k + eps * torch.eye(self.o_dim).to(self.device)
+            Phi[k, :, :, :, :] = Phi_k
+            try:
+                for i in range(batch_size):
+                    for t in range(seq_len):
+                        if Phi_k[i, t].sum().isnan():
+                            Phi_inv[k, i, t, :, :] = Phi_k[i, t]
+                        else:
+                            u = torch.cholesky(Phi_k[i, t])
+                            Phi_inv[k, i, t, :, :] = torch.cholesky_inverse(u)
+            except RuntimeError:
+                print('Phi: {}'.format(Phi_k))
+                print('Phi_inv: {}'.format(Phi_inv[k, :, :, :, :]))
+                print('o_ms_sq: {}'.format(o_ms_sq))
+                print('o: {}'.format(o[i, t, k, :]))
+                print('x_mean_dvaeumot: {}'.format(x_mean_dvaeumot[n, i, t, :]))
+                print('Eta: {}'.format(Eta[n, i, t, k]))
+                print('x_var_dvaeumot: {}'.format(x_var_dvaeumot[n, i, t, :, :]))
+        
+        Phi = Phi.permute(1, 2, 0, 3, 4)
+        Phi_inv = Phi_inv.permute(1, 2, 0, 3, 4)
+
+        return Phi, Phi_inv
+
+
+
+
+
+
+
+
+
+
+