Reweighting¶
PADA: Partial Adversarial Domain Adaptation¶

class
tllib.reweight.pada.
ClassWeightModule
(temperature=0.1)[source]¶ Calculating class weight based on the output of classifier. Introduced by Partial Adversarial Domain Adaptation (ECCV 2018)
Given classification logits outputs \(\{\hat{y}_i\}_{i=1}^n\), where \(n\) is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows
\[\mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}\text{softmax}( \hat{y}_i / T),\]where \(\mathcal{\gamma}\) is a \(\mathcal{C}\)dimensional weight vector quantifying the contribution of each class and T is a hyperparameters called temperature.
In practice, it’s possible that some of the weights are very small, thus, we normalize weight \(\mathcal{\gamma}\) by dividing its largest element, i.e. \(\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})\)
 Parameters
temperature (float, optional) – hyperparameters \(T\). Default: 0.1
 Shape:
Inputs: (minibatch, \(\mathcal{C}\))
Outputs: (\(\mathcal{C}\),)

class
tllib.reweight.pada.
AutomaticUpdateClassWeightModule
(update_steps, data_loader, classifier, num_classes, device, temperature=0.1, partial_classes_index=None)[source]¶ Calculating class weight based on the output of classifier. See
ClassWeightModule
about the details of the calculation. Every N iterations, the class weight is updated automatically. Parameters
update_steps (int) – N, the number of iterations to update class weight.
data_loader (torch.utils.data.DataLoader) – The data loader from which we can collect classification outputs.
classifier (torch.nn.Module) – Classifier.
num_classes (int) – Number of classes.
device (torch.device) – The device to run classifier.
temperature (float, optional) – T, temperature in ClassWeightModule. Default: 0.1
partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in realworld dataset, we have no access to the index of partial classes. Default: None.
Examples:
>>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> class_weight_module.step() >>> # weight for F.cross_entropy >>> w_c = class_weight_module.get_class_weight_for_cross_entropy_loss() >>> # weight for tllib.alignment.dann.DomainAdversarialLoss >>> w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss()

get_class_weight_for_adversarial_loss
(source_labels)[source]¶  Outputs:
w_s: source weight for
DomainAdversarialLoss
w_t: target weight for
DomainAdversarialLoss
 Shape:
w_s: \((minibatch, )\)
w_t: \((minibatch, )\)

get_class_weight_for_cross_entropy_loss
()[source]¶ Outputs: weight for F.cross_entropy
Shape: \((C, )\) where C means the number of classes.

get_partial_classes_weight
()[source]¶ Get class weight averaged on the partial classes and nonpartial classes respectively.
Warning
This function is just for debugging, since in realworld dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.

tllib.reweight.pada.
collect_classification_results
(data_loader, classifier, device)[source]¶ Fetch data from data_loader, and then use classifier to collect classification results
 Parameters
data_loader (torch.utils.data.DataLoader) – Data loader.
classifier (torch.nn.Module) – A classifier.
device (torch.device) –
 Returns
Classification results in shape (len(data_loader), \(\mathcal{C}\)).
IWAN: Importance Weighted Adversarial Nets¶

class
tllib.reweight.iwan.
ImportanceWeightModule
(discriminator, partial_classes_index=None)[source]¶ Calculating class weight based on the output of discriminator. Introduced by Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018)
 Parameters
discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is \((N, F)\) and output shape is \((N, 1)\)
partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in realworld dataset, we have no access to the index of partial classes. Default: None.
Examples:
>>> domain_discriminator = DomainDiscriminator(1024, 1024) >>> importance_weight_module = ImportanceWeightModule(domain_discriminator) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> # feature from source domain >>> f_s = torch.randn(32, 1024) >>> # importance weights for source instance >>> w_s = importance_weight_module.get_importance_weight(f_s)

get_importance_weight
(feature)[source]¶ Get importance weights for each instance.
 Parameters
feature (tensor) – feature from source domain, in shape \((N, F)\)
 Returns
instance weight in shape \((N, 1)\)

get_partial_classes_weight
(weights, labels)[source]¶ Get class weight averaged on the partial classes and nonpartial classes respectively.
 Parameters
weights (tensor) – instance weight in shape \((N, 1)\)
labels (tensor) – ground truth labels in shape \((N, 1)\)
Warning
This function is just for debugging, since in realworld dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.
GroupDRO: Group Distributionally robust optimization¶

class
tllib.reweight.groupdro.
AutomaticUpdateDomainWeightModule
(num_domains, eta, device)[source]¶ Maintaining group weight based on loss history of all domains according to Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for WorstCase Generalization (ICLR 2020).
Suppose we have \(N\) domains. During each iteration, we first calculate unweighted loss among all domains, resulting in \(loss\in R^N\). Then we update domain weight by
\[w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N]\]where \(\eta\) is the hyper parameter which ensures smoother change of weight. As \(w \in R^N\) denotes a distribution, we normalize \(w\) by its sum. At last, weighted loss is calculated as our objective
\[objective = \sum_{k=1}^N w_k * loss_k\] Parameters
num_domains (int) – The number of source domains.
eta (float) – Hyper parameter eta.
device (torch.device) – The device to run on.

get_domain_weight
(sampled_domain_idxes)[source]¶ Get domain weight to calculate final objective.
 Inputs:
sampled_domain_idxes (list): sampled domain indexes in current minibatch
 Shape:
sampled_domain_idxes: \((D, )\) where D means the number of sampled domains in current minibatch
Outputs: \((D, )\)

update
(sampled_domain_losses, sampled_domain_idxes)[source]¶ Update domain weight using loss of current minibatch.
 Inputs:
sampled_domain_losses (tensor): loss of among sampled domains in current minibatch
sampled_domain_idxes (list): sampled domain indexes in current minibatch
 Shape:
sampled_domain_losses: \((D, )\) where D means the number of sampled domains in current minibatch
sampled_domain_idxes: \((D, )\)