Source code for distil.active_learning_strategies.glister

  
from .strategy import Strategy
import numpy as np

import torch
import torch.nn.functional as F

from torch.utils.data import DataLoader, ConcatDataset, Dataset

import math

def dict_to(dictionary, device):
    
    # Predict the most likely class
    if type(dictionary) == dict:
        for key in dictionary:
            value = dictionary[key]
            if hasattr(value, "to"):
                dictionary[key] = value.to(device=device)
    
    return dictionary

[docs]class GLISTER(Strategy): """ This is implementation of GLISTER-ACTIVE from the paper GLISTER: Generalization based Data Subset Selection for Efficient and Robust Learning :footcite:`killamsetty2020glister`. GLISTER methods tries to solve a bi-level optimisation problem. .. math:: \\overbrace{\\underset{{S \\subseteq {\\mathcal U}, |S| \\leq k}}{\\operatorname{argmin\\hspace{0.7mm}}} L_V(\\underbrace{\\underset{\\theta}{\\operatorname{argmin\\hspace{0.7mm}}} L_T( \\theta, S)}_{inner-level}, {\\mathcal V})}^{outer-level} In the above equation, :math:`\\mathcal{U}` denotes the Data without lables i.e. `unlabeled_x`, :math:`\\mathcal{V}` denotes the validation set that guides the subset selection process, :math:`L_T` denotes the training loss, :math:`L_V` denotes the validation loss, :math:`S` denotes the data subset selected at each round, and :math:`k` is the `budget`. Since, solving the complete inner-optimization is expensive, GLISTER-ONLINE adopts a online one-step meta approximation where we approximate the solution to inner problem by taking a single gradient step. The optimization problem after the approximation is as follows: .. math:: \\overbrace{\\underset{{S \\subseteq {\\mathcal U}, |S| \\leq k}}{\\operatorname{argmin\\hspace{0.7mm}}} L_V(\\underbrace{\\theta - \\eta \\nabla_{\\theta}L_T(\\theta, S)}_{inner-level}, {\\mathcal V})}^{outer-level} In the above equation, :math:`\\eta` denotes the step-size used for one-step gradient update. Parameters ---------- labeled_dataset: torch.utils.data.Dataset The labeled training dataset unlabeled_dataset: torch.utils.data.Dataset The unlabeled pool dataset net: torch.nn.Module The deep model to use nclasses: int Number of unique values for the target args: dict Specify additional parameters - **batch_size**: The batch size used internally for torch.utils.data.DataLoader objects. (int, optional) - **device**: The device to be used for computation. PyTorch constructs are transferred to this device. Usually is one of 'cuda' or 'cpu'. (string, optional) - **loss**: The loss function to be used in computations. (typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional) - **lr**: The learning rate used for training (float) validation_dataset: torch.utils.data.Dataset The validation dataset to be used in GLISTER objective typeOf: str, optional Determines the type of regulariser to be used. Default is **'none'**. For random regulariser use **'Rand'**. To use Facility Location set functiom as a regulariser use **'FacLoc'**. To use Diversity set functiom as a regulariser use **'Diversity'**. lam: float, optional Determines the amount of regularisation to be applied. Mandatory if is not `typeOf='none'` and by default set to `None`. For random regulariser use values should be between 0 and 1 as it determines fraction of points replaced by random points. For both 'Diversity' and 'FacLoc', `lam` determines the weightage given to them while computing the gain. kernel_batch_size: int, optional For 'Diversity' and 'FacLoc' regualrizer versions, similarity kernel is to be computed, which entails creating a 3d torch tensor of dimenssions kernel_batch_size*kernel_batch_size* feature dimenssion.Again kernel_batch_size should be such that one can exploit the benefits of tensorization while honouring the resourse constraits. """ def __init__(self, labeled_dataset, unlabeled_dataset, net, nclasses, args={}, validation_dataset = None, typeOf = 'none', lam = None, kernel_batch_size = 200): super(GLISTER, self).__init__(labeled_dataset, unlabeled_dataset, net, nclasses, args) self.validation_dataset = validation_dataset self.typeOf = typeOf self.lam = lam self.kernel_batch_size = kernel_batch_size def distance(self, x, y, exp = 2): n = x.size(0) m = y.size(0) d = x.size(1) x = x.unsqueeze(1).expand(n, m, d) y = y.unsqueeze(0).expand(n, m, d) if self.typeOf == "FacLoc": dist = torch.pow(x - y, exp).sum(2) elif self.typeOf == "Diversity": dist = torch.exp((-1 * torch.pow(x - y, exp).sum(2))/2) return dist def _compute_similarity_kernel(self): g_is = [] for item in range(math.ceil(len(self.grads_per_elem) / self.kernel_batch_size)): inputs = self.grads_per_elem[item *self.kernel_batch_size:(item + 1) *self.kernel_batch_size] g_is.append(inputs) with torch.no_grad(): new_N = len(self.grads_per_elem) self.sim_mat = torch.zeros([new_N, new_N], dtype=torch.float32).to(self.device) first_i = True for i, g_i in enumerate(g_is, 0): if first_i: size_b = g_i.size(0) first_i = False for j, g_j in enumerate(g_is, 0): self.sim_mat[i * size_b: i * size_b + g_i.size(0), j * size_b: j * size_b + g_j.size(0)] = self.distance(g_i, g_j) if self.typeOf == "FacLoc": const = torch.max(self.sim_mat).item() #self.sim_mat = const - self.sim_mat self.min_dist = (torch.ones(new_N, dtype=torch.float32)*const).to(self.device) def _compute_per_element_grads(self): self.grads_per_elem = self.get_grad_embedding(self.unlabeled_dataset, True) self.prev_grads_sum = torch.sum(self.get_grad_embedding(self.labeled_dataset, False), dim=0).view(1, -1) def _procure_labels(self, input_dataset): loader = DataLoader(input_dataset, shuffle=False, batch_size = self.args['batch_size']) # If the input is a dictionary type, procure the labels by indexing the batch labels = None if type(input_dataset[0]) == dict: for dict_batch in loader: if labels is None: labels = dict_batch["labels"] else: labels = torch.cat([labels, dict_batch["labels"]]) else: for _, batch_labels in loader: if labels is None: labels = batch_labels else: labels = torch.cat([labels, batch_labels]) return labels def _update_grads_val(self,grads_currX=None, first_init=False): embDim = self.model.get_embedding_dim() if first_init: if self.validation_dataset is not None: loader = DataLoader(self.validation_dataset,shuffle=False,batch_size=self.args['batch_size']) self.out = torch.zeros(len(self.validation_dataset), self.target_classes).to(self.device) self.emb = torch.zeros(len(self.validation_dataset), embDim).to(self.device) else: class AddLabelDataset(Dataset): def __init__(self, wrapped_unlabeled_dataset, added_labels): self.wrapped_unlabeled_dataset = wrapped_unlabeled_dataset self.added_labels = added_labels def __getitem__(self, index): unlabeled_data = self.wrapped_unlabeled_dataset[index] label = self.added_labels[index] return unlabeled_data, label def __len__(self): return len(self.wrapped_unlabeled_dataset) class AddLabelDictDataset(Dataset): def __init__(self, wrapped_unlabeled_dataset, added_labels): self.wrapped_unlabeled_dataset = wrapped_unlabeled_dataset self.added_labels = added_labels def __getitem__(self, index): unlabeled_data = self.wrapped_unlabeled_dataset[index] label = self.added_labels[index] new_labeled_data = unlabeled_data new_labeled_data["labels"] = label return new_labeled_data def __len__(self): return len(self.wrapped_unlabeled_dataset) # Prepare the "new" dataset differently, depending on the type of the input predicted_y = self.predict(self.unlabeled_dataset).cpu() # Bring to CPU as the loaders used require it if type(self.unlabeled_dataset[0]) == dict: pseudolabeled_dataset = AddLabelDictDataset(self.unlabeled_dataset, predicted_y) else: pseudolabeled_dataset = AddLabelDataset(self.unlabeled_dataset, predicted_y) self.new_dataset = ConcatDataset([pseudolabeled_dataset, self.labeled_dataset]) loader = DataLoader(self.new_dataset, shuffle=False, batch_size=self.args['batch_size']) self.out = torch.zeros(len(self.new_dataset), self.target_classes).to(self.device) self.emb = torch.zeros(len(self.new_dataset), embDim).to(self.device) self.grads_val_curr = torch.zeros(self.target_classes*(1+embDim), 1).to(self.device) evaluated_points = 0 with torch.no_grad(): for loaded_instance in loader: if type(loaded_instance) == dict: y = loaded_instance["labels"] # Per our convention, we expect labels in dictionary-type inputs to be in "labels" field del loaded_instance["labels"] x = loaded_instance else: x = loaded_instance[0] y = loaded_instance[1] idxs = [iter_index for iter_index in range(evaluated_points, evaluated_points + y.shape[0])] if type(x) == dict: x = dict_to(x, self.device) init_out, init_l1 = self.model(**x,last=True) else: x = x.to(self.device) init_out, init_l1 = self.model(x,last=True) y = y.to(self.device) self.emb[idxs] = init_l1 for j in range(self.target_classes): try: self.out[idxs, j] = init_out[:, j] - (1 * self.args['lr'] * (torch.matmul(init_l1, self.prev_grads_sum[0][(j * embDim) + self.target_classes:((j + 1) * embDim) + self.target_classes].view(-1, 1)) + self.prev_grads_sum[0][j])).view(-1) except KeyError: raise ValueError("Please pass learning rate used during the training") scores = F.softmax(self.out[idxs], dim=1) one_hot_label = torch.zeros(len(y), self.target_classes).to(self.device) one_hot_label.scatter_(1, y.view(-1, 1), 1) l0_grads = scores - one_hot_label l0_expand = torch.repeat_interleave(l0_grads, embDim, dim=1) l1_grads = l0_expand * init_l1.repeat(1, self.target_classes) self.grads_val_curr += torch.cat((l0_grads, l1_grads), dim=1).sum(dim=0).view(-1, 1) evaluated_points += y.shape[0] if self.validation_dataset is not None: self.grads_val_curr /= len(self.validation_dataset) self.Y_Val = self._procure_labels(self.validation_dataset) self.Y_Val = self.Y_Val.to(self.device) else: self.grads_val_curr /= predicted_y.shape[0] self.Y_new = self._procure_labels(self.new_dataset) self.Y_new = self.Y_new.to(self.device) elif grads_currX is not None: # update params: with torch.no_grad(): for j in range(self.target_classes): try: self.out[:, j] = self.out[:, j] - (1 * self.args['lr'] * (torch.matmul(self.emb, grads_currX[0][(j * embDim) + self.target_classes:((j + 1) * embDim) + self.target_classes].view(-1, 1)) + grads_currX[0][j])).view(-1) except KeyError: print("Please pass learning rate used during the training") scores = F.softmax(self.out, dim=1) if self.validation_dataset is not None: one_hot_label = torch.zeros(self.Y_Val.shape[0], self.target_classes).to(self.device) one_hot_label.scatter_(1,self.Y_Val.view(-1, 1), 1) else: one_hot_label = torch.zeros(self.Y_new.shape[0], self.target_classes).to(self.device) one_hot_label.scatter_(1, self.Y_new.view(-1, 1), 1) l0_grads = scores - one_hot_label l0_expand = torch.repeat_interleave(l0_grads, embDim, dim=1) l1_grads = l0_expand * self.emb.repeat(1, self.target_classes) self.grads_val_curr = torch.cat((l0_grads, l1_grads), dim=1).mean(dim=0).view(-1, 1) def eval_taylor_modular(self, grads,greedySet=None,remset=None): with torch.no_grad(): if self.typeOf == "FacLoc": gains = torch.matmul(grads, self.grads_val_curr) + self.lam*((self.min_dist - \ torch.min(self.min_dist,self.sim_mat[remset])).sum(1)).view(-1, 1).to(self.device) elif self.typeOf == "Diversity" and len(greedySet) > 0: gains = torch.matmul(grads, self.grads_val_curr) - \ self.lam*self.sim_mat[remset][:, greedySet].sum(1).view(-1, 1).to(self.device) else: gains = torch.matmul(grads, self.grads_val_curr) return gains
[docs] def select(self, budget): """ Selects next set of points Parameters ---------- budget: int Number of data points to select for labeling Returns ---------- idxs: list List of selected data point indices with respect to unlabeled_dataset """ self.model.eval() self._compute_per_element_grads() self._update_grads_val(first_init=True) numSelected = 0 greedySet = list() remainSet = list(range(len(self.unlabeled_dataset))) if self.typeOf == 'Rand': if self.lam is not None: if self.lam >0 and self.lam < 1: curr_bud = (1-self.lam)*budget else: raise ValueError("Lambda value should be between 0 and 1") else: raise ValueError("Please pass a appropriate lambda value for random regularisation") else: curr_bud = budget if self.typeOf == "FacLoc" or self.typeOf == "Diversity": if self.lam is not None: self._compute_similarity_kernel() else: if self.typeOf == "FacLoc": raise ValueError("Please pass a appropriate lambda value for Facility Location based regularisation") elif self.typeOf == "Diversity": raise ValueError("Please pass a appropriate lambda value for Diversity based regularisation") while (numSelected < curr_bud): if self.typeOf == "Diversity": gains = self.eval_taylor_modular(self.grads_per_elem[remainSet],greedySet,remainSet) elif self.typeOf == "FacLoc": gains = self.eval_taylor_modular(self.grads_per_elem[remainSet],remset=remainSet) else: gains = self.eval_taylor_modular(self.grads_per_elem[remainSet])#rem_grads) bestId = remainSet[torch.argmax(gains).item()] greedySet.append(bestId) remainSet.remove(bestId) numSelected += 1 self._update_grads_val(self.grads_per_elem[bestId].view(1, -1)) if self.typeOf == "FacLoc": self.min_dist = torch.min(self.min_dist,self.sim_mat[bestId]) if self.typeOf == 'Rand': greedySet.extend(list(np.random.choice(remainSet, size=budget - int(curr_bud),replace=False))) return greedySet