Source code for tllib.self_training.uda

@author: Baixu Chen

import torch.nn as nn
import torch.nn.functional as F

[docs]class StrongWeakConsistencyLoss(nn.Module): """ Consistency loss between strong and weak augmented samples from `Unsupervised Data Augmentation for Consistency Training (NIPS 2020) <>`_. Args: threshold (float): Confidence threshold. temperature (float): Temperature. Inputs: - y_strong: unnormalized classifier predictions on strong augmented samples. - y: unnormalized classifier predictions on weak augmented samples. Shape: - y, y_strong: :math:`(minibatch, C)` where C means the number of classes. - Output: scalar. """ def __init__(self, threshold: float, temperature: float): super(StrongWeakConsistencyLoss, self).__init__() self.threshold = threshold self.temperature = temperature def forward(self, y_strong, y): confidence, _ = F.softmax(y.detach(), dim=1).max(dim=1) mask = (confidence > self.threshold).float() log_prob = F.log_softmax(y_strong / self.temperature, dim=1) con_loss = (F.kl_div(log_prob, F.softmax(y.detach(), dim=1), reduction='none').sum(dim=1)) con_loss = (con_loss * mask).sum() / max(mask.sum(), 1) return con_loss


