Shortcuts

Source code for tllib.translation.fourier_transform

"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import numpy as np
import os
import tqdm
import random
from PIL import Image
from typing import Optional, Sequence
import torch.nn as nn


[docs]def low_freq_mutate(amp_src: np.ndarray, amp_trg: np.ndarray, beta: Optional[int] = 1): """ Args: amp_src (numpy.ndarray): amplitude component of the Fourier transform of source image amp_trg (numpy.ndarray): amplitude component of the Fourier transform of target image beta (int, optional): the size of the center region to be replace. Default: 1 Returns: amplitude component of the Fourier transform of source image whose low-frequency component is replaced by that of the target image. """ # Shift the zero-frequency component to the center of the spectrum. a_src = np.fft.fftshift(amp_src, axes=(-2, -1)) a_trg = np.fft.fftshift(amp_trg, axes=(-2, -1)) # The low-frequency component includes # the area where the horizontal and vertical distance from the center does not exceed beta _, h, w = a_src.shape c_h = np.floor(h / 2.0).astype(int) c_w = np.floor(w / 2.0).astype(int) h1 = c_h - beta h2 = c_h + beta + 1 w1 = c_w - beta w2 = c_w + beta + 1 # The low-frequency component of source amplitude is replaced by the target amplitude a_src[:, h1:h2, w1:w2] = a_trg[:, h1:h2, w1:w2] a_src = np.fft.ifftshift(a_src, axes=(-2, -1)) return a_src
[docs]class FourierTransform(nn.Module): """ Fourier Transform is introduced by `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) <https://arxiv.org/abs/2004.05498>`_ Fourier Transform replace the low frequency component of the amplitude of the source image to that of the target image. Denote with :math:`M_{β}` a mask, whose value is zero except for the center region: .. math:: M_{β}(h,w) = \mathbb{1}_{(h, w)\in [-β,β, -β, β]} Given images :math:`x^s` from source domain and :math:`x^t` from target domain, the source image in the target style is .. math:: x^{s→t} = \mathcal{F}^{-1}([ M_{β}\circ\mathcal{F}^A(x^t) + (1-M_{β})\circ\mathcal{F}^A(x^s), \mathcal{F}^P(x^s) ]) where :math:`\mathcal{F}^A`, :math:`\mathcal{F}^P` are the amplitude and phase component of the Fourier Transform :math:`\mathcal{F}` of an RGB image. Args: image_list (sequence[str]): A sequence of image list from the target domain. amplitude_dir (str): Specifies the directory to put the amplitude component of the target image. beta (int, optional): :math:`β`. Default: 1. rebuild (bool, optional): whether rebuild the amplitude component of the target image in the given directory. Inputs: - image (PIL Image): image from the source domain, :math:`x^t`. Examples: >>> from tllib.translation.fourier_transform import FourierTransform >>> image_list = ["target_image_path1", "target_image_path2"] >>> amplitude_dir = "path/to/amplitude_dir" >>> fourier_transform = FourierTransform(image_list, amplitude_dir, beta=1, rebuild=False) >>> source_image = np.array((256, 256, 3)) # image form source domain >>> source_image_in_target_style = fourier_transform(source_image) .. note:: The meaning of :math:`β` is different from that of the origin paper. Experimentally, we found that the size of the center region in the frequency space should be constant when the image size increases. Thus we make the size of the center region independent of the image size. A recommended value for :math:`β` is 1. .. note:: The image structure of the source domain and target domain should be as similar as possible, thus for segemntation tasks, FourierTransform should be used before RandomResizeCrop and other transformations. .. note:: The image size of the source domain and the target domain need to be the same, thus before FourierTransform, you should use Resize to convert the source image to the target image size. Examples: >>> from tllib.translation.fourier_transform import FourierTransform >>> import tllibvision.datasets.segmentation.transforms as T >>> from PIL import Image >>> target_image_list = ["target_image_path1", "target_image_path2"] >>> amplitude_dir = "path/to/amplitude_dir" >>> # build a fourier transform that translate source images to the target style >>> fourier_transform = T.wrapper(FourierTransform)(target_image_list, amplitude_dir) >>> transforms=T.Compose([ ... # convert source image to the size of the target image before fourier transform ... T.Resize((2048, 1024)), ... fourier_transform, ... T.RandomResizedCrop((1024, 512)), ... T.RandomHorizontalFlip(), ... ]) >>> source_image = Image.open("path/to/source_image") # image form source domain >>> source_image_in_target_style = transforms(source_image) """ # TODO add image examples when beta is different def __init__(self, image_list: Sequence[str], amplitude_dir: str, beta: Optional[int] = 1, rebuild: Optional[bool] = False): super(FourierTransform, self).__init__() self.amplitude_dir = amplitude_dir if not os.path.exists(amplitude_dir) or rebuild: os.makedirs(amplitude_dir, exist_ok=True) self.build_amplitude(image_list, amplitude_dir) self.beta = beta self.length = len(image_list) @staticmethod def build_amplitude(image_list, amplitude_dir): # extract amplitudes from target domain for i, image_name in enumerate(tqdm.tqdm(image_list)): image = Image.open(image_name).convert('RGB') image = np.asarray(image, np.float32) image = image.transpose((2, 0, 1)) fft = np.fft.fft2(image, axes=(-2, -1)) amp = np.abs(fft) np.save(os.path.join(amplitude_dir, "{}.npy".format(i)), amp) def forward(self, image): # randomly sample a target image and load its amplitude component amp_trg = np.load(os.path.join(self.amplitude_dir, "{}.npy".format(random.randint(0, self.length-1)))) image = np.asarray(image, np.float32) image = image.transpose((2, 0, 1)) # get fft, amplitude on source domain fft_src = np.fft.fft2(image, axes=(-2, -1)) amp_src, pha_src = np.abs(fft_src), np.angle(fft_src) # mutate the amplitude part of source with target amp_src_ = low_freq_mutate(amp_src, amp_trg, beta=self.beta) # mutated fft of source fft_src_ = amp_src_ * np.exp(1j * pha_src) # get the mutated image src_in_trg = np.fft.ifft2(fft_src_, axes=(-2, -1)) src_in_trg = np.real(src_in_trg) src_in_trg = src_in_trg.transpose((1, 2, 0)) src_in_trg = Image.fromarray(src_in_trg.clip(min=0, max=255).astype('uint8')).convert('RGB') return src_in_trg

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources