Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 37a650f7 authored by PaulWawerek-L's avatar PaulWawerek-L
Browse files

separate best checkpoints

parent 8bcbf01a
No related branches found
No related tags found
1 merge request!2OSLO-IC
......@@ -330,7 +330,7 @@ def save_config(list_dicts, output_folder, filename=None):
f.write(f"{key:_<40}: {str(val)}\n") # check this for all kinds of formatting
f.write("="*60+"\n")
def save_checkpoint(state_dics, is_best, output_folder, filename=None, only_best_model=False):
def save_checkpoint(state_dics, is_best, output_folder, filename=None, only_best_model=False, best_checkpoint_milestones=None):
os.makedirs(output_folder, exist_ok=True)
if not only_best_model:
if filename:
......@@ -349,6 +349,13 @@ def save_checkpoint(state_dics, is_best, output_folder, filename=None, only_best
else:
shutil.copyfile(output_fileAddr, fp_best)
print(f'saved best model to {fp_best}')
if best_checkpoint_milestones:
for milestone in best_checkpoint_milestones:
if state_dics["epoch"] < milestone:
fp_milestone = os.path.join(output_folder, f'checkpoint_best_loss_{milestone:03d}.pth.tar')
shutil.copyfile(fp_best, fp_milestone)
print(f'saved best model to {fp_milestone}')
break
def get_model_name(string_name):
for ch in [",", " ", "_", "-"]: # replace all with space
......@@ -430,10 +437,15 @@ parser.add_argument("--sdpa-normalization", '-sn', type=str, default='non', help
# Weight Transfer
parser.add_argument("--weight-transfer", '-wt', action='store_true', help='Transfer weights from model pretrained on plain images')
parser.add_argument("--foldername-pretrained", type=str, default='./pretrained', help='Local folder to save pretrained model')
parser.add_argument('--save-best-milestones', nargs='*', type=int, default=[100, 200, 400], help="List of epoch indices to store best checkpoints separately (only for weight transfer).")
def main():
print('='*40+f' {datetime.now().strftime("%d-%m-%Y %H:%M:%S")} '+'='*40)
args = parser.parse_args()
if not args.weight_transfer:
args.save_best_milestones = None
if args.save_best_milestones:
args.save_best_milestones.sort()
print("=========== printing args ===========")
for key, val in args.__dict__.items():
......@@ -657,7 +669,7 @@ def main():
if not args.no_scheduler:
states["scheduler_state_dict"] = scheduler.state_dict()
states["scheduler_aux_state_dict"] = scheduler_aux.state_dict()
save_checkpoint(states, is_best, args.out_dir, only_best_model=(not save_regular_checkpoint))
save_checkpoint(states, is_best, args.out_dir, only_best_model=(not save_regular_checkpoint), best_checkpoint_milestones=args.save_best_milestones)
# Saving last results
is_best = loss_validation["loss"] < best_loss if perform_validation else False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment