Source code for

@author: Junguang Jiang
from ..imagelist import ImageList
from ..office31 import Office31
from ..officehome import OfficeHome
from ..visda2017 import VisDA2017
from ..officecaltech import OfficeCaltech
from .imagenet_caltech import ImageNetCaltech
from .caltech_imagenet import CaltechImageNet
from import ImageNetCaltech
from typing import Sequence, ClassVar

__all__ = ['Office31', 'OfficeHome', "VisDA2017", "CaltechImageNet", "ImageNetCaltech"]

[docs]def partial(dataset_class: ClassVar, partial_classes: Sequence[str]) -> ClassVar: """ Convert a dataset into its partial version. In other words, those samples which doesn't belong to `partial_classes` will be discarded. Yet `partial` will not change the label space of `dataset_class`. Args: dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be partial. partial_classes (sequence[str]): A sequence of which categories need to be kept in the partial dataset.\ Each element of `partial_classes` must belong to the `classes` list of `dataset_class`. Examples:: >>> partial_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard'] >>> # create a partial dataset class >>> PartialOffice31 = partial(Office31, partial_classes) >>> # create an instance of the partial dataset >>> dataset = PartialDataset(root="data/office31", task="A") """ if not (issubclass(dataset_class, ImageList)): raise Exception("Only subclass of ImageList can be partial") class PartialDataset(dataset_class): def __init__(self, **kwargs): super(PartialDataset, self).__init__(**kwargs) assert all([c in self.classes for c in partial_classes]) samples = [] for (path, label) in self.samples: class_name = self.classes[label] if class_name in partial_classes: samples.append((path, label)) self.samples = samples self.partial_classes = partial_classes self.partial_classes_idx = [self.class_to_idx[c] for c in partial_classes] return PartialDataset
[docs]def default_partial(dataset_class: ClassVar) -> ClassVar: """ Default partial used in some paper. Args: dataset_class (class): Dataset class. Currently, dataset_class must be one of :class:``, :class:``, :class:``, :class:`` and :class:``. """ if dataset_class == Office31: kept_classes = OfficeCaltech.CLASSES elif dataset_class == OfficeHome: kept_classes = sorted(OfficeHome.CLASSES)[:25] elif dataset_class == VisDA2017: kept_classes = sorted(VisDA2017.CLASSES)[:6] elif dataset_class in [ImageNetCaltech, CaltechImageNet]: kept_classes = dataset_class.CLASSES else: raise NotImplementedError("Unknown partial domain adaptation dataset: {}".format(dataset_class.__name__)) return partial(dataset_class, kept_classes)


Access comprehensive documentation for Transfer Learning Library

View Docs


Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources