Source code for distil.active_learning_strategies.core_set

import torch

from torch.utils.data import Dataset
from .strategy import Strategy

[docs]class CoreSet(Strategy): """ Implementation of CoreSet :footcite:`sener2018active` Strategy. A diversity-based approach using coreset selection. The embedding of each example is computed by the network’s penultimate layer and the samples at each round are selected using a greedy furthest-first traversal conditioned on all labeled examples. Parameters ---------- labeled_dataset: torch.utils.data.Dataset The labeled training dataset unlabeled_dataset: torch.utils.data.Dataset The unlabeled pool dataset net: torch.nn.Module The deep model to use nclasses: int Number of unique values for the target args: dict Specify additional parameters - **batch_size**: The batch size used internally for torch.utils.data.DataLoader objects. (int, optional) - **device**: The device to be used for computation. PyTorch constructs are transferred to this device. Usually is one of 'cuda' or 'cpu'. (string, optional) - **loss**: The loss function to be used in computations. (typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional) """ def __init__(self, labeled_dataset, unlabeled_dataset, net, nclasses, args={}): super(CoreSet, self).__init__(labeled_dataset, unlabeled_dataset, net, nclasses, args) def furthest_first(self, unlabeled_embeddings, labeled_embeddings, n): unlabeled_embeddings = unlabeled_embeddings.to(self.device) labeled_embeddings = labeled_embeddings.to(self.device) m = unlabeled_embeddings.shape[0] if labeled_embeddings.shape[0] == 0: min_dist = torch.tile(float("inf"), m) else: dist_ctr = torch.cdist(unlabeled_embeddings, labeled_embeddings, p=2) min_dist = torch.min(dist_ctr, dim=1)[0] idxs = [] for i in range(n): idx = torch.argmax(min_dist) idxs.append(idx.item()) dist_new_ctr = torch.cdist(unlabeled_embeddings, unlabeled_embeddings[[idx],:]) min_dist = torch.minimum(min_dist, dist_new_ctr[:,0]) return idxs
[docs] def select(self, budget): """ Selects next set of points Parameters ---------- budget: int Number of data points to select for labeling Returns ---------- idxs: list List of selected data point indices with respect to unlabeled_dataset """ class NoLabelDataset(Dataset): def __init__(self, wrapped_dataset): self.wrapped_dataset = wrapped_dataset def __getitem__(self, index): instance = self.wrapped_dataset[index] if type(instance) == dict: if "labels" in instance: del instance["labels"] features = instance else: features = instance[0] return features def __len__(self): return len(self.wrapped_dataset) self.model.eval() embedding_unlabeled = self.get_embedding(self.unlabeled_dataset) embedding_labeled = self.get_embedding(NoLabelDataset(self.labeled_dataset)) chosen = self.furthest_first(embedding_unlabeled, embedding_labeled, budget) return chosen