# Re-weighting¶

class tllib.reweight.pada.ClassWeightModule(temperature=0.1)[source]

Calculating class weight based on the output of classifier. Introduced by Partial Adversarial Domain Adaptation (ECCV 2018)

Given classification logits outputs $$\{\hat{y}_i\}_{i=1}^n$$, where $$n$$ is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows

$\mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}\text{softmax}( \hat{y}_i / T),$

where $$\mathcal{\gamma}$$ is a $$|\mathcal{C}|$$-dimensional weight vector quantifying the contribution of each class and T is a hyper-parameters called temperature.

In practice, it’s possible that some of the weights are very small, thus, we normalize weight $$\mathcal{\gamma}$$ by dividing its largest element, i.e. $$\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})$$

Parameters

temperature (float, optional) – hyper-parameters $$T$$. Default: 0.1

Shape:
• Inputs: (minibatch, $$|\mathcal{C}|$$)

• Outputs: ($$|\mathcal{C}|$$,)

class tllib.reweight.pada.AutomaticUpdateClassWeightModule(update_steps, data_loader, classifier, num_classes, device, temperature=0.1, partial_classes_index=None)[source]

Calculating class weight based on the output of classifier. See ClassWeightModule about the details of the calculation. Every N iterations, the class weight is updated automatically.

Parameters
• update_steps (int) – N, the number of iterations to update class weight.

• classifier (torch.nn.Module) – Classifier.

• num_classes (int) – Number of classes.

• device (torch.device) – The device to run classifier.

• temperature (float, optional) – T, temperature in ClassWeightModule. Default: 0.1

• partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.

Examples:

>>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...)
>>> num_iterations = 10000
>>> for _ in range(num_iterations):
>>>     class_weight_module.step()
>>>     # weight for F.cross_entropy
>>>     w_c = class_weight_module.get_class_weight_for_cross_entropy_loss()

get_class_weight_for_adversarial_loss(source_labels)[source]
Outputs:
Shape:
• w_s: $$(minibatch, )$$

• w_t: $$(minibatch, )$$

get_class_weight_for_cross_entropy_loss()[source]

Outputs: weight for F.cross_entropy

Shape: $$(C, )$$ where C means the number of classes.

get_partial_classes_weight()[source]

Get class weight averaged on the partial classes and non-partial classes respectively.

Warning

This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.

tllib.reweight.pada.collect_classification_results(data_loader, classifier, device)[source]

Fetch data from data_loader, and then use classifier to collect classification results

Parameters
Returns

Classification results in shape (len(data_loader), $$|\mathcal{C}|$$).

## IWAN: Importance Weighted Adversarial Nets¶

class tllib.reweight.iwan.ImportanceWeightModule(discriminator, partial_classes_index=None)[source]

Calculating class weight based on the output of discriminator. Introduced by Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018)

Parameters
• discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is $$(N, F)$$ and output shape is $$(N, 1)$$

• partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.

Examples:

>>> domain_discriminator = DomainDiscriminator(1024, 1024)
>>> importance_weight_module = ImportanceWeightModule(domain_discriminator)
>>> num_iterations = 10000
>>> for _ in range(num_iterations):
>>>     # feature from source domain
>>>     f_s = torch.randn(32, 1024)
>>>     # importance weights for source instance
>>>     w_s = importance_weight_module.get_importance_weight(f_s)

get_importance_weight(feature)[source]

Get importance weights for each instance.

Parameters

feature (tensor) – feature from source domain, in shape $$(N, F)$$

Returns

instance weight in shape $$(N, 1)$$

get_partial_classes_weight(weights, labels)[source]

Get class weight averaged on the partial classes and non-partial classes respectively.

Parameters
• weights (tensor) – instance weight in shape $$(N, 1)$$

• labels (tensor) – ground truth labels in shape $$(N, 1)$$

Warning

This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.

## GroupDRO: Group Distributionally robust optimization¶

class tllib.reweight.groupdro.AutomaticUpdateDomainWeightModule(num_domains, eta, device)[source]

Maintaining group weight based on loss history of all domains according to Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (ICLR 2020).

Suppose we have $$N$$ domains. During each iteration, we first calculate unweighted loss among all domains, resulting in $$loss\in R^N$$. Then we update domain weight by

$w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N]$

where $$\eta$$ is the hyper parameter which ensures smoother change of weight. As $$w \in R^N$$ denotes a distribution, we normalize $$w$$ by its sum. At last, weighted loss is calculated as our objective

$objective = \sum_{k=1}^N w_k * loss_k$
Parameters
• num_domains (int) – The number of source domains.

• eta (float) – Hyper parameter eta.

• device (torch.device) – The device to run on.

get_domain_weight(sampled_domain_idxes)[source]

Get domain weight to calculate final objective.

Inputs:
• sampled_domain_idxes (list): sampled domain indexes in current mini-batch

Shape:
• sampled_domain_idxes: $$(D, )$$ where D means the number of sampled domains in current mini-batch

• Outputs: $$(D, )$$

update(sampled_domain_losses, sampled_domain_idxes)[source]

Update domain weight using loss of current mini-batch.

Inputs:
• sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch

• sampled_domain_idxes (list): sampled domain indexes in current mini-batch

Shape:
• sampled_domain_losses: $$(D, )$$ where D means the number of sampled domains in current mini-batch

• sampled_domain_idxes: $$(D, )$$