Source code for tllib.modules.kernels

@author: Junguang Jiang
from typing import Optional
import torch
import torch.nn as nn

__all__ = ['GaussianKernel']

[docs]class GaussianKernel(nn.Module): r"""Gaussian Kernel Matrix Gaussian Kernel k is defined by .. math:: k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right) where :math:`x_1, x_2 \in R^d` are 1-d tensors. Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),` .. math:: K(X)_{i,j} = k(x_i, x_j) Also by default, during training this layer keeps running estimates of the mean of L2 distances, which are then used to set hyperparameter :math:`\sigma`. Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and use a fixed :math:`\sigma` instead. Args: sigma (float, optional): bandwidth :math:`\sigma`. Default: None track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`. Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True`` alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True`` Inputs: - X (tensor): input group :math:`X` Shape: - Inputs: :math:`(minibatch, F)` where F means the dimension of input features. - Outputs: :math:`(minibatch, minibatch)` """ def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True, alpha: Optional[float] = 1.): super(GaussianKernel, self).__init__() assert track_running_stats or sigma is not None self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None self.track_running_stats = track_running_stats self.alpha = alpha def forward(self, X: torch.Tensor) -> torch.Tensor: l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2) if self.track_running_stats: self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach()) return torch.exp(-l2_distance_square / (2 * self.sigma_square))


