Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 826f7f1a authored by PaulWawerek-L's avatar PaulWawerek-L
Browse files

save always best model during validation

parent 1082bae0
No related branches found
No related tags found
1 merge request!2OSLO-IC
...@@ -327,20 +327,25 @@ def save_config(list_dicts, output_folder, filename=None): ...@@ -327,20 +327,25 @@ def save_config(list_dicts, output_folder, filename=None):
f.write(f"{key:_<40}: {str(val)}\n") # check this for all kinds of formatting f.write(f"{key:_<40}: {str(val)}\n") # check this for all kinds of formatting
f.write("="*60+"\n") f.write("="*60+"\n")
def save_checkpoint(state_dics, is_best, output_folder, filename=None): def save_checkpoint(state_dics, is_best, output_folder, filename=None, only_best_model=False):
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
if filename: if not only_best_model:
if filename.endswith('.tar'): filename = filename.replace('.tar', '') if filename:
if filename.endswith('.pth'): filename = filename.replace('.pth', '') if filename.endswith('.tar'): filename = filename.replace('.tar', '')
else: if filename.endswith('.pth'): filename = filename.replace('.pth', '')
filename = f"checkpoint_{state_dics['epoch']+1:03d}" else:
output_fileAddr = os.path.join(output_folder, filename+'.pth.tar') filename = f"checkpoint_{state_dics['epoch']+1:04d}"
torch.save(state_dics, output_fileAddr)
print(f'saved checkpoint to {output_fileAddr}') output_fileAddr = os.path.join(output_folder, filename+'.pth.tar')
torch.save(state_dics, output_fileAddr)
print(f'saved checkpoint to {output_fileAddr}')
if is_best: if is_best:
fp_best = os.path.join(output_folder, 'checkpoint_best_loss.pth.tar') fp_best = os.path.join(output_folder, 'checkpoint_best_loss.pth.tar')
if only_best_model:
torch.save(state_dics, fp_best)
else:
shutil.copyfile(output_fileAddr, fp_best)
print(f'saved best model to {fp_best}') print(f'saved best model to {fp_best}')
shutil.copyfile(output_fileAddr, fp_best)
def get_model_name(string_name): def get_model_name(string_name):
for ch in [",", " ", "_", "-"]: # replace all with space for ch in [",", " ", "_", "-"]: # replace all with space
...@@ -585,9 +590,10 @@ def main(): ...@@ -585,9 +590,10 @@ def main():
if args.loss_file: if args.loss_file:
with open(os.path.join(args.out_dir, args.loss_file), 'a') as f: with open(os.path.join(args.out_dir, args.loss_file), 'a') as f:
f.write(loss_str + '\n') f.write(loss_str + '\n')
perform_validation = args.validation_data and ((epoch + 1) >= args.validation_start)
if args.validation_data: # validation data if args.validation_data: # validation data
if (epoch + 1) >= args.validation_start: if perform_validation:
saveVis = (((epoch+1) % args.interval_save_valtest == 0) or ((epoch+1) == args.max_epochs)) and (args.foldername_valtest is not None) saveVis = (((epoch+1) % args.interval_save_valtest == 0) or ((epoch+1) == args.max_epochs)) and (args.foldername_valtest is not None)
valTestVisFolder = os.path.join(args.out_dir, args.foldername_valtest) if saveVis else None valTestVisFolder = os.path.join(args.out_dir, args.foldername_valtest) if saveVis else None
loss_validation = test_epoch(validation_dataloader, struct_loader, net, criterion, args.patch_res_valtest, valTestVisFolder, args.print_freq, epoch) loss_validation = test_epoch(validation_dataloader, struct_loader, net, criterion, args.patch_res_valtest, valTestVisFolder, args.print_freq, epoch)
...@@ -604,10 +610,11 @@ def main(): ...@@ -604,10 +610,11 @@ def main():
if not args.no_scheduler: if not args.no_scheduler:
scheduler.step() # type: ignore scheduler.step() # type: ignore
scheduler_aux.step() # type: ignore scheduler_aux.step() # type: ignore
if (epoch+1) % args.checkpoint_interval == 0: save_regular_checkpoint = (epoch+1) % args.checkpoint_interval == 0
is_best = loss_validation["loss"] < best_loss if args.validation_data and ((epoch+1) >= args.validation_start) else False if save_regular_checkpoint or perform_validation:
best_loss = min(loss_validation["loss"], best_loss) if args.validation_data and ((epoch+1) >= args.validation_start) else best_loss is_best = loss_validation["loss"] < best_loss if perform_validation else False
best_loss = min(loss_validation["loss"], best_loss) if perform_validation else best_loss
states = { states = {
'epoch': epoch, 'epoch': epoch,
'net_state_dict': net.state_dict(), 'net_state_dict': net.state_dict(),
...@@ -619,7 +626,7 @@ def main(): ...@@ -619,7 +626,7 @@ def main():
if not args.no_scheduler: if not args.no_scheduler:
states["scheduler_state_dict"] = scheduler.state_dict() states["scheduler_state_dict"] = scheduler.state_dict()
states["scheduler_aux_state_dict"] = scheduler_aux.state_dict() states["scheduler_aux_state_dict"] = scheduler_aux.state_dict()
save_checkpoint(states, is_best, args.out_dir) save_checkpoint(states, is_best, args.out_dir, only_best_model=(not save_regular_checkpoint))
# Saving last results # Saving last results
is_best = loss_validation["loss"] < best_loss if args.validation_data else False is_best = loss_validation["loss"] < best_loss if args.validation_data else False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment