Commit 1f99c588 authored by nviolante's avatar nviolante
Browse files

add dataset class

parent e9cc30fc
import os
from pathlib import Path
from torch.utils.data import Dataset
from torchvision.transforms import Compose, ToTensor
from PIL import Image
import numpy as np
def is_image_ext(filename: str):
ext = str(filename).split('.')[-1].lower()
return f'.{ext}' in Image.EXTENSION
EgoNetTransform = Compose([ToTensor()])
class PathDataset(Dataset):
def __init__(self, source_dir, transform) -> None:
super().__init__()
self._transform = transform
self._image_paths = self._get_image_paths(source_dir)
self._image_shape = list(self[0][1].shape)
def _get_image_paths(self, source_dir):
Image.init()
paths = [str(f) for f in Path(source_dir).rglob('*') if is_image_ext(f) and os.path.isfile(f)]
if not len(paths) > 0:
raise ValueError(f"No images found in {source_dir}")
return paths
def __len__(self):
return len(self._image_paths)
def __getitem__(self, idx):
image_array = np.array(Image.open(self._image_paths[idx]))
image_tensor = self._transform(image_array)
return self._image_paths[idx], image_tensor
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment