Source code for distil.active_learning_strategies.kmeans_sampling

from .strategy import Strategy
from torch.utils.data import Subset, DataLoader, Dataset
from numpy import random
import torch

# Used to calculate embeddings
class CustomTensorDataset(Dataset):
            
    def __init__(self, wrapped_tensor):
        self.wrapped_tensor = wrapped_tensor
                
    def __getitem__(self, index):
        return self.wrapped_tensor[index]
            
    def __len__(self):
        return self.wrapped_tensor.shape[0]

[docs]class KMeansSampling(Strategy): """ Implements KMeans Sampling selection strategy, the last layer embeddings are calculated for all the unlabeled data points. Then the KMeans clustering algorithm is run over these embeddings with the number of clusters equal to the budget. Then the distance is calculated for all the points from their respective centers. From each cluster, the point closest to the center is selected to be labeled for the next iteration. Since the number of centers are equal to the budget, selecting one point from each cluster satisfies the total number of data points to be selected in one iteration. Parameters ---------- labeled_dataset: torch.utils.data.Dataset The labeled training dataset unlabeled_dataset: torch.utils.data.Dataset The unlabeled pool dataset net: torch.nn.Module The deep model to use nclasses: int Number of unique values for the target args: dict Specify additional parameters - **batch_size**: Batch size to be used inside strategy class (int, optional) - **device**: The device that this strategy class should use for computation (string, optional) - **loss**: The loss that should be used for relevant computations (typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional) - **rand_seed**: Specifies a seed for the random seed generator used in initialization (int, optional) - **representation**: Specifies whether to use the last linear layer embeddings or the raw data. Must be one of 'linear' or 'raw' (string, optional) - **kmeans_args**: Specifies additional kmeans-related parameters - **tol**: Specifies the value of the Frobenius norm of the inertia tensor by which kmeans should cease (float, optional) - **max_iter**: Specifies the maximum number of iterations that kmeans should use before terminating (int, optional) - **n_init**: Specifies the number of kmeans run-throughs to use, wherein the one with the smallest inertia is selected for the selection phase (int, optional) """ def __init__(self, labeled_dataset, unlabeled_dataset, net, nclasses, args={}): super(KMeansSampling, self).__init__(labeled_dataset, unlabeled_dataset, net, nclasses, args) if 'rand_seed' in args: random.seed(args['rand_seed']) if 'representation' in args: self.representation = args['representation'] else: self.representation = 'linear' if 'kmeans_args' in args: self.kmeans_args = args['kmeans_args'] else: args['kmeans_args'] = dict() self.kmeans_args = args['kmeans_args'] if 'tol' not in self.kmeans_args: self.kmeans_args['tol'] = 1e-4 if 'max_iter' not in self.kmeans_args: self.kmeans_args['max_iter'] = 300 if 'n_init' not in self.kmeans_args: self.kmeans_args['n_init'] = 10 def _dataset_to_raw_device_tensor(self, input_dataset): loaded_dataset_tensor = next(iter(DataLoader(input_dataset, shuffle=False, batch_size=len(input_dataset)))) loaded_dataset_tensor = loaded_dataset_tensor.to(self.device) loaded_dataset_tensor = loaded_dataset_tensor.view(len(input_dataset), -1) return loaded_dataset_tensor def get_closest_distances(self, ground_set, center_tensor): # Store the minimum distances in this tensor ground_set_min_distances = torch.zeros(len(ground_set)).to(self.device) ground_set_closest_center_indices = torch.zeros(len(ground_set), dtype=torch.long).to(self.device) start_batch_idx = 0 with torch.no_grad(): while start_batch_idx != len(ground_set): end_batch_idx = min(start_batch_idx + self.args['batch_size'], len(ground_set)) batch_idx_list = list(range(start_batch_idx, end_batch_idx)) batch_subset = Subset(ground_set, batch_idx_list) if self.representation == "linear": batch_embedding_tensor = self.get_embedding(batch_subset) elif self.representation == "raw": batch_embedding_tensor = self._dataset_to_raw_device_tensor(batch_subset) else: raise ValueError("Representation must be one of 'linear', 'raw'") # Calculate the distance of each point in the ground set batch to each center in the center batch. inter_batch_distances = torch.cdist(batch_embedding_tensor, center_tensor, p=2) # Calculate the minimum distances across each row; this will reflect the distance to the closest center batch_min_distances, batch_min_idx = torch.min(inter_batch_distances, dim=1) # Assign minimum distance to the correct slice of the storage tensor ground_set_min_distances[start_batch_idx:end_batch_idx] = batch_min_distances ground_set_closest_center_indices[start_batch_idx:end_batch_idx] = batch_min_idx start_batch_idx = end_batch_idx return ground_set_min_distances, ground_set_closest_center_indices.tolist() def kmeans_plusplus(self, num_centers): # 1. Choose a random point to be the center (uniform dist) selected_points = [random.choice(len(self.unlabeled_dataset))] # Keep repeating this step until num_centers centers have been chosen while len(selected_points) < num_centers: # 2. Calculate the squared distance to the nearest center for each point selected_centers = Subset(self.unlabeled_dataset, selected_points) if self.representation == 'linear': selected_centers_tensor = self.get_embedding(selected_centers) elif self.representation == 'raw': selected_centers_tensor = self._dataset_to_raw_device_tensor(selected_centers) else: raise ValueError("Representation must be one of 'linear', 'raw'") ground_set_min_distances, _ = self.get_closest_distances(self.unlabeled_dataset, selected_centers_tensor) ground_set_min_distances = torch.pow(ground_set_min_distances, 2) # 3. Sample a random point with probability proportional to the squared distance # Note: torch.multinomial does not require that the weight tensor sum to 1 # (e.g., forms a probability distribution). It simply requires non-negative weights # and will form the distribution itself. torch.multinomial can be used as it allows # the tensor to stay on the GPU and because sampling from the multinomial distribution # assigns the probability of sampling element i with the calculated distance weight. distance_probability_distribution = ground_set_min_distances random_choice = torch.multinomial(distance_probability_distribution, 1).item() # 4. Add the chosen index to the center list selected_points.append(random_choice) return selected_points def kmeans_calculate_means(self, clusters): # Calculate dimensions of and form the storage tensor if self.representation == 'linear': num_features = self.model.get_embedding_dim() means = torch.zeros(len(clusters), num_features).to(self.device) elif self.representation == 'raw': num_features = self.unlabeled_dataset[0].view(-1).shape[0] means = torch.zeros(len(clusters), num_features).to(self.device) else: raise ValueError("Representation must be one of 'linear', 'raw'") with torch.no_grad(): # Calculate the mean of each cluster for i, cluster in enumerate(clusters): start_batch_idx = 0 # Only load those points specific to the cluster ground_set_cluster = Subset(self.unlabeled_dataset, cluster) running_cluster_sum = None while start_batch_idx != len(ground_set_cluster): end_batch_idx = min(start_batch_idx + self.args['batch_size'], len(ground_set_cluster)) batch_idx_list = list(range(start_batch_idx, end_batch_idx)) batch_subset = Subset(ground_set_cluster, batch_idx_list) # Put center batch on correct device, calculate embedding if self.representation == 'linear': batch_subset_tensor = self.get_embedding(batch_subset) elif self.representation == 'raw': batch_subset_tensor = self._dataset_to_raw_device_tensor(batch_subset) else: raise ValueError("Representation must be one of 'linear', 'raw'") if running_cluster_sum is None: running_cluster_sum = torch.sum(batch_subset_tensor, dim=0) else: running_cluster_sum += torch.sum(batch_subset_tensor, dim=0) start_batch_idx = end_batch_idx # Divide by total number of elements to get the mean running_cluster_sum /= len(ground_set_cluster) means[i] = running_cluster_sum return means def kmeans_calculate_clusters(self, center_tensor): # Calculate the closest center indices _, ground_set_closest_center_indices = self.get_closest_distances(self.unlabeled_dataset, center_tensor) # For each center, create an associated cluster and add points to them clusters = [[] for x in range(len(center_tensor))] for i, index in enumerate(ground_set_closest_center_indices): clusters[index].append(i) # Return the clusters return clusters def kmeans_clustering(self, num_centers): best_inertia = None best_centers = None # Run kmeans algorithm n_init times and choose the one with best inertia. for i in range(self.kmeans_args['n_init']): # Use kmeans++ initialization centers_subset = Subset(self.unlabeled_dataset, self.kmeans_plusplus(num_centers)) if self.representation == "linear": centers = self.get_embedding(centers_subset) else: centers = self._dataset_to_raw_device_tensor(centers_subset) # Alternate between means/assignment steps until max_iter reached for i in range(self.kmeans_args['max_iter']): old_centers = centers clusters = self.kmeans_calculate_clusters(centers) centers = self.kmeans_calculate_means(clusters) center_diff = centers - old_centers frobenius_norm = torch.linalg.norm(center_diff, ord="fro").item() # If frobenius norm of difference between centers is below a tolerance, stop this kmeans iteration. if frobenius_norm < self.kmeans_args['tol']: break # Lastly, evaluate the inertia of this kmeans solution. If it is the best one so far, keep it. ground_set_min_distances, ground_set_closest_center_indices = self.get_closest_distances(self.unlabeled_dataset, centers) inertia = torch.pow(ground_set_min_distances, 2).sum().item() if best_inertia is None or inertia < best_inertia: best_inertia = inertia best_centers = centers return best_centers
[docs] def select(self, budget): """ Selects next set of points Parameters ---------- budget: int Number of data points to select for labeling Returns ---------- idxs: list List of selected data point indices with respect to unlabeled_dataset """ self.model.eval() # See if the unlabeled dataset returns dictionary-style type instances. If so, raise an error. if type(self.unlabeled_dataset[0]) == dict and self.representation == "raw": raise ValueError("Dictionary-type input not supported with raw representation") # Get the best centers through kmeans clustering best_centers = self.kmeans_clustering(budget) # Choose the point closest to each center. ground_set_min_distances, ground_set_closest_center_indices = self.get_closest_distances(self.unlabeled_dataset, best_centers) cluster_min_distances = [None for x in range(budget)] cluster_min_indices = [None for x in range(budget)] for i, (distance, center_index) in enumerate(zip(ground_set_min_distances, ground_set_closest_center_indices)): if cluster_min_distances[center_index] is None or distance < cluster_min_distances[center_index]: cluster_min_distances[center_index] = distance cluster_min_indices[center_index] = i # Return the list of indices of points that are closest to the best centers chosen by kmeans return cluster_min_indices