Shortcuts

Domain Adversarial Training

DANN: Domain Adversarial Neural Network

class tllib.alignment.dann.DomainAdversarialLoss(domain_discriminator, reduction='mean', grl=None, sigmoid=True)[source]

The Domain Adversarial Loss proposed in Domain-Adversarial Training of Neural Networks (ICML 2015)

Domain adversarial loss measures the domain discrepancy through training a domain discriminator. Given domain discriminator \(D\), feature representation \(f\), the definition of DANN loss is

\[loss(\mathcal{D}_s, \mathcal{D}_t) = \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(f_i^s)] + \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(f_j^t)].\]
Parameters
  • domain_discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1)

  • reduction (str, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Default: 'mean'

  • grl (WarmStartGradientReverseLayer, optional) – Default: None.

Inputs:
  • f_s (tensor): feature representations on source domain, \(f^s\)

  • f_t (tensor): feature representations on target domain, \(f^t\)

  • w_s (tensor, optional): a rescaling weight given to each instance from source domain.

  • w_t (tensor, optional): a rescaling weight given to each instance from target domain.

Shape:
  • f_s, f_t: \((N, F)\) where F means the dimension of input features.

  • Outputs: scalar by default. If reduction is 'none', then \((N, )\).

Examples:

>>> from tllib.modules.domain_discriminator import DomainDiscriminator
>>> discriminator = DomainDiscriminator(in_feature=1024, hidden_size=1024)
>>> loss = DomainAdversarialLoss(discriminator, reduction='mean')
>>> # features from source domain and target domain
>>> f_s, f_t = torch.randn(20, 1024), torch.randn(20, 1024)
>>> # If you want to assign different weights to each instance, you should pass in w_s and w_t
>>> w_s, w_t = torch.randn(20), torch.randn(20)
>>> output = loss(f_s, f_t, w_s, w_t)

CDAN: Conditional Domain Adversarial Network

class tllib.alignment.cdan.ConditionalDomainAdversarialLoss(domain_discriminator, entropy_conditioning=False, randomized=False, num_classes=-1, features_dim=-1, randomized_dim=1024, reduction='mean', sigmoid=True)[source]

The Conditional Domain Adversarial Loss used in Conditional Adversarial Domain Adaptation (NIPS 2018)

Conditional Domain adversarial loss measures the domain discrepancy through training a domain discriminator in a conditional manner. Given domain discriminator \(D\), feature representation \(f\) and classifier predictions \(g\), the definition of CDAN loss is

\[\begin{split}loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(T(f_i^s, g_i^s))] \\ &+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(T(f_j^t, g_j^t))],\\\end{split}\]

where \(T\) is a MultiLinearMap or RandomizedMultiLinearMap which convert two tensors to a single tensor.

