Commit 7b7f5f81 authored by VIOLANTE Nicolas's avatar VIOLANTE Nicolas
Browse files

Merge branch 'dataset' into 'master'

add dataset class

See merge request !3
parents e9cc30fc 1f99c588
import os
from pathlib import Path
from torch.utils.data import Dataset
from torchvision.transforms import Compose, ToTensor
from PIL import Image
import numpy as np
def is_image_ext(filename: str):
ext = str(filename).split('.')[-1].lower()
return f'.{ext}' in Image.EXTENSION
EgoNetTransform = Compose([ToTensor()])
class PathDataset(Dataset):
def __init__(self, source_dir, transform) -> None:
super().__init__()
self._transform = transform
self._image_paths = self._get_image_paths(source_dir)
self._image_shape = list(self[0][1].shape)
def _get_image_paths(self, source_dir):
Image.init()
paths = [str(f) for f in Path(source_dir).rglob('*') if is_image_ext(f) and os.path.isfile(f)]
if not len(paths) > 0:
raise ValueError(f"No images found in {source_dir}")
return paths
def __len__(self):
return len(self._image_paths)
def __getitem__(self, idx):
image_array = np.array(Image.open(self._image_paths[idx]))
image_tensor = self._transform(image_array)
return self._image_paths[idx], image_tensor
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment