Source code for

@author: Junguang Jiang
import itertools
import random
import numpy as np

import torch
from import Sampler
from import DataLoader, Dataset
from typing import TypeVar, Iterable, Dict, List

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')

[docs]def send_to_device(tensor, device): """ Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. Args: tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`): The data to send to a given device. device (:obj:`torch.device`): The device to send the data to Returns: The same data structure as :obj:`tensor` with all tensors sent to the proper device. """ if isinstance(tensor, (list, tuple)): return type(tensor)(send_to_device(t, device) for t in tensor) elif isinstance(tensor, dict): return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()}) elif not hasattr(tensor, "to"): return tensor return
[docs]class ForeverDataIterator: r"""A data iterator that will never stop producing data""" def __init__(self, data_loader: DataLoader, device=None): self.data_loader = data_loader self.iter = iter(self.data_loader) self.device = device def __next__(self): try: data = next(self.iter) if self.device is not None: data = send_to_device(data, self.device) except StopIteration: self.iter = iter(self.data_loader) data = next(self.iter) if self.device is not None: data = send_to_device(data, self.device) return data def __len__(self): return len(self.data_loader)
[docs]class RandomMultipleGallerySampler(Sampler): r"""Sampler from `In defense of the Triplet Loss for Person Re-Identification (ICCV 2017) <>`_. Assume there are :math:`N` identities in the dataset, this implementation simply samples :math:`K` images for every identity to form an iter of size :math:`N\times K`. During training, we will call ``__iter__`` method of pytorch dataloader once we reach a ``StopIteration``, this guarantees every image in the dataset will eventually be selected and we are not wasting any training data. Args: dataset(list): each element of this list is a tuple (image_path, person_id, camera_id) num_instances(int, optional): number of images to sample for every identity (:math:`K` here) """ def __init__(self, dataset, num_instances=4): super(RandomMultipleGallerySampler, self).__init__(dataset) self.dataset = dataset self.num_instances = num_instances self.idx_to_pid = {} self.cid_list_per_pid = {} self.idx_list_per_pid = {} for idx, (_, pid, cid) in enumerate(dataset): if pid not in self.cid_list_per_pid: self.cid_list_per_pid[pid] = [] self.idx_list_per_pid[pid] = [] self.idx_to_pid[idx] = pid self.cid_list_per_pid[pid].append(cid) self.idx_list_per_pid[pid].append(idx) self.pid_list = list(self.idx_list_per_pid.keys()) self.num_samples = len(self.pid_list) def __len__(self): return self.num_samples * self.num_instances def __iter__(self): def select_idxes(element_list, target_element): assert isinstance(element_list, list) return [i for i, element in enumerate(element_list) if element != target_element] pid_idxes = torch.randperm(len(self.pid_list)).tolist() final_idxes = [] for perm_id in pid_idxes: i = random.choice(self.idx_list_per_pid[self.pid_list[perm_id]]) _, _, cid = self.dataset[i] final_idxes.append(i) pid_i = self.idx_to_pid[i] cid_list = self.cid_list_per_pid[pid_i] idx_list = self.idx_list_per_pid[pid_i] selected_cid_list = select_idxes(cid_list, cid) if selected_cid_list: if len(selected_cid_list) >= self.num_instances: cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=False) else: cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=True) for cid_idx in cid_idxes: final_idxes.append(idx_list[cid_idx]) else: selected_idxes = select_idxes(idx_list, i) if not selected_idxes: continue if len(selected_idxes) >= self.num_instances: pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=False) else: pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=True) for pid_idx in pid_idxes: final_idxes.append(idx_list[pid_idx]) return iter(final_idxes)
[docs]class CombineDataset(Dataset[T_co]): r"""Dataset as a combination of multiple datasets. The element of each dataset must be a list, and the i-th element of the combined dataset is a list splicing of the i-th element of each sub dataset. The length of the combined dataset is the minimum of the lengths of all sub datasets. Arguments: datasets (sequence): List of datasets to be concatenated """ def __init__(self, datasets: Iterable[Dataset]) -> None: super(CombineDataset, self).__init__() # Cannot verify that datasets is Sized assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore self.datasets = list(datasets) def __len__(self): return min([len(d) for d in self.datasets]) def __getitem__(self, idx): return list(itertools.chain(*[d[idx] for d in self.datasets]))
[docs]def concatenate(tensors): """concatenate multiple batches into one batch. ``tensors`` can be :class:`torch.Tensor`, List or Dict, but they must be the same data format. """ if isinstance(tensors[0], torch.Tensor): return, dim=0) elif isinstance(tensors[0], List): ret = [] for i in range(len(tensors[0])): ret.append(concatenate([t[i] for t in tensors])) return ret elif isinstance(tensors[0], Dict): ret = dict() for k in tensors[0].keys(): ret[k] = concatenate([t[k] for t in tensors]) return ret


Access comprehensive documentation for Transfer Learning Library

View Docs


Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources