Shortcuts

Re-weighting

PADA: Partial Adversarial Domain Adaptation

class tllib.reweight.pada.ClassWeightModule(temperature=0.1)[source]

Calculating class weight based on the output of classifier. Introduced by Partial Adversarial Domain Adaptation (ECCV 2018)

Given classification logits outputs \(\{\hat{y}_i\}_{i=1}^n\), where \(n\) is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows

\[\mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}\text{softmax}( \hat{y}_i / T),\]

where \(\mathcal{\gamma}\) is a \(|\mathcal{C}|\)-dimensional weight vector quantifying the contribution of each class and T is a hyper-parameters called temperature.

In practice, it’s possible that some of the weights are very small, thus, we normalize weight \(\mathcal{\gamma}\) by dividing its largest element, i.e. \(\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})\)

Parameters

temperature (float, optional) – hyper-parameters \(T\). Default: 0.1

Shape:
  • Inputs: (minibatch, \(|\mathcal{C}|\))

  • Outputs: (\(|\mathcal{C}|\),)

class tllib.reweight.pada.AutomaticUpdateClassWeightModule(update_steps, data_loader, classifier, num_classes, device, temperature=0.1, partial_classes_index=None)[source]

Calculating class weight based on the output of classifier. See ClassWeightModule about the details of the calculation. Every N iterations, the class weight is updated automatically.

Parameters
  • update_steps (int) – N, the number of iterations to update class weight.

  • data_loader (torch.utils.data.DataLoader) – The data loader from which we can collect classification outputs.

  • classifier (torch.nn.Module) – Classifier.

  • num_classes (int) – Number of classes.

  • device (torch.device) – The device to run classifier.

  • temperature (float, optional) – T, temperature in ClassWeightModule. Default: 0.1

  • partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.

Examples:

>>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...)
>>> num_iterations = 10000
>>> for _ in range(num_iterations):
>>>     class_weight_module.step()
>>>     # weight for F.cross_entropy
>>>     w_c = class_weight_module.get_class_weight_for_cross_entropy_loss()
>>>     # weight for tllib.alignment.dann.DomainAdversarialLoss
>>>     w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss()
get_class_weight_for_adversarial_loss(source_labels)[source]
Outputs:
Shape:
  • w_s: \((minibatch, )\)

  • w_t: \((minibatch, )\)

get_class_weight_for_cross_entropy_loss()[source]

Outputs: weight for F.cross_entropy

Shape: \((C, )\) where C means the number of classes.

get_partial_classes_weight()[source]

Get class weight averaged on the partial classes and non-partial classes respectively.

Warning

This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.

tllib.reweight.pada.collect_classification_results(data_loader, classifier, device)[source]

Fetch data from data_loader, and then use classifier to collect classification results

Parameters
Returns

Classification results in shape (len(data_loader), \(|\mathcal{C}|\)).

IWAN: Importance Weighted Adversarial Nets

class tllib.reweight.iwan.ImportanceWeightModule(discriminator, partial_classes_index=None)[source]

Calculating class weight based on the output of discriminator. Introduced by Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018)

Parameters
  • discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is \((N, F)\) and output shape is \((N, 1)\)

  • partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.

Examples:

>>> domain_discriminator = DomainDiscriminator(1024, 1024)
>>> importance_weight_module = ImportanceWeightModule(domain_discriminator)
>>> num_iterations = 10000
>>> for _ in range(num_iterations):
>>>     # feature from source domain
>>>     f_s = torch.randn(32, 1024)
>>>     # importance weights for source instance
>>>     w_s = importance_weight_module.get_importance_weight(f_s)
get_importance_weight(feature)[source]

Get importance weights for each instance.

Parameters

feature (tensor) – feature from source domain, in shape \((N, F)\)

Returns

instance weight in shape \((N, 1)\)

get_partial_classes_weight(weights, labels)[source]

Get class weight averaged on the partial classes and non-partial classes respectively.

Parameters
  • weights (tensor) – instance weight in shape \((N, 1)\)

  • labels (tensor) – ground truth labels in shape \((N, 1)\)

Warning

This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.

GroupDRO: Group Distributionally robust optimization

class tllib.reweight.groupdro.AutomaticUpdateDomainWeightModule(num_domains, eta, device)[source]

Maintaining group weight based on loss history of all domains according to Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (ICLR 2020).

Suppose we have \(N\) domains. During each iteration, we first calculate unweighted loss among all domains, resulting in \(loss\in R^N\). Then we update domain weight by

\[w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N]\]

where \(\eta\) is the hyper parameter which ensures smoother change of weight. As \(w \in R^N\) denotes a distribution, we normalize \(w\) by its sum. At last, weighted loss is calculated as our objective

\[objective = \sum_{k=1}^N w_k * loss_k\]
Parameters
  • num_domains (int) – The number of source domains.

  • eta (float) – Hyper parameter eta.

  • device (torch.device) – The device to run on.

get_domain_weight(sampled_domain_idxes)[source]

Get domain weight to calculate final objective.

Inputs:
  • sampled_domain_idxes (list): sampled domain indexes in current mini-batch

Shape:
  • sampled_domain_idxes: \((D, )\) where D means the number of sampled domains in current mini-batch

  • Outputs: \((D, )\)

update(sampled_domain_losses, sampled_domain_idxes)[source]

Update domain weight using loss of current mini-batch.

Inputs:
  • sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch

  • sampled_domain_idxes (list): sampled domain indexes in current mini-batch

Shape:
  • sampled_domain_losses: \((D, )\) where D means the number of sampled domains in current mini-batch

  • sampled_domain_idxes: \((D, )\)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources