From 2016a13fd77312423acde6973e26c9f68e01b40f Mon Sep 17 00:00:00 2001 From: LIN Xiaoyu <xiaoyu.lin@inria.fr> Date: Fri, 25 Feb 2022 08:42:07 +0000 Subject: [PATCH] Update KF_tracking.py --- KF_tracking.py | 200 +++++++++++++++++++++++++++---------------------- 1 file changed, 111 insertions(+), 89 deletions(-) diff --git a/KF_tracking.py b/KF_tracking.py index 883b839..34ac6e6 100644 --- a/KF_tracking.py +++ b/KF_tracking.py @@ -1,89 +1,111 @@ -import sys -import os -import torch -import shutil -import numpy as np -from configparser import ConfigParser -from data.data_loader import create_dataloader -from save_model import SaveLog -from models.vem_KF import VEM_KF_MODEL -from utils import tracking_evaluation_onebatch_KF -import motmetrics as mm - -from utils import get_basic_info - - -def train(cfg_file): - # Read the config file - if not os.path.isfile(cfg_file): - raise ValueError('Invalid config file path') - cfg = ConfigParser() - cfg.read(cfg_file) - - # Create save log directory - save_log = SaveLog(cfg) - save_dir = save_log.save_dir - - # Save config file - save_cfg_path = os.path.join(save_dir, 'config.ini') - shutil.copy(cfg_file, save_cfg_path) - - # Print basic information - use_cuda = cfg.getboolean('Training', 'use_cuda') - device = 'cuda' if torch.cuda.is_available() and use_cuda else 'cpu' - - basic_info = get_basic_info(device) - save_log.print_info(basic_info) - for info in basic_info: - print('%s' % info) - - # Create and initialize model - vem_model = VEM_KF_MODEL(cfg, device, save_log) - - # Load data - vem_data_loader, vem_data_size = create_dataloader(cfg, data_type='mot') - - # Print data information - data_info = [] - data_info.append('========== DATA INFO ==========') - data_info.append('Training data: %s' % vem_data_size) - save_log.print_info(data_info) - for info in data_info: - print('%s' % info) - - # Start training - print('Start training...') - total_iter = int(cfg.get('VEM', 'N_iter')) - save_frequency = int(cfg.get('Training', 'save_frequency')) - normalize_range = np.array([int(i) for i in cfg.get('DataFrame', 'normalize_range').split(',')], dtype='float64').reshape(-1,4) - acc_list = [[] for i in range(total_iter)] - for idx, data in enumerate(vem_data_loader): - print('batch {}\n'.format(idx)) - data_obs = data['det'].to(device) - data_gt = data['gt'].to('cpu') - Eta_iter, x_mean_vem_iter, x_var_vem_iter, Lambda_iter\ - = vem_model.model_training(data_obs, data_gt, idx, save_frequency) - - acc_list = tracking_evaluation_onebatch_KF(data_gt, normalize_range, acc_list, Eta_iter, x_mean_vem_iter) - - summary_list = [] - mota_list = [[] for i in range(total_iter)] - for iter_number in range(total_iter): - mh = mm.metrics.create() - name = ['sample_{}'.format(i) for i in range(vem_data_size)] - summary = mh.compute_many(acc_list[iter_number], metrics=mm.metrics.motchallenge_metrics, names=name, generate_overall=True) - mota_list[iter_number].append(summary.loc['OVERALL']['mota']) - strsummary = mm.io.render_summary( - summary, - formatters=mh.formatters, - namemap=mm.io.motchallenge_metric_names - ) - summary_list.append(strsummary) - save_log.save_evaluation(summary_list, mota_list, total_iter) - -if __name__ == '__main__': - if len(sys.argv) == 2: - cfg_file = sys.argv[1] - train(cfg_file) - else: - print('Error: Please indicate config file path') \ No newline at end of file +## DVAE-UMOT +## Copyright Inria +## Year 2022 +## Contact : xiaoyu.lin@inria.fr + +## DVAE-UMOT is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. + +## DVAE-UMOT is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program, DVAE-UMOT. If not, see <http://www.gnu.org/licenses/> and the LICENSE file. + +# DVAE-UMOT has code derived from +# (1) ArTIST, https://github.com/fatemeh-slh/ArTIST. +# (2) DVAE, https://github.com/XiaoyuBIE1994/DVAE, distributed under MIT License 2020 INRIA. + +import sys +import os +import torch +import shutil +import numpy as np +from configparser import ConfigParser +from data.data_loader import create_dataloader +from save_model import SaveLog +from models.vem_KF import VEM_KF_MODEL +from utils import tracking_evaluation_onebatch_KF +import motmetrics as mm + +from utils import get_basic_info + + +def train(cfg_file): + # Read the config file + if not os.path.isfile(cfg_file): + raise ValueError('Invalid config file path') + cfg = ConfigParser() + cfg.read(cfg_file) + + # Create save log directory + save_log = SaveLog(cfg) + save_dir = save_log.save_dir + + # Save config file + save_cfg_path = os.path.join(save_dir, 'config.ini') + shutil.copy(cfg_file, save_cfg_path) + + # Print basic information + use_cuda = cfg.getboolean('Training', 'use_cuda') + device = 'cuda' if torch.cuda.is_available() and use_cuda else 'cpu' + + basic_info = get_basic_info(device) + save_log.print_info(basic_info) + for info in basic_info: + print('%s' % info) + + # Create and initialize model + vem_model = VEM_KF_MODEL(cfg, device, save_log) + + # Load data + vem_data_loader, vem_data_size = create_dataloader(cfg, data_type='mot') + + # Print data information + data_info = [] + data_info.append('========== DATA INFO ==========') + data_info.append('Training data: %s' % vem_data_size) + save_log.print_info(data_info) + for info in data_info: + print('%s' % info) + + # Start training + print('Start training...') + total_iter = int(cfg.get('VEM', 'N_iter')) + save_frequency = int(cfg.get('Training', 'save_frequency')) + normalize_range = np.array([int(i) for i in cfg.get('DataFrame', 'normalize_range').split(',')], dtype='float64').reshape(-1,4) + acc_list = [[] for i in range(total_iter)] + for idx, data in enumerate(vem_data_loader): + print('batch {}\n'.format(idx)) + data_obs = data['det'].to(device) + data_gt = data['gt'].to('cpu') + Eta_iter, x_mean_vem_iter, x_var_vem_iter, Lambda_iter\ + = vem_model.model_training(data_obs, data_gt, idx, save_frequency) + + acc_list = tracking_evaluation_onebatch_KF(data_gt, normalize_range, acc_list, Eta_iter, x_mean_vem_iter) + + summary_list = [] + mota_list = [[] for i in range(total_iter)] + for iter_number in range(total_iter): + mh = mm.metrics.create() + name = ['sample_{}'.format(i) for i in range(vem_data_size)] + summary = mh.compute_many(acc_list[iter_number], metrics=mm.metrics.motchallenge_metrics, names=name, generate_overall=True) + mota_list[iter_number].append(summary.loc['OVERALL']['mota']) + strsummary = mm.io.render_summary( + summary, + formatters=mh.formatters, + namemap=mm.io.motchallenge_metric_names + ) + summary_list.append(strsummary) + save_log.save_evaluation(summary_list, mota_list, total_iter) + +if __name__ == '__main__': + if len(sys.argv) == 2: + cfg_file = sys.argv[1] + train(cfg_file) + else: + print('Error: Please indicate config file path') -- GitLab