Utilities

Predefined Training Loop

class distil.utils.train_helper.AddIndexDataset(wrapped_dataset)[source]

Bases: Dataset

class distil.utils.train_helper.data_train(training_dataset, net, args)[source]

Bases: object

Provides a configurable training loop for AL.

Parameters
  • training_dataset (torch.utils.data.Dataset) – The training dataset to use

  • net (torch.nn.Module) – The model to train

  • args (dict) –

    Additional arguments to control the training loop

    batch_size - The size of each training batch (int, optional) islogs- Whether to return training metadata (bool, optional) optimizer- The choice of optimizer. Must be one of ‘sgd’ or ‘adam’ (string, optional) isverbose- Whether to print more messages about the training (bool, optional) isreset- Whether to reset the model before training (bool, optional) max_accuracy- The training accuracy cutoff by which to stop training (float, optional) min_diff_acc- The minimum difference in accuracy to measure in the window of monitored accuracies. If all differences are less than the minimum, stop training (float, optional) window_size- The size of the window for monitoring accuracies. If all differences are less than ‘min_diff_acc’, then stop training (int, optional) criterion- The criterion to use for training (typing.Callable[], optional) device- The device to use for training (string, optional)

check_saturation(acc_monitor)[source]
get_acc_on_set(test_dataset)[source]

Calculates and returns the accuracy on the given dataset to test

Parameters

test_dataset (torch.utils.data.Dataset) – The dataset to test

Returns

accFinal – The fraction of data points whose predictions by the current model match their targets

Return type

float

train(gradient_weights=None)[source]

Initiates the training loop.

Parameters

gradient_weights (list, optional) – The weight of each data point’s effect on the loss gradient. If none, regular training will commence. If not, weighted training will commence.

Returns

model – The trained model. Alternatively, this will also return the training logs if ‘islogs’ is set to true.

Return type

torch.nn.Module

update_data(new_training_dataset)[source]

Updates the training dataset with the provided new training dataset

Parameters

new_training_dataset (torch.utils.data.Dataset) – The new training dataset

update_index(idxs_lb)[source]
distil.utils.train_helper.init_weights(m)[source]

Utilities

class distil.utils.utils.ConcatWithTargets(dataset1, dataset2)[source]

Bases: Dataset

Concat of a dataset at specified indices.

class distil.utils.utils.LabeledToUnlabeledDataset(wrapped_dataset)[source]

Bases: Dataset

Remove labels from a labeled dataset.

class distil.utils.utils.SubsetWithTargets(dataset, indices, labels)[source]

Bases: Dataset

Subset of a dataset at specified indices.

Parameters
  • dataset (Dataset) – The whole Dataset

  • indices (sequence) – Indices in the whole set selected for subset

  • labels (sequence) – targets as required for the indices. will be the same length as indices

class distil.utils.utils.SubsetWithTargetsSingleChannel(dataset, indices, labels)[source]

Bases: Dataset

Subset of a dataset at specified indices.

Parameters
  • dataset (Dataset) – The whole Dataset

  • indices (sequence) – Indices in the whole set selected for subset

  • labels (sequence) – targets as required for the indices. will be the same length as indices