Source code for distil.active_learning_strategies.gradmatch_active

import numpy as np
import torch

from .strategy import Strategy
from scipy.optimize import nnls

[docs]class GradMatchActive(Strategy): """ This is an implementation of an active learning variant of GradMatch from the paper GRAD-MATCH: A Gradient Matching Based Data Subset Selection for Efficient Learning :footcite:`killamsetty2021grad`. This algorithm solves a fixed-weight version of the error term present in the paper by a greedy selection algorithm akin to the original GradMatch's Orthogonal Matching Pursuit. The gradients computed are on the hypothesized labels of the loss function and are matched to either the full gradient of these hypothesized examples or a supplied validation gradient. The indices returned are the ones selected by this algorithm. .. math:: Err(X_t, L, L_T, \\theta_t) = \\left |\\left| \\sum_{i \\in X_t} \\nabla_\\theta L_T^i (\\theta_t) - \\frac{k}{N} \\nabla_\\theta L(\\theta_t) \\right | \\right| where, - Each gradient is computed with respect to the last layer's parameters - :math:`\\theta_t` are the model parameters at selection round :math:`t` - :math:`X_t` is the queried set of points to label at selection round :math:`t` - :math:`k` is the budget - :math:`N` is the number of points contributing to the full gradient :math:`\\nabla_\\theta L(\\theta_t)` - :math:`\\nabla_\\theta L(\\theta_t)` is either the complete hypothesized gradient or a validation gradient - :math:`\\sum_{i \\in X_t} \\nabla_\\theta L_T^i (\\theta_t)` is the subset's hypothesized gradient with :math:`|X_t| = k` Parameters ---------- labeled_dataset: torch.utils.data.Dataset The labeled training dataset unlabeled_dataset: torch.utils.data.Dataset The unlabeled pool dataset net: torch.nn.Module The deep model to use nclasses: int Number of unique values for the target args: dict Specify additional parameters - **batch_size**: The batch size used internally for torch.utils.data.DataLoader objects. (int, optional) - **device**: The device to be used for computation. PyTorch constructs are transferred to this device. Usually is one of 'cuda' or 'cpu'. (string, optional) - **loss**: The loss function to be used in computations. (typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional) - **grad_embedding**: The type of gradient embedding that should be used (string, optional) - **omp_reg**: The regularization constant to use in GradMatch objective validation_dataset: torch.utils.data.Dataset, optional The validation dataset to use in GradMatch objective """ def __init__(self, labeled_dataset, unlabeled_dataset, net, nclasses, args={}, validation_dataset = None): # Run super constructor super(GradMatchActive, self).__init__(labeled_dataset, unlabeled_dataset, net, nclasses, args) if 'grad_embedding' not in args: self.grad_embedding = 'bias_linear' else: self.grad_embedding = args['grad_embedding'] if 'omp_reg' not in args: self.omp_reg = 0 else: self.omp_reg = args['omp_reg'] self.validation_dataset = validation_dataset def fixed_weight_greedy_parallel(self, A, b, val_set_size, nnz): """ Approximately solves :math:`\\min_x ||Ax - b||_2 s.t. x_i \in \{0,1\}, ||x||_0 = nnz` Parameters ---------- A: torch.Tensor Design matrix of size (d, n) b: torch.Tensor Measurement vector of length d nnz: int Maximum number of nonzero coefficients (if None, set to n) Returns ------- x: torch.Tensor Vector of length n """ d, n = A.shape x = torch.zeros(n, device=self.device) # ,dtype=torch.float64) # Calculate b * n / val_set_size b_k_val = (n / val_set_size) * b # Stores memoized version of Ax, where x is the vector containing an element from {0,1} memoized_Ax = torch.zeros(d, device=self.device) for i in range(nnz): # Calculate residual resid = memoized_Ax - b_k_val # Calculate columnwise difference with resid. # resid[:None] promotes shape from (d) to (d,) gain_norms = (A + resid[:,None]).norm(dim=0) # Choose smallest norm from among columns of A not already chosen. # Such a vector represents the best greedy choice in matching to # the residual. zero_x = torch.nonzero(x == 0) argmin = torch.argmin(gain_norms[zero_x]) actual_index = zero_x[argmin].item() # Add this column to the memoized Ax. Set this weight to 1. x[actual_index] = 1 memoized_Ax = memoized_Ax + A[:,actual_index] # One condition for stopping is to check if the gain_norm is larger than the residual norm. # However, we opt to select the full number of points specified while following this # greedy behavior in order to fill the AL budget. """ if gain_norms[actual_index] < resid_norm: x[actual_index] = 1 memoized_Ax = memoized_Ax + A[:,actual_index] else: break """ return x # NOTE: Standard Algorithm, e.g. Tropp, ``Greed is Good: Algorithmic Results for Sparse Approximation," IEEE Trans. Info. Theory, 2004. def orthogonalmp_reg_parallel(self, A, b, tol=1E-4, nnz=None, positive=False, lam=1): '''approximately solves min_x |x|_0 s.t. Ax=b using Orthogonal Matching Pursuit Args: A: design matrix of size (d, n) b: measurement vector of length d tol: solver tolerance nnz = maximum number of nonzero coefficients (if None set to n) positive: only allow positive nonzero coefficients Returns: vector of length n ''' AT = torch.transpose(A, 0, 1) d, n = A.shape if nnz is None: nnz = n x = torch.zeros(n, device=self.device) # ,dtype=torch.float64) resid = b.detach().clone() normb = b.norm().item() indices = [] for i in range(nnz): if resid.norm().item() / normb < tol: break projections = torch.matmul(AT, resid) if positive: index = torch.argmax(projections) else: index = torch.argmax(torch.abs(projections)) if index in indices: break indices.append(index) if len(indices) == 1: A_i = A[:, index] x_i = projections[index] / torch.dot(A_i, A_i).view(-1) A_i = A[:, index].view(1, -1) else: A_i = torch.cat((A_i, A[:, index].view(1, -1)), dim=0) temp = torch.matmul(A_i, torch.transpose(A_i, 0, 1)) + lam * torch.eye(A_i.shape[0], device=self.device) # If constrained to be positive, use nnls. Otherwise, use torch's lstsq method. if positive: x_i, _ = nnls(temp.cpu().numpy(), torch.matmul(A_i, b).view(-1, 1).cpu().numpy()[:,0]) x_i = torch.Tensor(x_i).cuda() else: x_i, _ = torch.lstsq(torch.matmul(A_i, b).view(-1, 1), temp) resid = b - torch.matmul(torch.transpose(A_i, 0, 1), x_i).view(-1) x_i = x_i.view(-1) for i, index in enumerate(indices): try: x[index] += x_i[i] except IndexError: x[index] += x_i return x
[docs] def select(self, budget, use_weights=False): """ Selects next set of points Parameters ---------- budget: int Number of data points to select for labeling use_weights: bool Whether to use fixed-weight version (false) or OMP version (true) Returns ---------- idxs: list List of selected data point indices with respect to unlabeled_dataset """ self.model.eval() # Compute hypothesized gradients. hypothesized_gradients = self.get_grad_embedding(self.unlabeled_dataset, True, grad_embedding_type=self.grad_embedding) # If there is a validation set, set the target gradient vector to be the average validation gradient. # Otherwise, set the target gradient vector to be the average hypothesized gradient if self.validation_dataset is not None: target_gradient = self.get_grad_embedding(self.validation_dataset, False, grad_embedding_type=self.grad_embedding).sum(dim=0) num_target_gradient_summed_vectors = len(self.validation_dataset) else: target_gradient = hypothesized_gradients.sum(dim=0) num_target_gradient_summed_vectors = len(self.unlabeled_dataset) # Needed for solvers transposed_hypothesized_gradients = torch.transpose(hypothesized_gradients, 0, 1) # Use different solvers depending on the use of weights if use_weights: # We solve the weighted objective; hence, the OMP solver should be used. weight_vector = self.orthogonalmp_reg_parallel(transposed_hypothesized_gradients, target_gradient, nnz=budget, positive=True, lam=self.omp_reg) non_negative_weight_indices = torch.nonzero(weight_vector).view(-1) non_negative_weights = weight_vector[non_negative_weight_indices].tolist() non_negative_weight_indices = non_negative_weight_indices.tolist() # OMP solver might not return 'budget' number of non-negative weights. Pick the rest randomly and assign weights of 1. diff = budget - len(non_negative_weight_indices) if diff > 0: remain_list = set(np.arange(len(self.unlabeled_dataset))).difference(set(non_negative_weight_indices)) new_idxs = np.random.choice(list(remain_list), size=diff, replace=False) non_negative_weight_indices.extend(new_idxs) non_negative_weights.extend([1 for _ in range(diff)]) idxs = non_negative_weight_indices gammas = torch.tensor(non_negative_weights).to(self.device) return idxs, gammas else: # We solve the fixed weight objective weight_vector = self.fixed_weight_greedy_parallel(transposed_hypothesized_gradients, target_gradient, num_target_gradient_summed_vectors, nnz=budget) non_negative_weight_indices = torch.nonzero(weight_vector).view(-1).tolist() # Unlike the other wrapper, this solver will ensure there are 'budget' non-negative weights. return non_negative_weight_indices