Parameters
  • domain_discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1)

  • entropy_conditioning (bool, optional) – If True, use entropy-aware weight to reweight each training example. Default: False

  • randomized (bool, optional) – If True, use randomized multi linear map. Else, use multi linear map. Default: False

  • num_classes (int, optional) – Number of classes. Default: -1

  • features_dim (int, optional) – Dimension of input features. Default: -1

  • randomized_dim (int, optional) – Dimension of features after randomized. Default: 1024

  • reduction (str, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Default: 'mean'

Note

You need to provide num_classes, features_dim and randomized_dim only when randomized is set True.

Inputs:
  • g_s (tensor): unnormalized classifier predictions on source domain, \(g^s\)

  • f_s (tensor): feature representations on source domain, \(f^s\)

  • g_t (tensor): unnormalized classifier predictions on target domain, \(g^t\)

  • f_t (tensor): feature representations on target domain, \(f^t\)

Shape:
  • g_s, g_t: \((minibatch, C)\) where C means the number of classes.

  • f_s, f_t: \((minibatch, F)\) where F means the dimension of input features.

  • Output: scalar by default. If reduction is 'none', then \((minibatch, )\).

Examples:

>>> from tllib.modules.domain_discriminator import DomainDiscriminator
>>> from tllib.alignment.cdan import ConditionalDomainAdversarialLoss
>>> import torch
>>> num_classes = 2
>>> feature_dim = 1024
>>> batch_size = 10
>>> discriminator = DomainDiscriminator(in_feature=feature_dim * num_classes, hidden_size=1024)
>>> loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')
>>> # features from source domain and target domain
>>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> # logits output from source domain adn target domain
>>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
>>> output = loss(g_s, f_s, g_t, f_t)
class tllib.alignment.cdan.RandomizedMultiLinearMap(features_dim, num_classes, output_dim=1024)[source]

Random multi linear map

Given two inputs \(f\) and \(g\), the definition is

\[T_{\odot}(f,g) = \dfrac{1}{\sqrt{d}} (R_f f) \odot (R_g g),\]

where \(\odot\) is element-wise product, \(R_f\) and \(R_g\) are random matrices sampled only once and fixed in training.

Parameters
  • features_dim (int) – dimension of input \(f\)

  • num_classes (int) – dimension of input \(g\)

  • output_dim (int, optional) – dimension of output tensor. Default: 1024

Shape:
  • f: (minibatch, features_dim)

  • g: (minibatch, num_classes)

  • Outputs: (minibatch, output_dim)

class tllib.alignment.cdan.MultiLinearMap[source]

Multi linear map

Shape:
  • f: (minibatch, F)

  • g: (minibatch, C)

  • Outputs: (minibatch, F * C)

ADDA: Adversarial Discriminative Domain Adaptation

class tllib.alignment.adda.DomainAdversarialLoss[source]

Domain adversarial loss from Adversarial Discriminative Domain Adaptation (CVPR 2017). Similar to the original GAN paper, ADDA argues that replacing \(\text{log}(1-p)\) with \(-\text{log}(p)\) in the adversarial loss provides better gradient qualities. Detailed optimization process can be found here.

Inputs:
  • domain_pred (tensor): predictions of domain discriminator

  • domain_label (str, optional): whether the data comes from source or target. Must be ‘source’ or ‘target’. Default: ‘source’

Shape:
  • domain_pred: \((minibatch,)\).

  • Outputs: scalar.

Note

ADDAgrl is also implemented and benchmarked. You can find code here.

BSP: Batch Spectral Penalization

class tllib.alignment.bsp.BatchSpectralPenalizationLoss[source]

Batch spectral penalization loss from Transferability vs. Discriminability: Batch Spectral Penalization for Adversarial Domain Adaptation (ICML 2019).

Given source features \(f_s\) and target features \(f_t\) in current mini batch, singular value decomposition is first performed

\[f_s = U_s\Sigma_sV_s^T\]
\[f_t = U_t\Sigma_tV_t^T\]

Then batch spectral penalization loss is calculated as

\[loss=\sum_{i=1}^k(\sigma_{s,i}^2+\sigma_{t,i}^2)\]

where \(\sigma_{s,i},\sigma_{t,i}\) refer to the \(i-th\) largest singular value of source features and target features respectively. We empirically set \(k=1\).

Inputs:
  • f_s (tensor): feature representations on source domain, \(f^s\)

  • f_t (tensor): feature representations on target domain, \(f^t\)

Shape:
  • f_s, f_t: \((N, F)\) where F means the dimension of input features.

  • Outputs: scalar.

OSBP: Open Set Domain Adaptation by Backpropagation

class tllib.alignment.osbp.UnknownClassBinaryCrossEntropy(t=0.5)[source]

Binary cross entropy loss to make a boundary for unknown samples, proposed by Open Set Domain Adaptation by Backpropagation (ECCV 2018).

Given a sample on target domain \(x_t\) and its classifcation outputs \(y\), the binary cross entropy loss is defined as

\[L_{\text{adv}}(x_t) = -t \text{log}(p(y=C+1|x_t)) - (1-t)\text{log}(1-p(y=C+1|x_t))\]

where t is a hyper-parameter and C is the number of known classes.

Parameters

t (float) – Predefined hyper-parameter. Default: 0.5

Inputs:
  • y (tensor): classification outputs (before softmax).

Shape:
  • y: \((minibatch, C+1)\) where C is the number of known classes.

  • Outputs: scalar

ADVENT: Adversarial Entropy Minimization for Semantic Segmentation

class tllib.alignment.advent.Discriminator(num_classes, ndf=64)[source]

Domain discriminator model from ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation (CVPR 2019)

Distinguish pixel-by-pixel whether the input predictions come from the source domain or the target domain. The source domain label is 1 and the target domain label is 0.

Parameters
  • num_classes (int) – num of classes in the predictions

  • ndf (int) – dimension of the hidden features

Shape:
  • Inputs: \((minibatch, C, H, W)\) where \(C\) is the number of classes

  • Outputs: \((minibatch, 1, H, W)\)

class tllib.alignment.advent.DomainAdversarialEntropyLoss(discriminator)[source]

The Domain Adversarial Entropy Loss

Minimizing entropy with adversarial learning through training a domain discriminator.

Parameters

domain_discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of predictions. Its input shape is \((minibatch, C, H, W)\) and output shape is \((minibatch, 1, H, W)\)

Inputs:
  • logits (tensor): logits output of segmentation model

  • domain_label (str, optional): whether the data comes from source or target. Choices: [‘source’, ‘target’]. Default: ‘source’

Shape:
  • logits: \((minibatch, C, H, W)\) where \(C\) means the number of classes

  • Outputs: scalar.

Examples:

>>> B, C, H, W = 2, 19, 512, 512
>>> discriminator = Discriminator(num_classes=C)
>>> dann = DomainAdversarialEntropyLoss(discriminator)
>>> # logits output on source domain and target domain
>>> y_s, y_t = torch.randn(B, C, H, W), torch.randn(B, C, H, W)
>>> loss = 0.5 * (dann(y_s, "source") + dann(y_t, "target"))
eval()[source]

Sets the module in evaluation mode. In the training mode, all the parameters in discriminator will be set requires_grad=False.

This is equivalent with self.train(False).

forward(logits, domain_label='source')[source]
train(mode=True)[source]

Sets the discriminator in training mode. In the training mode, all the parameters in discriminator will be set requires_grad=True.

Parameters

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

D-adapt: Decoupled Adaptation for Cross-Domain Object Detection

Origin Paper.

class tllib.alignment.d_adapt.proposal.Proposal(image_id, filename, pred_boxes, pred_classes, pred_scores, gt_classes=None, gt_boxes=None, gt_ious=None, gt_fg_classes=None)[source]

A data structure that stores the proposals for a single image.

Parameters
  • image_id (str) – unique image identifier

  • filename (str) – image filename

  • pred_boxes (numpy.ndarray) – predicted boxes

  • pred_classes (numpy.ndarray) – predicted classes

  • pred_scores (numpy.ndarray) – class confidence score

  • gt_classes (numpy.ndarray, optional) – ground-truth classes, including background classes

  • gt_boxes (numpy.ndarray, optional) – ground-truth boxes

  • gt_ious (numpy.ndarray, optional) – IoU between predicted boxes and ground-truth boxes

  • gt_fg_classes (numpy.ndarray, optional) – ground-truth foreground classes, not including background classes

class tllib.alignment.d_adapt.proposal.PersistentProposalList(filename=None)[source]

A data structure that stores the proposals for a dataset.

Parameters

filename (str, optional) – filename indicating where to cache

class tllib.alignment.d_adapt.proposal.ProposalDataset(proposal_list, transform=None, crop_func=None)[source]

A dataset for proposals.

Parameters
  • proposal_list (list) – list of Proposal

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • crop_func – (ExpandCrop, optional):

class tllib.alignment.d_adapt.modeling.meta_arch.DecoupledGeneralizedRCNN(*args, **kwargs)[source]

Generalized R-CNN for Decoupled Adaptation (D-adapt). Similar to that in in Supervised Learning, DecoupledGeneralizedRCNN has the following three components: 1. Per-image feature extraction (aka backbone) 2. Region proposal generation 3. Per-region feature extraction and prediction

Different from that in Supervised Learning, DecoupledGeneralizedRCNN 1. accepts unlabeled images and uses the feedbacks from adaptors as supervision during training 2. generate foreground and background proposals during inference

Parameters
  • backbone – a backbone module, must follow detectron2’s backbone interface

  • proposal_generator – a module that generates proposals using backbone features

  • roi_heads – a ROI head that performs per-region computation

  • pixel_std (pixel_mean,) – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image

  • input_format – describe the meaning of channels of input. Needed by visualization

  • vis_period – the period to run visualization. Set to 0 to disable.

  • finetune (bool) – whether finetune the detector or train from scratch. Default: True

Inputs:
  • batched_inputs: a list, batched outputs of DatasetMapper. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains:

    • image: Tensor, image in (C, H, W) format.

    • instances (optional): groundtruth Instances

    • feedbacks (optional): Instances, feedbacks from adaptors.

    • “height”, “width” (int): the output resolution of the model, used in inference. See postprocess() for details.

  • labeled (bool, optional): whether has ground-truth label

Outputs:
  • outputs (during inference): A list of dict where each dict is the output for one input image. The dict contains a key “instances” whose value is a Instances. The Instances object has the following keys: “pred_boxes”, “pred_classes”, “scores”, “pred_masks”, “pred_keypoints”

  • losses (during training): A dict of different losses

class tllib.alignment.d_adapt.modeling.meta_arch.DecoupledRetinaNet(*args, max_samples_per_level=25, **kwargs)[source]

RetinaNet for Decoupled Adaptation (D-adapt).

Different from that in Supervised Learning, DecoupledRetinaNet 1. accepts unlabeled images and uses the feedbacks from adaptors as supervision during training 2. generate foreground and background proposals during inference

Parameters
  • backbone – a backbone module, must follow detectron2’s backbone interface

  • head (nn.Module) – a module that predicts logits and regression deltas for each level from a list of per-level features

  • head_in_features (Tuple[str]) – Names of the input feature maps to be used in head

  • anchor_generator (nn.Module) – a module that creates anchors from a list of features. Usually an instance of AnchorGenerator

  • box2box_transform (Box2BoxTransform) – defines the transform from anchors boxes to instance boxes

  • anchor_matcher (Matcher) – label the anchors by matching them with ground truth.

  • num_classes (int) – number of classes. Used to label background proposals.

  • Loss parameters (#) –

  • focal_loss_alpha (float) – focal_loss_alpha

  • focal_loss_gamma (float) – focal_loss_gamma

  • smooth_l1_beta (float) – smooth_l1_beta

  • box_reg_loss_type (str) – Options are “smooth_l1”, “giou”

  • Inference parameters (#) –

  • test_score_thresh (float) – Inference cls score threshold, only anchors with score > INFERENCE_TH are considered for inference (to improve speed)

  • test_topk_candidates (int) – Select topk candidates before NMS

  • test_nms_thresh (float) – Overlap threshold used for non-maximum suppression (suppress boxes with IoU >= this threshold)

  • max_detections_per_image (int) – Maximum number of detections to return per image during inference (100 is based on the limit established for the COCO dataset).

  • Input parameters (#) –

  • pixel_mean (Tuple[float]) – Values to be used for image normalization (BGR order). To train on images of different number of channels, set different mean & std. Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]

  • pixel_std (Tuple[float]) – When using pre-trained models in Detectron1 or any MSRA models, std has been absorbed into its conv1 weights, so the std needs to be set 1. Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)

  • vis_period (int) – The period (in terms of steps) for minibatch visualization at train time. Set to 0 to disable.

  • input_format (str) – Whether the model needs RGB, YUV, HSV etc.

  • finetune (bool) – whether finetune the detector or train from scratch. Default: True

Inputs:
  • batched_inputs: a list, batched outputs of DatasetMapper. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains:

    • image: Tensor, image in (C, H, W) format.

    • instances (optional): groundtruth Instances

    • “height”, “width” (int): the output resolution of the model, used in inference. See postprocess() for details.

  • labeled (bool, optional): whether has ground-truth label

Outputs:
  • outputs: A list of dict where each dict is the output for one input image. The dict contains a key “instances” whose value is a Instances and a key “features” whose value is the features of middle layers. The Instances object has the following keys: “pred_boxes”, “pred_classes”, “scores”, “pred_masks”, “pred_keypoints”

  • losses: A dict of different losses

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources