Shortcuts

Source code for tllib.translation.cycada

"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
from torch import Tensor


[docs]class SemanticConsistency(nn.Module): """ Semantic consistency loss is introduced by `CyCADA: Cycle-Consistent Adversarial Domain Adaptation (ICML 2018) <https://arxiv.org/abs/1711.03213>`_ This helps to prevent label flipping during image translation. Args: ignore_index (tuple, optional): Specifies target values that are ignored and do not contribute to the input gradient. When :attr:`size_average` is ``True``, the loss is averaged over non-ignored targets. Default: (). reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, C)` where `C = number of classes`, or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. - Output: scalar. If :attr:`reduction` is ``'none'``, then the same size as the target: :math:`(N)`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. Examples:: >>> loss = SemanticConsistency() >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.empty(3, dtype=torch.long).random_(5) >>> output = loss(input, target) >>> output.backward() """ def __init__(self, ignore_index=(), reduction='mean'): super(SemanticConsistency, self).__init__() self.ignore_index = ignore_index self.loss = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: for class_idx in self.ignore_index: target[target == class_idx] = -1 return self.loss(input, target)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources