Source code for tllib.reweight.iwan

@author: Baixu Chen
from typing import Optional, List, Dict
import torch
import torch.nn as nn

from tllib.modules.classifier import Classifier as ClassifierBase

[docs]class ImportanceWeightModule(object): r""" Calculating class weight based on the output of discriminator. Introduced by `Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018) <>`_ Args: discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of features. Its input shape is :math:`(N, F)` and output shape is :math:`(N, 1)` partial_classes_index (list[int], optional): The index of partial classes. Note that this parameter is \ just for debugging, since in real-world dataset, we have no access to the index of partial classes. \ Default: None. Examples:: >>> domain_discriminator = DomainDiscriminator(1024, 1024) >>> importance_weight_module = ImportanceWeightModule(domain_discriminator) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> # feature from source domain >>> f_s = torch.randn(32, 1024) >>> # importance weights for source instance >>> w_s = importance_weight_module.get_importance_weight(f_s) """ def __init__(self, discriminator: nn.Module, partial_classes_index: Optional[List[int]] = None): self.discriminator = discriminator self.partial_classes_index = partial_classes_index
[docs] def get_importance_weight(self, feature): """ Get importance weights for each instance. Args: feature (tensor): feature from source domain, in shape :math:`(N, F)` Returns: instance weight in shape :math:`(N, 1)` """ weight = 1. - self.discriminator(feature) weight = weight / weight.mean() weight = weight.detach() return weight
[docs] def get_partial_classes_weight(self, weights: torch.Tensor, labels: torch.Tensor): """ Get class weight averaged on the partial classes and non-partial classes respectively. Args: weights (tensor): instance weight in shape :math:`(N, 1)` labels (tensor): ground truth labels in shape :math:`(N, 1)` .. warning:: This function is just for debugging, since in real-world dataset, we have no access to the index of \ partial classes and this function will throw an error when `partial_classes_index` is None. """ assert self.partial_classes_index is not None weights = weights.squeeze() is_partial = torch.Tensor([label in self.partial_classes_index for label in labels]).to(weights.device) if is_partial.sum() > 0: partial_classes_weight = (weights * is_partial).sum() / is_partial.sum() else: partial_classes_weight = torch.tensor(0) not_partial = 1. - is_partial if not_partial.sum() > 0: not_partial_classes_weight = (weights * not_partial).sum() / not_partial.sum() else: not_partial_classes_weight = torch.tensor(0) return partial_classes_weight, not_partial_classes_weight
class ImageClassifier(ClassifierBase): r"""The Image Classifier for `Importance Weighted Adversarial Nets for Partial Domain Adaptation <>`_ """ def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)


Access comprehensive documentation for Transfer Learning Library

View Docs


Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources