Source code for tllib.utils.analysis

import torch
from import DataLoader
import torch.nn as nn
import tqdm

[docs]def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module, device: torch.device, max_num_features=None) -> torch.Tensor: """ Fetch data from `data_loader`, and then use `feature_extractor` to collect features Args: data_loader ( Data loader. feature_extractor (torch.nn.Module): A feature extractor. device (torch.device) max_num_features (int): The max number of features to return Returns: Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`). """ feature_extractor.eval() all_features = [] with torch.no_grad(): for i, data in enumerate(tqdm.tqdm(data_loader)): if max_num_features is not None and i >= max_num_features: break inputs = data[0].to(device) feature = feature_extractor(inputs).cpu() all_features.append(feature) return, dim=0)


Access comprehensive documentation for Transfer Learning Library

View Docs


Get started for Transfer Learning Library

Get Started

Paper List

Get started for transfer learning

View Resources