Source code for tllib.translation.cycada

@author: Junguang Jiang
import torch.nn as nn
from torch import Tensor

[docs]class SemanticConsistency(nn.Module): """ Semantic consistency loss is introduced by `CyCADA: Cycle-Consistent Adversarial Domain Adaptation (ICML 2018) <>`_ This helps to prevent label flipping during image translation. Args: ignore_index (tuple, optional): Specifies target values that are ignored and do not contribute to the input gradient. When :attr:`size_average` is ``True``, the loss is averaged over non-ignored targets. Default: (). reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, C)` where `C = number of classes`, or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. - Output: scalar. If :attr:`reduction` is ``'none'``, then the same size as the target: :math:`(N)`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. Examples:: >>> loss = SemanticConsistency() >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.empty(3, dtype=torch.long).random_(5) >>> output = loss(input, target) >>> output.backward() """ def __init__(self, ignore_index=(), reduction='mean'): super(SemanticConsistency, self).__init__() self.ignore_index = ignore_index self.loss = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: for class_idx in self.ignore_index: target[target == class_idx] = -1 return self.loss(input, target)


Access comprehensive documentation for Transfer Learning Library

View Docs


Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources