From 2016a13fd77312423acde6973e26c9f68e01b40f Mon Sep 17 00:00:00 2001
From: LIN Xiaoyu <xiaoyu.lin@inria.fr>
Date: Fri, 25 Feb 2022 08:42:07 +0000
Subject: [PATCH] Update KF_tracking.py

---
 KF_tracking.py | 200 +++++++++++++++++++++++++++----------------------
 1 file changed, 111 insertions(+), 89 deletions(-)

diff --git a/KF_tracking.py b/KF_tracking.py
index 883b839..34ac6e6 100644
--- a/KF_tracking.py
+++ b/KF_tracking.py
@@ -1,89 +1,111 @@
-import sys
-import os
-import torch
-import shutil
-import numpy as np
-from configparser import ConfigParser
-from data.data_loader import create_dataloader
-from save_model import SaveLog
-from models.vem_KF import VEM_KF_MODEL
-from utils import tracking_evaluation_onebatch_KF
-import motmetrics as mm
-
-from utils import get_basic_info
-
-
-def train(cfg_file):
-    # Read the config file
-    if not os.path.isfile(cfg_file):
-        raise ValueError('Invalid config file path')
-    cfg = ConfigParser()
-    cfg.read(cfg_file)
-
-    # Create save log directory
-    save_log = SaveLog(cfg)
-    save_dir = save_log.save_dir
-
-    # Save config file
-    save_cfg_path = os.path.join(save_dir, 'config.ini')
-    shutil.copy(cfg_file, save_cfg_path)
-
-    # Print basic information
-    use_cuda = cfg.getboolean('Training', 'use_cuda')
-    device = 'cuda' if torch.cuda.is_available() and use_cuda else 'cpu'
-
-    basic_info = get_basic_info(device)
-    save_log.print_info(basic_info)
-    for info in basic_info:
-        print('%s' % info)
-
-    # Create and initialize model
-    vem_model = VEM_KF_MODEL(cfg, device, save_log)
-
-    # Load data
-    vem_data_loader, vem_data_size = create_dataloader(cfg, data_type='mot')
-
-    # Print data information
-    data_info = []
-    data_info.append('========== DATA INFO ==========')
-    data_info.append('Training data: %s' % vem_data_size)
-    save_log.print_info(data_info)
-    for info in data_info:
-        print('%s' % info)
-
-    # Start training
-    print('Start training...')
-    total_iter = int(cfg.get('VEM', 'N_iter'))
-    save_frequency = int(cfg.get('Training', 'save_frequency'))
-    normalize_range = np.array([int(i) for i in cfg.get('DataFrame', 'normalize_range').split(',')], dtype='float64').reshape(-1,4)
-    acc_list = [[] for i in range(total_iter)]
-    for idx, data in enumerate(vem_data_loader):
-        print('batch {}\n'.format(idx))
-        data_obs = data['det'].to(device)
-        data_gt = data['gt'].to('cpu')
-        Eta_iter, x_mean_vem_iter, x_var_vem_iter, Lambda_iter\
-            = vem_model.model_training(data_obs, data_gt, idx, save_frequency)
-
-        acc_list = tracking_evaluation_onebatch_KF(data_gt, normalize_range, acc_list, Eta_iter, x_mean_vem_iter)
-
-    summary_list = []
-    mota_list = [[] for i in range(total_iter)]
-    for iter_number in range(total_iter):
-        mh = mm.metrics.create()
-        name = ['sample_{}'.format(i) for i in range(vem_data_size)]
-        summary = mh.compute_many(acc_list[iter_number], metrics=mm.metrics.motchallenge_metrics, names=name, generate_overall=True)
-        mota_list[iter_number].append(summary.loc['OVERALL']['mota'])
-        strsummary = mm.io.render_summary(
-            summary,
-            formatters=mh.formatters,
-            namemap=mm.io.motchallenge_metric_names
-        )
-        summary_list.append(strsummary)
-    save_log.save_evaluation(summary_list, mota_list, total_iter)
-
-if __name__ == '__main__':
-    if len(sys.argv) == 2:
-        cfg_file = sys.argv[1]
-        train(cfg_file)
-    else:
-        print('Error: Please indicate config file path')
\ No newline at end of file
+## DVAE-UMOT
+## Copyright Inria
+## Year 2022
+## Contact : xiaoyu.lin@inria.fr
+
+## DVAE-UMOT is free software: you can redistribute it and/or modify
+## it under the terms of the GNU General Public License as published by
+## the Free Software Foundation, either version 3 of the License, or
+## (at your option) any later version.
+
+## DVAE-UMOT is distributed in the hope that it will be useful,
+## but WITHOUT ANY WARRANTY; without even the implied warranty of
+## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+## GNU General Public License for more details.
+##
+## You should have received a copy of the GNU General Public License
+## along with this program, DVAE-UMOT.  If not, see <http://www.gnu.org/licenses/> and the LICENSE file.
+
+# DVAE-UMOT has code derived from 
+# (1) ArTIST, https://github.com/fatemeh-slh/ArTIST.
+# (2) DVAE, https://github.com/XiaoyuBIE1994/DVAE, distributed under MIT License 2020 INRIA.
+
+import sys
+import os
+import torch
+import shutil
+import numpy as np
+from configparser import ConfigParser
+from data.data_loader import create_dataloader
+from save_model import SaveLog
+from models.vem_KF import VEM_KF_MODEL
+from utils import tracking_evaluation_onebatch_KF
+import motmetrics as mm
+
+from utils import get_basic_info
+
+
+def train(cfg_file):
+    # Read the config file
+    if not os.path.isfile(cfg_file):
+        raise ValueError('Invalid config file path')
+    cfg = ConfigParser()
+    cfg.read(cfg_file)
+
+    # Create save log directory
+    save_log = SaveLog(cfg)
+    save_dir = save_log.save_dir
+
+    # Save config file
+    save_cfg_path = os.path.join(save_dir, 'config.ini')
+    shutil.copy(cfg_file, save_cfg_path)
+
+    # Print basic information
+    use_cuda = cfg.getboolean('Training', 'use_cuda')
+    device = 'cuda' if torch.cuda.is_available() and use_cuda else 'cpu'
+
+    basic_info = get_basic_info(device)
+    save_log.print_info(basic_info)
+    for info in basic_info:
+        print('%s' % info)
+
+    # Create and initialize model
+    vem_model = VEM_KF_MODEL(cfg, device, save_log)
+
+    # Load data
+    vem_data_loader, vem_data_size = create_dataloader(cfg, data_type='mot')
+
+    # Print data information
+    data_info = []
+    data_info.append('========== DATA INFO ==========')
+    data_info.append('Training data: %s' % vem_data_size)
+    save_log.print_info(data_info)
+    for info in data_info:
+        print('%s' % info)
+
+    # Start training
+    print('Start training...')
+    total_iter = int(cfg.get('VEM', 'N_iter'))
+    save_frequency = int(cfg.get('Training', 'save_frequency'))
+    normalize_range = np.array([int(i) for i in cfg.get('DataFrame', 'normalize_range').split(',')], dtype='float64').reshape(-1,4)
+    acc_list = [[] for i in range(total_iter)]
+    for idx, data in enumerate(vem_data_loader):
+        print('batch {}\n'.format(idx))
+        data_obs = data['det'].to(device)
+        data_gt = data['gt'].to('cpu')
+        Eta_iter, x_mean_vem_iter, x_var_vem_iter, Lambda_iter\
+            = vem_model.model_training(data_obs, data_gt, idx, save_frequency)
+
+        acc_list = tracking_evaluation_onebatch_KF(data_gt, normalize_range, acc_list, Eta_iter, x_mean_vem_iter)
+
+    summary_list = []
+    mota_list = [[] for i in range(total_iter)]
+    for iter_number in range(total_iter):
+        mh = mm.metrics.create()
+        name = ['sample_{}'.format(i) for i in range(vem_data_size)]
+        summary = mh.compute_many(acc_list[iter_number], metrics=mm.metrics.motchallenge_metrics, names=name, generate_overall=True)
+        mota_list[iter_number].append(summary.loc['OVERALL']['mota'])
+        strsummary = mm.io.render_summary(
+            summary,
+            formatters=mh.formatters,
+            namemap=mm.io.motchallenge_metric_names
+        )
+        summary_list.append(strsummary)
+    save_log.save_evaluation(summary_list, mota_list, total_iter)
+
+if __name__ == '__main__':
+    if len(sys.argv) == 2:
+        cfg_file = sys.argv[1]
+        train(cfg_file)
+    else:
+        print('Error: Please indicate config file path')
-- 
GitLab