Source code for tllib.alignment.coral

@author: Baixu Chen
import torch
import torch.nn as nn

[docs]class CorrelationAlignmentLoss(nn.Module): r"""The `Correlation Alignment Loss` in `Deep CORAL: Correlation Alignment for Deep Domain Adaptation (ECCV 2016) <>`_. Given source features :math:`f_S` and target features :math:`f_T`, the covariance matrices are given by .. math:: C_S = \frac{1}{n_S-1}(f_S^Tf_S-\frac{1}{n_S}(\textbf{1}^Tf_S)^T(\textbf{1}^Tf_S)) .. math:: C_T = \frac{1}{n_T-1}(f_T^Tf_T-\frac{1}{n_T}(\textbf{1}^Tf_T)^T(\textbf{1}^Tf_T)) where :math:`\textbf{1}` denotes a column vector with all elements equal to 1, :math:`n_S, n_T` denotes number of source and target samples, respectively. We use :math:`d` to denote feature dimension, use :math:`{\Vert\cdot\Vert}^2_F` to denote the squared matrix `Frobenius norm`. The correlation alignment loss is given by .. math:: l_{CORAL} = \frac{1}{4d^2}\Vert C_S-C_T \Vert^2_F Inputs: - f_s (tensor): feature representations on source domain, :math:`f^s` - f_t (tensor): feature representations on target domain, :math:`f^t` Shape: - f_s, f_t: :math:`(N, d)` where d means the dimension of input features, :math:`N=n_S=n_T` is mini-batch size. - Outputs: scalar. """ def __init__(self): super(CorrelationAlignmentLoss, self).__init__() def forward(self, f_s: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor: mean_s = f_s.mean(0, keepdim=True) mean_t = f_t.mean(0, keepdim=True) cent_s = f_s - mean_s cent_t = f_t - mean_t cov_s =, cent_s) / (len(f_s) - 1) cov_t =, cent_t) / (len(f_t) - 1) mean_diff = (mean_s - mean_t).pow(2).mean() cov_diff = (cov_s - cov_t).pow(2).mean() return mean_diff + cov_diff


