Source code for distil.utils.submodular

import numpy as np
import torch

import apricot

from scipy.sparse import csr_matrix

from .similarity_mat import SimilarityComputation


[docs]class SubmodularFunction(SimilarityComputation): """ Implementation of Submodular Function. This class allows you to use different submodular functions Parameters ---------- device: str Device to be used, cpu|gpu x_trn: torch tensor Data on which submodular optimization should be applied y_trn: torch tensor Labels of the data model: class Model architecture used for training N_trn: int Number of samples in dataset batch_size: int Batch size to be used for optimization if_convex: bool If convex or not submod: str Choice of submodular function - 'facility_location' | 'graph_cut' | 'saturated_coverage' | 'sum_redundancy' | 'feature_based' selection_type: str Type of selection - 'PerClass' | 'Supervised' | 'Full' """ def __init__(self, device, x_trn, y_trn, N_trn, batch_size, submod, selection_type): super(SubmodularFunction, self).__init__(device, x_trn, y_trn, N_trn, batch_size) self.submod = submod self.selection_type = selection_type
[docs] def lazy_greedy_max(self, budget): """ Data selection method using different submodular optimization functions. Parameters ---------- budget: int The number of data points to be selected Returns ---------- total_greedy_list: list List containing indices of the best datapoints """ classes, no_elements = torch.unique(self.y_trn, return_counts=True) len_unique_elements = no_elements.shape[0] per_class_bud = int(budget / len(classes)) final_per_class_bud = [] _, sorted_indices = torch.sort(no_elements, descending = True) if self.selection_type == 'PerClass': total_idxs = 0 for n_element in no_elements: final_per_class_bud.append(min(per_class_bud, torch.IntTensor.item(n_element))) total_idxs += min(per_class_bud, torch.IntTensor.item(n_element)) if total_idxs < budget: bud_difference = budget - total_idxs for i in range(len_unique_elements): available_idxs = torch.IntTensor.item(no_elements[sorted_indices[i]])-per_class_bud final_per_class_bud[sorted_indices[i]] += min(bud_difference, available_idxs) total_idxs += min(bud_difference, available_idxs) bud_difference = budget - total_idxs if bud_difference == 0: break total_greedy_list = [] for i in range(len_unique_elements): idxs = torch.where(self.y_trn == classes[i])[0] if self.submod == 'facility_location': self.compute_score(idxs) fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed', n_samples=final_per_class_bud[i]) elif self.submod == 'graph_cut': self.compute_score(idxs) fl = apricot.functions.graphCut.GraphCutSelection(random_state=0, metric='precomputed', n_samples=final_per_class_bud[i]) elif self.submod == 'saturated_coverage': self.compute_score(idxs) fl = apricot.functions.saturatedCoverage.SaturatedCoverageSelection(random_state=0, metric='precomputed', n_samples=final_per_class_bud[i]) elif self.submod == 'sum_redundancy': self.compute_score(idxs) fl = apricot.functions.sumRedundancy.SumRedundancySelection(random_state=0, metric='precomputed', n_samples=final_per_class_bud[i]) elif self.submod == 'feature_based': fl = apricot.functions.featureBased.FeatureBasedSelection(random_state=0, n_samples=final_per_class_bud[i]) if self.submod == 'feature_based': x_sub = fl.fit_transform(self.x_trn[idxs].numpy()) greedyList = self.get_index(self.x_trn[idxs].numpy(), x_sub) total_greedy_list.extend(idxs[greedyList]) else: sim_sub = fl.fit_transform(self.dist_mat.cpu().numpy()) greedyList = list(np.argmax(sim_sub, axis=1)) total_greedy_list.extend(idxs[greedyList]) elif self.selection_type == 'Supervised': if self.submod == 'feature_based': class_map = {} for i in range(len_unique_elements): class_map[torch.IntTensor.item(classes[i])] = i #Mapping classes from 0 to n sparse_data = torch.zeros([self.x_trn.shape[0], self.x_trn.shape[1]*len_unique_elements]) for i in range(self.x_trn.shape[0]): start_col = class_map[torch.IntTensor.item(self.y_trn[i])]*self.x_trn.shape[1] end_col = start_col+self.x_trn.shape[1] sparse_data[i, start_col:end_col] = self.x_trn[i, :] fl = apricot.functions.featureBased.FeatureBasedSelection(random_state=0, n_samples=budget) x_sub = fl.fit_transform(sparse_data.numpy()) total_greedy_list = self.get_index(sparse_data.numpy(), x_sub) else: for i in range(len(classes)): if i == 0: idxs = torch.where(self.y_trn == classes[i])[0] N = len(idxs) self.compute_score(idxs) row = idxs.repeat_interleave(N) col = idxs.repeat(N) data = self.dist_mat.cpu().numpy().flatten() else: idxs = torch.where(self.y_trn == classes[i])[0] N = len(idxs) self.compute_score(idxs) row = torch.cat((row, idxs.repeat_interleave(N)), dim=0) col = torch.cat((col, idxs.repeat(N)), dim=0) data = np.concatenate([data, self.dist_mat.cpu().numpy().flatten()], axis=0) sparse_simmat = csr_matrix((data, (row.numpy(), col.numpy())), shape=(self.N_trn, self.N_trn)) #self.dist_mat = sparse_simmat if self.submod == 'facility_location': fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed', n_samples=budget) elif self.submod == 'graph_cut': fl = apricot.functions.graphCut.GraphCutSelection(random_state=0, metric='precomputed', n_samples=budget) elif self.submod == 'saturated_coverage': fl = apricot.functions.saturatedCoverage.SaturatedCoverageSelection(random_state=0, metric='precomputed', n_samples=budget) elif self.submod == 'sum_redundancy': fl = apricot.functions.sumRedundancy.SumRedundancySelection(random_state=0, metric='precomputed', n_samples=budget) sim_sub = fl.fit_transform(sparse_simmat) total_greedy_list = list(np.array(np.argmax(sim_sub, axis=1)).reshape(-1)) if self.selection_type == 'Full': total_greedy_list = [] idx_end = self.x_trn.shape[0] - 1 idxs = torch.linspace(0, idx_end, self.x_trn.shape[0]).long() if self.submod == 'facility_location': self.compute_score(idxs) fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed', n_samples=budget) elif self.submod == 'graph_cut': self.compute_score(idxs) fl = apricot.functions.graphCut.GraphCutSelection(random_state=0, metric='precomputed', n_samples=budget) elif self.submod == 'saturated_coverage': self.compute_score(idxs) fl = apricot.functions.saturatedCoverage.SaturatedCoverageSelection(random_state=0, metric='precomputed', n_samples=budget) elif self.submod == 'sum_redundancy': self.compute_score(idxs) fl = apricot.functions.sumRedundancy.SumRedundancySelection(random_state=0, metric='precomputed', n_samples=budget) elif self.submod == 'feature_based': fl = apricot.functions.featureBased.FeatureBasedSelection(random_state=0, n_samples=budget) if self.submod == 'feature_based': x_sub = fl.fit_transform(self.x_trn.numpy()) total_greedy_list = self.get_index(self.x_trn.numpy(), x_sub) else: sim_sub = fl.fit_transform(self.dist_mat.cpu().numpy()) total_greedy_list = list(np.argmax(sim_sub, axis=1)) return total_greedy_list