I want to build an (abstract) Dataset which is generic and just provides the framework for loading files.
The subclasses are then for specific types (here bother the samples and the annotations are of type np.ndarray
When I instantiate an object of type ImageDataset
I get
File "/home/maximilian/darts/tests/test_dataset.py", line 12, in test_simple_loading
dataset = ImageDataset(dataset_root)
File "/home/maximilian/darts/darts/datasets.py", line 72, in __init__
super(ImageDataset, self).__init__(path)
File "/home/maximilian/darts/darts/datasets.py", line 19, in __init__
self.__load_data(path)
File "/home/maximilian/darts/darts/datasets.py", line 48, in __load_data
self.annotations.update(annotations)
AttributeError: 'ImageDataset' object has no attribute 'annotations'
Can anybody tell me please what I am doing wrong here?
from collections import defaultdict
from abc import abstractmethod
from itertools import tee
from pathlib import Path
from typing import Iterator, TypeVar, Tuple, Dict, Mapping
import numpy as np
from cv2 import haveImageReader, imread
Key = str
Annotation = TypeVar('Annotation')
Sample = TypeVar('Sample')
AnnotatedSample = Tuple[Sample, Annotation]
class Dataset(Mapping[Key, AnnotatedSample]):
def __init__(self, path: Path):
self.__path = path
self.__load_data(path)
self.annotations: Dict[str, Annotation] = defaultdict(lambda: None)
self.samples: Dict[str, Sample] = defaultdict(lambda: None)
@abstractmethod
def _is_sample_file(self, file : Path) -> bool:
raise NotImplementedError()
@abstractmethod
def _is_annotation_file(self, file : Path) -> bool:
raise NotImplementedError()
@abstractmethod
def _load_annotation(self, file: Path) -> Annotation:
raise NotImplementedError()
@abstractmethod
def _load_sample(self, file: Path) -> Sample:
raise NotImplementedError()
def __load_data(self, path: Path):
files = filter(lambda file: not file.is_dir(), path.glob('*'))
it1, it2 = tee(files)
annotations_files = filter(lambda file: self._is_annotation_file(file), it1)
sample_files = filter(lambda file: self._is_sample_file(file), it2)
annotations = map(lambda file: (file.stem, self._load_annotation(file)), annotations_files)
samples = map(lambda file: (file.stem, self._load_sample(file)), sample_files)
self.annotations.update(annotations)
self.samples.update(samples)
annotation_keys = set(self.annotations)
samples_keys = set(self.samples)
annotations_without_sample = annotation_keys.difference(samples_keys)
if annotations_without_sample:
raise ValueError(
f"For each annotation a sample file must be given. Annotation without sample {annotations_without_sample} ")
def __getitem__(self, k: Key) -> AnnotatedSample:
return self.samples[k], self.annotations
def __len__(self) -> int:
return len(self.samples)
def __iter__(self) -> Iterator[Key]:
return self.samples.keys()
class ImageDataset(Dataset[np.ndarray, np.ndarray]):
ANNOTATION_EXTENSIONS = ['.npy']
def __init__(self, path : Path, transformations = []):
super(ImageDataset, self).__init__(path)
self.__transformations = transformations
def _is_annotation_file(self, file: Path) -> bool:
return haveImageReader(str(file))
def _is_sample_file(self, file: Path) -> bool:
return file.stem in ImageDataset.ANNOTATION_EXTENSIONS
def _load_annotation(self, file: Path) -> Annotation:
return np.load(str(file))
def _load_sample(self, file: Path) -> Sample:
image = imread(str(file))
for transform in self.__transformations:
image = transform(image)
return image
question from:
https://stackoverflow.com/questions/65860223/how-to-use-mappings-generic-in-python