Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 30cfffc9 authored by PaulWawerek-L's avatar PaulWawerek-L
Browse files

print_best_model

parent 09706ed3
No related branches found
No related tags found
1 merge request!2OSLO-IC
...@@ -177,7 +177,7 @@ def compute_actual_bits(compressed_stream): ...@@ -177,7 +177,7 @@ def compute_actual_bits(compressed_stream):
total_bits_per_image = torch.sum(torch.stack(list_latent_bits, dim=0), dim=0).detach().cpu().long() total_bits_per_image = torch.sum(torch.stack(list_latent_bits, dim=0), dim=0).detach().cpu().long()
return total_bits_per_image return total_bits_per_image
def test_epoch(test_dataloader, struct_loader, model, criterion, patch_res=None, visFolder=None, print_freq=None, epoch = None, checkWithActualCompression = False): def test_epoch(test_dataloader, struct_loader, model, criterion, patch_res=None, visFolder=None, print_freq=None, epoch=None, checkWithActualCompression=False):
model.eval() model.eval()
device = next(model.parameters()).device device = next(model.parameters()).device
...@@ -337,7 +337,9 @@ def save_checkpoint(state_dics, is_best, output_folder, filename=None): ...@@ -337,7 +337,9 @@ def save_checkpoint(state_dics, is_best, output_folder, filename=None):
output_fileAddr += ".pth.tar" output_fileAddr += ".pth.tar"
torch.save(state_dics, output_fileAddr) torch.save(state_dics, output_fileAddr)
if is_best: if is_best:
shutil.copyfile(output_fileAddr, os.path.join(output_folder, 'checkpoint_best_loss.pth.tar')) fp_best = os.path.join(output_folder, 'checkpoint_best_loss.pth.tar')
print(f'saved best model to {fp_best}')
shutil.copyfile(output_fileAddr, fp_best)
def get_model_name(string_name): def get_model_name(string_name):
for ch in [",", " ", "_", "-"]: # replace all with space for ch in [",", " ", "_", "-"]: # replace all with space
...@@ -395,7 +397,7 @@ parser.add_argument('--clip_max_norm', type=float, default=1., help='gradient cl ...@@ -395,7 +397,7 @@ parser.add_argument('--clip_max_norm', type=float, default=1., help='gradient cl
# Network config # Network config
parser.add_argument('--max-epochs', '-e', type=int, default=1000, help='max epochs') parser.add_argument('--max-epochs', '-e', type=int, default=1000, help='max epochs')
parser.add_argument('--no-scheduler', '-ns', action='store_true', help='Disable scheduler') parser.add_argument('--no-scheduler', '-ns', action='store_true', help='Disable scheduler')
parser.add_argument('--scheduler-milestones', nargs='+', type=int, default=[30, 130], help=" List of epoch indices for scheduler to adjust learning rate.") parser.add_argument('--scheduler-milestones', nargs='+', type=int, default=[30, 130], help=" List of epoch indices for scheduler to adjust learning rate if no validation data is given.")
parser.add_argument('--scheduler-gamma', type=float, default=np.sqrt(0.1), help='Multiplicative factor of learning rate decay') parser.add_argument('--scheduler-gamma', type=float, default=np.sqrt(0.1), help='Multiplicative factor of learning rate decay')
parser.add_argument('--scheduler-patience', type=int, default=10, help='Patience in terms of number of epochs for scheduler ReduceLROnPlateau') parser.add_argument('--scheduler-patience', type=int, default=10, help='Patience in terms of number of epochs for scheduler ReduceLROnPlateau')
parser.add_argument('--gpu', '-g', action='store_true', help='enables cuda') parser.add_argument('--gpu', '-g', action='store_true', help='enables cuda')
...@@ -404,6 +406,7 @@ parser.add_argument('--conv', '-c', type=str, default='SDPAConv', help="Graph c ...@@ -404,6 +406,7 @@ parser.add_argument('--conv', '-c', type=str, default='SDPAConv', help="Graph c
parser.add_argument('--skip-connection-aggregation', '-sc', type=str, default='sum', help="Mode for jumping knowledge") parser.add_argument('--skip-connection-aggregation', '-sc', type=str, default='sum', help="Mode for jumping knowledge")
parser.add_argument('--pool-func', '-pf', type=str, default="stride", help="Pooling function.") parser.add_argument('--pool-func', '-pf', type=str, default="stride", help="Pooling function.")
parser.add_argument('--unpool-func', '-upf', type=str, default="pixel_shuffle", help="Unpooling function.") parser.add_argument('--unpool-func', '-upf', type=str, default="pixel_shuffle", help="Unpooling function.")
# TODO: parser.add_argument('--attention', '-at', action='store_true', help='use additional attention modules in the En- and Decoder')
# HealPix # HealPix
parser.add_argument("--use-4connectivity", action='store_true', help='use 4 neighboring for graph construction') parser.add_argument("--use-4connectivity", action='store_true', help='use 4 neighboring for graph construction')
parser.add_argument("--use-euclidean", action='store_true', help='Use geodesic distance for graph weights') parser.add_argument("--use-euclidean", action='store_true', help='Use geodesic distance for graph weights')
...@@ -560,7 +563,7 @@ def main(): ...@@ -560,7 +563,7 @@ def main():
print(f"Loss train: Epoch [{epoch:04n}/{args.max_epochs:04n}]:", ', '.join([f'{k}={v:6.4f}' for k, v in loss_train.items()]), flush=True) print(f"Loss train: Epoch [{epoch:04n}/{args.max_epochs:04n}]:", ', '.join([f'{k}={v:6.4f}' for k, v in loss_train.items()]), flush=True)
if args.validation_data: # validation data if args.validation_data: # validation data
saveVis = ((epoch % args.interval_save_valtest == 0) or (epoch == args.max_epochs-1)) and args.foldername_valtest is not None saveVis = ((epoch % args.interval_save_valtest == 0) or (epoch == args.max_epochs-1)) and (args.foldername_valtest is not None)
valTestVisFolder = os.path.join(args.out_dir, args.foldername_valtest) if saveVis else None valTestVisFolder = os.path.join(args.out_dir, args.foldername_valtest) if saveVis else None
loss_validation = test_epoch(validation_dataloader, struct_loader, net, criterion, args.patch_res_valtest, valTestVisFolder, args.print_freq, epoch) loss_validation = test_epoch(validation_dataloader, struct_loader, net, criterion, args.patch_res_valtest, valTestVisFolder, args.print_freq, epoch)
list_mean_losses_validation[epoch] = loss_validation list_mean_losses_validation[epoch] = loss_validation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment