Source code for

@author: Junguang Jiang
from typing import Optional, Sequence
import os
from .._util import download as download_data, check_exits
from .image_regression import ImageRegression

[docs]class DSprites(ImageRegression): """`DSprites <>`_ Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \ ``'N'``: Noisy and ``'S'``: Scream. split (str, optional): The dataset split, supports ``train``, or ``test``. factors (sequence[str]): Factors selected. Default: ('scale', 'position x', 'position y'). download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: color/ ... noisy/ scream/ image_list/ color_train.txt noisy_train.txt scream_train.txt color_test.txt noisy_test.txt scream_test.txt """ download_list = [ ("image_list", "", ""), ("color", "color.tgz", ""), ("noisy", "noisy.tgz", ""), ("scream", "scream.tgz", ""), ] image_list = { "C": "color", "N": "noisy", "S": "scream" } FACTORS = ('none', 'shape', 'scale', 'orientation', 'position x', 'position y') def __init__(self, root: str, task: str, split: Optional[str] = 'train', factors: Sequence[str] = ('scale', 'position x', 'position y'), download: Optional[bool] = True, target_transform=None, **kwargs): assert task in self.image_list assert split in ['train', 'test'] for factor in factors: assert factor in self.FACTORS factor_index = [self.FACTORS.index(factor) for factor in factors] if target_transform is None: target_transform = lambda x: x[list(factor_index)] else: target_transform = lambda x: target_transform(x[list(factor_index)]) data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split)) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(DSprites, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform, **kwargs)


Access comprehensive documentation for Transfer Learning Library

View Docs


Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources