Classes and functions for training and predicting.

Helpers

save_to_checkpoint[source]

save_to_checkpoint(epoch_index, model, optimizer, path)

Save model and optimizer state_dicts to checkpoint

load_from_checkpoint[source]

load_from_checkpoint(model, path, optimizer=None, for_inference=False)

Load from checkpoint - model, optimizer & epoch_index for training or just model for inference

get_loss_fn[source]

get_loss_fn(pos_wts)

Return nn.BCEWithLogitsLoss with the given positive weights

class RunHistory[source]

RunHistory(labels)

Class to hold training and prediction run histories

class RunHistory:
    '''Class to hold training and prediction run histories'''
    def __init__(self, labels):
        self.train = self.valid = self.test = pd.DataFrame(columns=['loss', *labels])
        self.y_train = self.yhat_train = self.y_valid = self.yhat_valid = self.y_test = self.yhat_test = []
        self.prediction_summary = pd.DataFrame()

Note

  • y and y_hat are actual (ground truth and predicted) values from
    • the last epoch of fit() for train and valid,
    • the last run of predict() for test
  • train, valid and test are calculated loss and accuracy values at the end of each epoch
    • for test there is only a single epoch in each run

fit & predict

BCEWithLogitsLoss & torch.sigmoid

train[source]

train(model, train_dl, train_loss_fn, optimizer, lazy=True)

Train model using train dataset

evaluate[source]

evaluate(model, eval_dl, eval_loss_fn, lazy=True)

Evaluate model - used for validation (while training) and prediction

fit[source]

fit(epochs, history, model, train_loss_fn, valid_loss_fn, optimizer, accuracy_fn, train_dl, valid_dl, lazy=True, to_chkpt_path=None, from_chkpt_path=None, verbosity=0.75)

Fit model and return results in history

predict[source]

predict(history, model, test_loss_fn, accuracy_fn, test_dl, chkpt_path, lazy=True)

Predict and return results in history

Plotting

plot_loss[source]

plot_loss(history_df, title='Loss', axis=None)

Plot loss

plot_losses[source]

plot_losses(train_history, valid_history)

Plot multiple losses (train and valid) side by side

plot_aurocs[source]

plot_aurocs(history_df, title='AUROC Scores', axis=None)

Plot AUROC scores

plot_train_valid_aurocs[source]

plot_train_valid_aurocs(train_history, valid_history)

Plot train and valid AUROC scores side by side

Summarize

plot_fit_results[source]

plot_fit_results(history, labels)

All plots after fit - ROC curves, losses and AUROCs

summarize_prediction[source]

summarize_prediction(history, labels, plot=True)

Summarize after prediction - plot ROC curves, calculate auroc, optimal threshold and 95% CI for AUROC and return results in history

count_parameters[source]

count_parameters(model, printout=False)

Returns number of parameters in model