Mentions légales du service

Skip to content
Snippets Groups Projects
Commit cf05fa76 authored by Paul Wawerek-López's avatar Paul Wawerek-López
Browse files

optional temp_dir for cluster

parent 81a9eedd
No related branches found
No related tags found
1 merge request!2OSLO-IC
......@@ -7,13 +7,15 @@ import healpy as hp
import random
class HealpixDataset(Dataset):
def __init__(self, file_address, transform=None):
def __init__(self, file_address, transform=None, tmp_dir:str=''):
with open(file_address, 'r') as f:
# lines = f.readlines()
lines = f.read().splitlines()
#This line is a comment. The secong line after this comment is the healpix resolution used for sampling. The third line is the root directory where the files are located. Then, all file names are listed
self.resolution = int(lines[1])
self.root_dir = lines[2]
if tmp_dir:
self.root_dir = os.path.join(tmp_dir, os.path.basename(self.root_dir))
self.list_filenames = lines[3:]
nside = hp.order2nside(self.resolution) # == 2 ** sampling_resolution
self.nPix = hp.nside2npix(nside)
......
......@@ -407,6 +407,7 @@ parser.add_argument('--seed', type=int, help='Set random seed for reproducibilit
parser.add_argument('--train-data', '-i', type=str, default='./train.txt', help="A text file containing location of train images.")
parser.add_argument('--test-data', '-t', type=str, help="A text file containing location of test images. This represents test mode (no training will be done)")
parser.add_argument('--validation-data', '-v', type=str, help="A text file containing location of validation images.")
parser.add_argument('--tmp-dir', '-tmp', type=str, default='', help="A temporary path where images are saved.")
parser.add_argument('--batch-size-train', '-bst', type=int, default=10, help='batch size train dataset')
parser.add_argument('--batch-size-valtest', '-bsvt', type=int, default=10, help='batch size validation and test datasets')
parser.add_argument('--dataloader-num-workers', type=int, default=0, help='multi-process data loading with the specified number of loader worker processes')
......@@ -613,7 +614,7 @@ def main():
if args.test_data: # test phase. Note: this should be after args.checkpoint_file to load the latest model
if not args.checkpoint_file: net.update(force=True) # update the model CDFs parameters.
transform = torchvision.transforms.Compose([dataset.ToTensor(to_range_minusOne_to_plusOne=False)])
test_dataset = dataset.HealpixDataset(args.test_data, transform=transform)
test_dataset = dataset.HealpixDataset(args.test_data, transform=transform, tmp_dir=args.tmp_dir)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size_valtest, shuffle=False, num_workers=args.dataloader_num_workers)
assert test_dataset.resolution == args.healpix_res, "resolution of test dataset doesn't match with input dataset"
valTestVisFolder = os.path.join(args.out_dir, args.foldername_valtest) if args.foldername_valtest is not None else None
......@@ -626,7 +627,7 @@ def main():
return
transform = torchvision.transforms.Compose([dataset.ToTensor(to_range_minusOne_to_plusOne=False)])
train_dataset = dataset.HealpixDataset(args.train_data, transform=transform)
train_dataset = dataset.HealpixDataset(args.train_data, transform=transform, tmp_dir=args.tmp_dir)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.dataloader_num_workers)
assert train_dataset.resolution == args.healpix_res, "resolution of train dataset doesn't match with input dataset"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment