Source code for tllib.vision.datasets.imagelist
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
import warnings
from typing import Optional, Callable, Tuple, Any, List, Iterable
import bisect
from torch.utils.data.dataset import Dataset, T_co, IterableDataset
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader
[docs]class ImageList(datasets.VisionDataset):
"""A generic Dataset class for image classification
Args:
root (str): Root directory of dataset
classes (list[str]): The names of all the classes
data_list_file (str): File to read the image list from.
transform (callable, optional): A function/transform that takes in an PIL image \
and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `data_list_file`, each line has 2 values in the following format.
::
source_dir/dog_xxx.png 0
source_dir/cat_123.png 1
target_dir/dog_xxy.png 0
target_dir/cat_nsdf3.png 1
The first value is the relative path of an image, and the second value is the label of the corresponding image.
If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`.
"""
def __init__(self, root: str, classes: List[str], data_list_file: str,
transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
super().__init__(root, transform=transform, target_transform=target_transform)
self.samples = self.parse_data_file(data_list_file)
self.targets = [s[1] for s in self.samples]
self.classes = classes
self.class_to_idx = {cls: idx
for idx, cls in enumerate(self.classes)}
self.loader = default_loader
self.data_list_file = data_list_file
def __getitem__(self, index: int) -> Tuple[Any, int]:
"""
Args:
index (int): Index
return (tuple): (image, target) where target is index of the target class.
"""
path, target = self.samples[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None and target is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.samples)
[docs] def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]:
"""Parse file to data list
Args:
file_name (str): The path of data file
return (list): List of (image path, class_index) tuples
"""
with open(file_name, "r") as f:
data_list = []
for line in f.readlines():
split_line = line.split()
target = split_line[-1]
path = ' '.join(split_line[:-1])
if not os.path.isabs(path):
path = os.path.join(self.root, path)
target = int(target)
data_list.append((path, target))
return data_list
@property
def num_classes(self) -> int:
"""Number of classes"""
return len(self.classes)
[docs] @classmethod
def domains(cls):
"""All possible domain in this dataset"""
raise NotImplemented
class MultipleDomainsDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, domains: Iterable[Dataset], domain_names: Iterable[str], domain_ids) -> None:
super(MultipleDomainsDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(domains) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
self.datasets = self.domains = list(domains)
for d in self.domains:
assert not isinstance(d, IterableDataset), "MultipleDomainsDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.domains)
self.domain_names = domain_names
self.domain_ids = domain_ids
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.domains[dataset_idx][sample_idx] + (self.domain_ids[dataset_idx],)
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes