Utilities
Predefined Training Loop
- class distil.utils.train_helper.data_train(training_dataset, net, args)[source]
Bases:
object
Provides a configurable training loop for AL.
- Parameters
training_dataset (torch.utils.data.Dataset) – The training dataset to use
net (torch.nn.Module) – The model to train
args (dict) –
Additional arguments to control the training loop
batch_size - The size of each training batch (int, optional) islogs- Whether to return training metadata (bool, optional) optimizer- The choice of optimizer. Must be one of ‘sgd’ or ‘adam’ (string, optional) isverbose- Whether to print more messages about the training (bool, optional) isreset- Whether to reset the model before training (bool, optional) max_accuracy- The training accuracy cutoff by which to stop training (float, optional) min_diff_acc- The minimum difference in accuracy to measure in the window of monitored accuracies. If all differences are less than the minimum, stop training (float, optional) window_size- The size of the window for monitoring accuracies. If all differences are less than ‘min_diff_acc’, then stop training (int, optional) criterion- The criterion to use for training (typing.Callable[], optional) device- The device to use for training (string, optional)
- get_acc_on_set(test_dataset)[source]
Calculates and returns the accuracy on the given dataset to test
- Parameters
test_dataset (torch.utils.data.Dataset) – The dataset to test
- Returns
accFinal – The fraction of data points whose predictions by the current model match their targets
- Return type
float
- train(gradient_weights=None)[source]
Initiates the training loop.
- Parameters
gradient_weights (list, optional) – The weight of each data point’s effect on the loss gradient. If none, regular training will commence. If not, weighted training will commence.
- Returns
model – The trained model. Alternatively, this will also return the training logs if ‘islogs’ is set to true.
- Return type
torch.nn.Module
Utilities
- class distil.utils.utils.ConcatWithTargets(dataset1, dataset2)[source]
Bases:
Dataset
Concat of a dataset at specified indices.
- class distil.utils.utils.LabeledToUnlabeledDataset(wrapped_dataset)[source]
Bases:
Dataset
Remove labels from a labeled dataset.
- class distil.utils.utils.SubsetWithTargets(dataset, indices, labels)[source]
Bases:
Dataset
Subset of a dataset at specified indices.
- Parameters
dataset (Dataset) – The whole Dataset
indices (sequence) – Indices in the whole set selected for subset
labels (sequence) – targets as required for the indices. will be the same length as indices
- class distil.utils.utils.SubsetWithTargetsSingleChannel(dataset, indices, labels)[source]
Bases:
Dataset
Subset of a dataset at specified indices.
- Parameters
dataset (Dataset) – The whole Dataset
indices (sequence) – Indices in the whole set selected for subset
labels (sequence) – targets as required for the indices. will be the same length as indices