import torch
from torch.utils.data import Dataset
from .strategy import Strategy
[docs]class CoreSet(Strategy):
"""
Implementation of CoreSet :footcite:`sener2018active` Strategy. A diversity-based
approach using coreset selection. The embedding of each example is computed by the network’s
penultimate layer and the samples at each round are selected using a greedy furthest-first
traversal conditioned on all labeled examples.
Parameters
----------
labeled_dataset: torch.utils.data.Dataset
The labeled training dataset
unlabeled_dataset: torch.utils.data.Dataset
The unlabeled pool dataset
net: torch.nn.Module
The deep model to use
nclasses: int
Number of unique values for the target
args: dict
Specify additional parameters
- **batch_size**: The batch size used internally for torch.utils.data.DataLoader objects. (int, optional)
- **device**: The device to be used for computation. PyTorch constructs are transferred to this device. Usually is one of 'cuda' or 'cpu'. (string, optional)
- **loss**: The loss function to be used in computations. (typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional)
"""
def __init__(self, labeled_dataset, unlabeled_dataset, net, nclasses, args={}):
super(CoreSet, self).__init__(labeled_dataset, unlabeled_dataset, net, nclasses, args)
def furthest_first(self, unlabeled_embeddings, labeled_embeddings, n):
unlabeled_embeddings = unlabeled_embeddings.to(self.device)
labeled_embeddings = labeled_embeddings.to(self.device)
m = unlabeled_embeddings.shape[0]
if labeled_embeddings.shape[0] == 0:
min_dist = torch.tile(float("inf"), m)
else:
dist_ctr = torch.cdist(unlabeled_embeddings, labeled_embeddings, p=2)
min_dist = torch.min(dist_ctr, dim=1)[0]
idxs = []
for i in range(n):
idx = torch.argmax(min_dist)
idxs.append(idx.item())
dist_new_ctr = torch.cdist(unlabeled_embeddings, unlabeled_embeddings[[idx],:])
min_dist = torch.minimum(min_dist, dist_new_ctr[:,0])
return idxs
[docs] def select(self, budget):
"""
Selects next set of points
Parameters
----------
budget: int
Number of data points to select for labeling
Returns
----------
idxs: list
List of selected data point indices with respect to unlabeled_dataset
"""
class NoLabelDataset(Dataset):
def __init__(self, wrapped_dataset):
self.wrapped_dataset = wrapped_dataset
def __getitem__(self, index):
instance = self.wrapped_dataset[index]
if type(instance) == dict:
if "labels" in instance:
del instance["labels"]
features = instance
else:
features = instance[0]
return features
def __len__(self):
return len(self.wrapped_dataset)
self.model.eval()
embedding_unlabeled = self.get_embedding(self.unlabeled_dataset)
embedding_labeled = self.get_embedding(NoLabelDataset(self.labeled_dataset))
chosen = self.furthest_first(embedding_unlabeled, embedding_labeled, budget)
return chosen