Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 2016a13f authored by LIN Xiaoyu's avatar LIN Xiaoyu
Browse files

Update KF_tracking.py

parent 9b7feefa
No related branches found
No related tags found
No related merge requests found
import sys
import os
import torch
import shutil
import numpy as np
from configparser import ConfigParser
from data.data_loader import create_dataloader
from save_model import SaveLog
from models.vem_KF import VEM_KF_MODEL
from utils import tracking_evaluation_onebatch_KF
import motmetrics as mm
from utils import get_basic_info
def train(cfg_file):
# Read the config file
if not os.path.isfile(cfg_file):
raise ValueError('Invalid config file path')
cfg = ConfigParser()
cfg.read(cfg_file)
# Create save log directory
save_log = SaveLog(cfg)
save_dir = save_log.save_dir
# Save config file
save_cfg_path = os.path.join(save_dir, 'config.ini')
shutil.copy(cfg_file, save_cfg_path)
# Print basic information
use_cuda = cfg.getboolean('Training', 'use_cuda')
device = 'cuda' if torch.cuda.is_available() and use_cuda else 'cpu'
basic_info = get_basic_info(device)
save_log.print_info(basic_info)
for info in basic_info:
print('%s' % info)
# Create and initialize model
vem_model = VEM_KF_MODEL(cfg, device, save_log)
# Load data
vem_data_loader, vem_data_size = create_dataloader(cfg, data_type='mot')
# Print data information
data_info = []
data_info.append('========== DATA INFO ==========')
data_info.append('Training data: %s' % vem_data_size)
save_log.print_info(data_info)
for info in data_info:
print('%s' % info)
# Start training
print('Start training...')
total_iter = int(cfg.get('VEM', 'N_iter'))
save_frequency = int(cfg.get('Training', 'save_frequency'))
normalize_range = np.array([int(i) for i in cfg.get('DataFrame', 'normalize_range').split(',')], dtype='float64').reshape(-1,4)
acc_list = [[] for i in range(total_iter)]
for idx, data in enumerate(vem_data_loader):
print('batch {}\n'.format(idx))
data_obs = data['det'].to(device)
data_gt = data['gt'].to('cpu')
Eta_iter, x_mean_vem_iter, x_var_vem_iter, Lambda_iter\
= vem_model.model_training(data_obs, data_gt, idx, save_frequency)
acc_list = tracking_evaluation_onebatch_KF(data_gt, normalize_range, acc_list, Eta_iter, x_mean_vem_iter)
summary_list = []
mota_list = [[] for i in range(total_iter)]
for iter_number in range(total_iter):
mh = mm.metrics.create()
name = ['sample_{}'.format(i) for i in range(vem_data_size)]
summary = mh.compute_many(acc_list[iter_number], metrics=mm.metrics.motchallenge_metrics, names=name, generate_overall=True)
mota_list[iter_number].append(summary.loc['OVERALL']['mota'])
strsummary = mm.io.render_summary(
summary,
formatters=mh.formatters,
namemap=mm.io.motchallenge_metric_names
)
summary_list.append(strsummary)
save_log.save_evaluation(summary_list, mota_list, total_iter)
if __name__ == '__main__':
if len(sys.argv) == 2:
cfg_file = sys.argv[1]
train(cfg_file)
else:
print('Error: Please indicate config file path')
\ No newline at end of file
## DVAE-UMOT
## Copyright Inria
## Year 2022
## Contact : xiaoyu.lin@inria.fr
## DVAE-UMOT is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
## DVAE-UMOT is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program, DVAE-UMOT. If not, see <http://www.gnu.org/licenses/> and the LICENSE file.
# DVAE-UMOT has code derived from
# (1) ArTIST, https://github.com/fatemeh-slh/ArTIST.
# (2) DVAE, https://github.com/XiaoyuBIE1994/DVAE, distributed under MIT License 2020 INRIA.
import sys
import os
import torch
import shutil
import numpy as np
from configparser import ConfigParser
from data.data_loader import create_dataloader
from save_model import SaveLog
from models.vem_KF import VEM_KF_MODEL
from utils import tracking_evaluation_onebatch_KF
import motmetrics as mm
from utils import get_basic_info
def train(cfg_file):
# Read the config file
if not os.path.isfile(cfg_file):
raise ValueError('Invalid config file path')
cfg = ConfigParser()
cfg.read(cfg_file)
# Create save log directory
save_log = SaveLog(cfg)
save_dir = save_log.save_dir
# Save config file
save_cfg_path = os.path.join(save_dir, 'config.ini')
shutil.copy(cfg_file, save_cfg_path)
# Print basic information
use_cuda = cfg.getboolean('Training', 'use_cuda')
device = 'cuda' if torch.cuda.is_available() and use_cuda else 'cpu'
basic_info = get_basic_info(device)
save_log.print_info(basic_info)
for info in basic_info:
print('%s' % info)
# Create and initialize model
vem_model = VEM_KF_MODEL(cfg, device, save_log)
# Load data
vem_data_loader, vem_data_size = create_dataloader(cfg, data_type='mot')
# Print data information
data_info = []
data_info.append('========== DATA INFO ==========')
data_info.append('Training data: %s' % vem_data_size)
save_log.print_info(data_info)
for info in data_info:
print('%s' % info)
# Start training
print('Start training...')
total_iter = int(cfg.get('VEM', 'N_iter'))
save_frequency = int(cfg.get('Training', 'save_frequency'))
normalize_range = np.array([int(i) for i in cfg.get('DataFrame', 'normalize_range').split(',')], dtype='float64').reshape(-1,4)
acc_list = [[] for i in range(total_iter)]
for idx, data in enumerate(vem_data_loader):
print('batch {}\n'.format(idx))
data_obs = data['det'].to(device)
data_gt = data['gt'].to('cpu')
Eta_iter, x_mean_vem_iter, x_var_vem_iter, Lambda_iter\
= vem_model.model_training(data_obs, data_gt, idx, save_frequency)
acc_list = tracking_evaluation_onebatch_KF(data_gt, normalize_range, acc_list, Eta_iter, x_mean_vem_iter)
summary_list = []
mota_list = [[] for i in range(total_iter)]
for iter_number in range(total_iter):
mh = mm.metrics.create()
name = ['sample_{}'.format(i) for i in range(vem_data_size)]
summary = mh.compute_many(acc_list[iter_number], metrics=mm.metrics.motchallenge_metrics, names=name, generate_overall=True)
mota_list[iter_number].append(summary.loc['OVERALL']['mota'])
strsummary = mm.io.render_summary(
summary,
formatters=mh.formatters,
namemap=mm.io.motchallenge_metric_names
)
summary_list.append(strsummary)
save_log.save_evaluation(summary_list, mota_list, total_iter)
if __name__ == '__main__':
if len(sys.argv) == 2:
cfg_file = sys.argv[1]
train(cfg_file)
else:
print('Error: Please indicate config file path')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment