Source code for distil.active_learning_strategies.batch_bald

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from .strategy import Strategy
from ..utils.batch_bald.consistent_mc_dropout import ConsistentMCDropout
from ..utils.batch_bald.batchbald import get_batchbald_batch

def dict_to(dictionary, device):
    
    # Predict the most likely class
    if type(dictionary) == dict:
        for key in dictionary:
            value = dictionary[key]
            if hasattr(value, "to"):
                dictionary[key] = value.to(device=device)
    
    return dictionary

[docs]class BatchBALDDropout(Strategy): """ Implementation of BatchBALD Strategy :footcite:`kirsch2019batchbald`, which refines the original BALD acquisition to the batch setting using a new acquisition function. This class extends :class:`active_learning_strategies.strategy.Strategy` to include a MC sampling technique based on the sampling techniques used in their paper. Parameters ---------- labeled_dataset: torch.utils.data.Dataset The labeled training dataset unlabeled_dataset: torch.utils.data.Dataset The unlabeled pool dataset net: torch.nn.Module The deep model to use nclasses: int Number of unique values for the target args: dict Specify additional parameters - **batch_size**: The batch size used internally for torch.utils.data.DataLoader objects. (int, optional) - **device**: The device to be used for computation. PyTorch constructs are transferred to this device. Usually is one of 'cuda' or 'cpu'. (string, optional) - **loss**: The loss function to be used in computations. (typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional) - **n_drop**: Number of dropout runs to use to generate MC samples (int, optional) - **n_samples**: Number of samples to use in computing joint entropy (int, optional) """ def __init__(self, labeled_dataset, unlabeled_dataset, net, nclasses, args={}): """ Constructor method """ if 'n_drop' in args: self.n_drop = args['n_drop'] else: self.n_drop = 40 if 'n_samples' in args: self.n_samples = args['n_samples'] else: self.n_samples = 1000 if 'mod_inject' in args: self.mod_inject = args['mod_inject'] else: self.mod_inject = 'linear' super(BatchBALDDropout, self).__init__(labeled_dataset, unlabeled_dataset, net, nclasses, args={}) def do_MC_dropout_before_linear(self, unlabeled_dataset, n_drop): # Procure a loader on the supplied dataset loader_te = DataLoader(unlabeled_dataset, shuffle=False, batch_size = self.args['batch_size']) # Check that there is a linear layer attribute in the supplied model try: getattr(self.model, self.mod_inject) except: raise ValueError(F"Model does not have attribute {self.mod_inject} as the last layer") # Store the linear layer in a temporary variable lin_layer_temp = getattr(self.model, self.mod_inject) # Inject dropout into the model by using ConsistentMCDropout module from BatchBALD repo dropout_module = ConsistentMCDropout() dropout_injection = torch.nn.Sequential(dropout_module, lin_layer_temp) setattr(self.model, self.mod_inject, dropout_injection) # Make sure that the model is in eval mode self.model.eval() # For safety, explicitly set the dropout module to be in evaluation mode dropout_module.train(mode=False) # Create a tensor that will store the probabilities probs = torch.zeros([n_drop, len(unlabeled_dataset), self.target_classes]).to(self.device) with torch.no_grad(): for i in range(n_drop): evaluated_points = 0 # In original BatchBALD code, inference samples were predicted in a single forward pass via an additional forward parameter. # Hence, only 1 mask needed to be generated during eval time for consistent MC sampling (as there was only 1 pass). Here, # our models do not assume this forward parameter. Hence, we must have a different generated mask for each PASS of the # dataset. Note, however, that the mask is CONSISTENT within a pass, which is functionally equivalent to the original # BatchBALD code. dropout_module.reset_mask() for x in loader_te: if type(x) == dict: x = dict_to(x, self.device) out = self.model(**x) else: x = x.to(self.device) out = self.model(x) idxs = [iter_index for iter_index in range(evaluated_points, evaluated_points + out.shape[0])] probs[i][idxs] = F.softmax(out, dim=1) evaluated_points += len(x) # Transpose the probs to match BatchBALD interface probs = probs.permute(1,0,2) # Restore the linear layer setattr(self.model, self.mod_inject, lin_layer_temp) # Return the MC samples for each data instance return probs
[docs] def select(self, budget): """ Selects next set of points Parameters ---------- budget: int Number of data points to select for labeling Returns ---------- idxs: list List of selected data point indices with respect to unlabeled_dataset """ # Move the model to the correct device self.model = self.model.to(self.device) # Get the MC samples from probs = self.do_MC_dropout_before_linear(self.unlabeled_dataset, self.n_drop) # Compute the log probabilities to match BatchBALD interface log_probs = torch.log(probs) # Use BatchBALD interface to select the new points. candidate_batchbald_batch = get_batchbald_batch(log_probs, budget, self.n_samples, device=self.device) # Return the selected indices return candidate_batchbald_batch.indices