Re-weighting¶
PADA: Partial Adversarial Domain Adaptation¶
-
class
tllib.reweight.pada.
ClassWeightModule
(temperature=0.1)[source]¶ Calculating class weight based on the output of classifier. Introduced by Partial Adversarial Domain Adaptation (ECCV 2018)
Given classification logits outputs \(\{\hat{y}_i\}_{i=1}^n\), where \(n\) is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows
\[\mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}\text{softmax}( \hat{y}_i / T),\]where \(\mathcal{\gamma}\) is a \(|\mathcal{C}|\)-dimensional weight vector quantifying the contribution of each class and T is a hyper-parameters called temperature.
In practice, it’s possible that some of the weights are very small, thus, we normalize weight \(\mathcal{\gamma}\) by dividing its largest element, i.e. \(\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})\)
- Parameters
temperature (float, optional) – hyper-parameters \(T\). Default: 0.1
- Shape:
Inputs: (minibatch, \(|\mathcal{C}|\))
Outputs: (\(|\mathcal{C}|\),)
-
class
tllib.reweight.pada.
AutomaticUpdateClassWeightModule
(update_steps, data_loader, classifier, num_classes, device, temperature=0.1, partial_classes_index=None)[source]¶ Calculating class weight based on the output of classifier. See
ClassWeightModule
about the details of the calculation. Every N iterations, the class weight is updated automatically.- Parameters
update_steps (int) – N, the number of iterations to update class weight.
data_loader (torch.utils.data.DataLoader) – The data loader from which we can collect classification outputs.
classifier (torch.nn.Module) – Classifier.
num_classes (int) – Number of classes.
device (torch.device) – The device to run classifier.
temperature (float, optional) – T, temperature in ClassWeightModule. Default: 0.1
partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.
Examples:
>>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> class_weight_module.step() >>> # weight for F.cross_entropy >>> w_c = class_weight_module.get_class_weight_for_cross_entropy_loss() >>> # weight for tllib.alignment.dann.DomainAdversarialLoss >>> w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss()
-
get_class_weight_for_adversarial_loss
(source_labels)[source]¶ - Outputs:
w_s: source weight for
DomainAdversarialLoss
w_t: target weight for
DomainAdversarialLoss
- Shape:
w_s: \((minibatch, )\)
w_t: \((minibatch, )\)
-
get_class_weight_for_cross_entropy_loss
()[source]¶ Outputs: weight for F.cross_entropy
Shape: \((C, )\) where C means the number of classes.
-
get_partial_classes_weight
()[source]¶ Get class weight averaged on the partial classes and non-partial classes respectively.
Warning
This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.
-
tllib.reweight.pada.
collect_classification_results
(data_loader, classifier, device)[source]¶ Fetch data from data_loader, and then use classifier to collect classification results
- Parameters
data_loader (torch.utils.data.DataLoader) – Data loader.
classifier (torch.nn.Module) – A classifier.
device (torch.device) –
- Returns
Classification results in shape (len(data_loader), \(|\mathcal{C}|\)).
IWAN: Importance Weighted Adversarial Nets¶
-
class
tllib.reweight.iwan.
ImportanceWeightModule
(discriminator, partial_classes_index=None)[source]¶ Calculating class weight based on the output of discriminator. Introduced by Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018)
- Parameters
discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is \((N, F)\) and output shape is \((N, 1)\)
partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.
Examples:
>>> domain_discriminator = DomainDiscriminator(1024, 1024) >>> importance_weight_module = ImportanceWeightModule(domain_discriminator) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> # feature from source domain >>> f_s = torch.randn(32, 1024) >>> # importance weights for source instance >>> w_s = importance_weight_module.get_importance_weight(f_s)
-
get_importance_weight
(feature)[source]¶ Get importance weights for each instance.
- Parameters
feature (tensor) – feature from source domain, in shape \((N, F)\)
- Returns
instance weight in shape \((N, 1)\)
-
get_partial_classes_weight
(weights, labels)[source]¶ Get class weight averaged on the partial classes and non-partial classes respectively.
- Parameters
weights (tensor) – instance weight in shape \((N, 1)\)
labels (tensor) – ground truth labels in shape \((N, 1)\)
Warning
This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.
GroupDRO: Group Distributionally robust optimization¶
-
class
tllib.reweight.groupdro.
AutomaticUpdateDomainWeightModule
(num_domains, eta, device)[source]¶ Maintaining group weight based on loss history of all domains according to Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (ICLR 2020).
Suppose we have \(N\) domains. During each iteration, we first calculate unweighted loss among all domains, resulting in \(loss\in R^N\). Then we update domain weight by
\[w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N]\]where \(\eta\) is the hyper parameter which ensures smoother change of weight. As \(w \in R^N\) denotes a distribution, we normalize \(w\) by its sum. At last, weighted loss is calculated as our objective
\[objective = \sum_{k=1}^N w_k * loss_k\]- Parameters
num_domains (int) – The number of source domains.
eta (float) – Hyper parameter eta.
device (torch.device) – The device to run on.
-
get_domain_weight
(sampled_domain_idxes)[source]¶ Get domain weight to calculate final objective.
- Inputs:
sampled_domain_idxes (list): sampled domain indexes in current mini-batch
- Shape:
sampled_domain_idxes: \((D, )\) where D means the number of sampled domains in current mini-batch
Outputs: \((D, )\)
-
update
(sampled_domain_losses, sampled_domain_idxes)[source]¶ Update domain weight using loss of current mini-batch.
- Inputs:
sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch
sampled_domain_idxes (list): sampled domain indexes in current mini-batch
- Shape:
sampled_domain_losses: \((D, )\) where D means the number of sampled domains in current mini-batch
sampled_domain_idxes: \((D, )\)