Source code for

@author: Junguang Jiang
import os
import warnings
from typing import Optional, Callable, Tuple, Any, List, Iterable
import bisect

from import Dataset, T_co, IterableDataset
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader

[docs]class ImageList(datasets.VisionDataset): """A generic Dataset class for image classification Args: root (str): Root directory of dataset classes (list[str]): The names of all the classes data_list_file (str): File to read the image list from. transform (callable, optional): A function/transform that takes in an PIL image \ and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `data_list_file`, each line has 2 values in the following format. :: source_dir/dog_xxx.png 0 source_dir/cat_123.png 1 target_dir/dog_xxy.png 0 target_dir/cat_nsdf3.png 1 The first value is the relative path of an image, and the second value is the label of the corresponding image. If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`. """ def __init__(self, root: str, classes: List[str], data_list_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): super().__init__(root, transform=transform, target_transform=target_transform) self.samples = self.parse_data_file(data_list_file) self.targets = [s[1] for s in self.samples] self.classes = classes self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} self.loader = default_loader self.data_list_file = data_list_file def __getitem__(self, index: int) -> Tuple[Any, int]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path, target = self.samples[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.samples)
[docs] def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]: """Parse file to data list Args: file_name (str): The path of data file return (list): List of (image path, class_index) tuples """ with open(file_name, "r") as f: data_list = [] for line in f.readlines(): split_line = line.split() target = split_line[-1] path = ' '.join(split_line[:-1]) if not os.path.isabs(path): path = os.path.join(self.root, path) target = int(target) data_list.append((path, target)) return data_list
@property def num_classes(self) -> int: """Number of classes""" return len(self.classes)
[docs] @classmethod def domains(cls): """All possible domain in this dataset""" raise NotImplemented
class MultipleDomainsDataset(Dataset[T_co]): r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """ datasets: List[Dataset[T_co]] cumulative_sizes: List[int] @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, domains: Iterable[Dataset], domain_names: Iterable[str], domain_ids) -> None: super(MultipleDomainsDataset, self).__init__() # Cannot verify that datasets is Sized assert len(domains) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] self.datasets = = list(domains) for d in assert not isinstance(d, IterableDataset), "MultipleDomainsDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum( self.domain_names = domain_names self.domain_ids = domain_ids def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed dataset length") idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return[dataset_idx][sample_idx] + (self.domain_ids[dataset_idx],) @property def cummulative_sizes(self): warnings.warn("cummulative_sizes attribute is renamed to " "cumulative_sizes", DeprecationWarning, stacklevel=2) return self.cumulative_sizes


Access comprehensive documentation for Transfer Learning Library

View Docs


Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources