Classes and helper functions for all Deep Learning models used in this library.
 
DEVICE
device(type='cuda')

Dropouts

This is the RNNDropout from fast.ai renamed as InputDropout

dropout_mask[source]

dropout_mask(x, sz, p)

Dropout mask as described in fast.ai

class InputDropout[source]

InputDropout(p=0.5) :: Module

InputDropout - same as RNNDropout described in fast.ai

x = torch.randn(2,3,5) #bs=2, seq_len(bptt)=3, x(emb width)=5
mask = dropout_mask(x, (2,1,5), 0.75)
x, mask, x*mask
(tensor([[[-7.6203e-01, -3.9588e-01, -7.1918e-01, -1.8210e+00, -1.4076e+00],
          [ 2.1205e+00, -8.1494e-01, -1.2896e+00,  8.1135e-01,  7.2622e-01],
          [ 1.0889e+00, -1.1225e+00, -5.6083e-01, -8.9040e-01, -1.8095e+00]],
 
         [[-1.1460e-01,  5.4790e-02,  7.1071e-01,  1.1865e+00, -7.4493e-01],
          [-7.9245e-02,  6.5569e-01, -1.4624e+00,  3.7764e-02,  4.5171e-01],
          [ 4.5527e-01,  4.0895e-01, -1.8539e-03,  1.4421e+00, -7.2452e-01]]]),
 tensor([[[0., 0., 4., 4., 4.]],
 
         [[0., 0., 0., 0., 0.]]]),
 tensor([[[-0.0000, -0.0000, -2.8767, -7.2838, -5.6305],
          [ 0.0000, -0.0000, -5.1585,  3.2454,  2.9049],
          [ 0.0000, -0.0000, -2.2433, -3.5616, -7.2381]],
 
         [[-0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000]]]))
mask.std(), (x*mask).std(), x.std()
(tensor(1.9322), tensor(2.5112), tensor(0.9989))
mask.mean(), (x*mask).mean(), x.mean()
(tensor(1.2000), tensor(-0.9281), tensor(-0.1524))
dp = InputDropout(0.3)
tst_input = torch.randn(2,3,5)
tst_input, dp(tst_input)
(tensor([[[ 1.1242,  0.7601, -0.3914,  0.8785,  0.3521],
          [ 1.1843, -0.0696,  1.4045, -1.3105,  0.9592],
          [ 1.1924, -1.1984,  0.2293,  2.1980,  1.9559]],
 
         [[ 0.3717,  0.4463, -1.3992,  0.2004, -0.3131],
          [-0.2888, -0.3914, -0.9919, -0.6149,  0.1763],
          [-0.0667,  1.6526, -0.0421, -0.6019,  0.4115]]]),
 tensor([[[ 0.0000,  0.0000, -0.5592,  1.2551,  0.5030],
          [ 0.0000, -0.0000,  2.0065, -1.8722,  1.3703],
          [ 0.0000, -0.0000,  0.3276,  3.1401,  2.7942]],
 
         [[ 0.0000,  0.0000, -1.9989,  0.2862, -0.4473],
          [-0.0000, -0.0000, -1.4170, -0.8784,  0.2519],
          [-0.0000,  0.0000, -0.0601, -0.8598,  0.5878]]]))
tst_input.std(), dp(tst_input).std()
(tensor(0.9427), tensor(1.2440))

Linear Layers

linear_layer[source]

linear_layer(in_features, out_features, bn=False, dropout_p=0.0)

Create a single linear layer

create_linear_layers[source]

create_linear_layers(in_features_start, num_layers, bn=False, dropout_p=0.0)

Create linear layers

out, m = create_linear_layers(100, 4, bn=True)
m, out
(Sequential(
   (0): Linear(in_features=100, out_features=200, bias=True)
   (1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (2): ReLU(inplace=True)
   (3): Dropout(p=0.0, inplace=False)
   (4): Linear(in_features=200, out_features=400, bias=True)
   (5): BatchNorm1d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (6): ReLU(inplace=True)
   (7): Dropout(p=0.0, inplace=False)
   (8): Linear(in_features=400, out_features=800, bias=True)
   (9): BatchNorm1d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (10): ReLU(inplace=True)
   (11): Dropout(p=0.0, inplace=False)
   (12): Linear(in_features=800, out_features=1600, bias=True)
   (13): BatchNorm1d(1600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (14): ReLU(inplace=True)
   (15): Dropout(p=0.0, inplace=False)
 ),
 1600)

init_lstm[source]

init_lstm(m, initrange, zero_bn=False)

Initialize LSTM

class EHR_LSTM[source]

EHR_LSTM(demograph_dims, rec_dims, demograph_wd, rec_wd, num_labels, lstm_layers=4, linear_layers=4, initrange=0.3, bn=False, input_drp=0.3, lstm_drp=0.3, linear_drp=0.3, zero_bn=False) :: Module

Based on LSTM described in this paper - https://arxiv.org/abs/1801.07860

Load Data

SYNTHEA_DATAGEN_DATES['1K']
'03-15-2021'
preprocess_ehr_dataset(PATH_1K, SYNTHEA_DATAGEN_DATES['1K'], conditions_dict=CONDITIONS, age_start=240, age_stop=360, age_in_months=True)
Since data is pre-cleaned, skipping Cleaning, Splitting and Vocab-creation
------------------- Creating patient lists -------------------
702 total patients completed, saved patient list to /home/vinod/.lemonpie/datasets/synthea/1K/processed/months_240_to_360/train
234 total patients completed, saved patient list to /home/vinod/.lemonpie/datasets/synthea/1K/processed/months_240_to_360/valid
235 total patients completed, saved patient list to /home/vinod/.lemonpie/datasets/synthea/1K/processed/months_240_to_360/test
CONDITIONS.keys()
dict_keys(['diabetes', 'stroke', 'alzheimers', 'coronary_heart', 'lung_cancer', 'breast_cancer', 'rheumatoid_arthritis', 'epilepsy'])
labels = ['diabetes', 'stroke', 'alzheimers', 'coronary_heart', 'breast_cancer', 'epilepsy']
ehr_1K_data = EHRData(PATH_1K, labels, age_start=240, age_stop=360, age_in_months=True)
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = get_all_emb_dims(EhrVocabList.load(PATH_1K))
train_dl, valid_dl, train_pos_wts, valid_pos_wts = ehr_1K_data.get_data()

Inspect Data

ehr_1K_data.splits.train
PatientList (702 items)
base path:/home/vinod/.lemonpie/datasets/synthea/1K; split:train; age span:120 months
age_start:240; age_stop:360; age_type:months
ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:af1495be-5077-4087-98b1-9ff624c7582c, birthdate:2008-07-17, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:f23e12d9-2ec6-4006-b041-ea78d374e9c9, birthdate:2014-09-06, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:1968aa31-5fce-461a-9486-6e385a7b75e7, birthdate:1986-04-11, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:1211c8ff-ab73-49f3-b2ab-87b7a03f6167, birthdate:1972-03-24, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:27a8b7b6-007d-4036-82a7-80a9ab670dcb, birthdate:2005-04-13, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:532696f2-0b76-4eb0-9aea-a74e2fb1bed2, birthdate:1967-05-18, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:8641e13a-c832-4d97-811a-b735d0abb45e, birthdate:1982-10-06, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:7f874045-4062-405d-8c23-abb12d0af23e, birthdate:1972-05-20, [('diabetes', False), ('stroke', False)].., device:cpu
ptid:0b6a83ae-fcb1-4b75-9ffa-d52898167d66, birthdate:1989-08-05, [('diabetes', False), ('stroke', False)].., device:cpu...]

Inspect a single patient

pt = ehr_1K_data.splits.train[3] 
pt.obs_nums, pt.obs_offsts
(tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0, 108, 112, 121,   0,   5,   8,  14,  18,
          28,  32,  38,  44, 467,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         108, 112, 121,   5,   8,  14,  19,  28,  32,  40,  44, 467,   0,   5,
           8,  14,  19,  28,  33,  39,  45, 467,   0,   0,   0,   0,   0,   0,
           0,   0,   5,   8,  14,  19,  28,  33,  38,  44, 467, 108, 112, 121,
           0,   0,   5,   9,  14,  19,  29,  33,  40,  45,  54,  60,  65,  71,
          73,  79,  85,  90,  93,  98, 103, 467,   0,   0,   0,   0,   0,   0,
           0,   0, 108, 112, 121,   0,   0,   5,   8,  15,  19,  28,  33,  38,
          45, 467,   0,   0,  49,   0,   0,   0,   0,   0, 108, 112, 121,   0,
           0,   5,   7,  15,  19,  28,  33,  38,  42, 467,   0,   0,   0,   0,
           0,   0,   0,   0, 108, 112, 121,   0,   0,   5,   8,  15,  19,  28,
          33,  38,  45, 467,   0,   0,   0,   0,   0,   0,   0,   0]),
 tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  51,  52,  61,  62,  63,  64,  65,
          66,  67,  68,  69,  70,  73,  82,  83,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 112, 113, 114, 134, 135, 136, 137, 138, 139, 140, 141,
         142, 145, 146, 147, 156, 157, 158, 159, 160, 161, 162, 163, 164, 167,
         168, 169, 178, 179, 180, 181, 182, 183, 184, 185, 186, 189, 190, 191,
         200, 201, 202, 203, 204, 205, 206, 207]))
len(pt.img_nums), len(pt.img_offsts)
(120, 120)
len(pt.obs_nums), len(pt.obs_offsts)
(208, 120)
pt
ptid:1968aa31-5fce-461a-9486-6e385a7b75e7, birthdate:1986-04-11, [('diabetes', False), ('stroke', False)].., device:cpu
len(train_dl), len(valid_dl)
(11, 2)
train_pos_wts, valid_pos_wts
(tensor([15., 22., 58., 17., 63., 46.]),
 tensor([16., 32., 32., 20., 28., 46.]))
demograph_dims, demograph_dims_wd
([(33, 8),
  (14, 7),
  (124, 11),
  (5, 5),
  (7, 6),
  (4, 5),
  (4, 5),
  (243, 14),
  (208, 13),
  (3, 5),
  (181, 13)],
 92)
rec_dims, rec_dims_wd
([(536, 17),
  (26, 8),
  (50, 9),
  (226, 13),
  (11, 6),
  (137, 12),
  (184, 13),
  (20, 7)],
 85)

Inspect Model

model = EHR_LSTM(demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd, num_labels=len(labels))
model
EHR_LSTM(
  (embs): ModuleList(
    (0): Embedding(33, 8)
    (1): Embedding(14, 7)
    (2): Embedding(124, 11)
    (3): Embedding(5, 5)
    (4): Embedding(7, 6)
    (5): Embedding(4, 5)
    (6): Embedding(4, 5)
    (7): Embedding(243, 14)
    (8): Embedding(208, 13)
    (9): Embedding(3, 5)
    (10): Embedding(181, 13)
  )
  (embgs): ModuleList(
    (0): EmbeddingBag(536, 17, mode=mean)
    (1): EmbeddingBag(26, 8, mode=mean)
    (2): EmbeddingBag(50, 9, mode=mean)
    (3): EmbeddingBag(226, 13, mode=mean)
    (4): EmbeddingBag(11, 6, mode=mean)
    (5): EmbeddingBag(137, 12, mode=mean)
    (6): EmbeddingBag(184, 13, mode=mean)
    (7): EmbeddingBag(20, 7, mode=mean)
  )
  (input_dp): InputDropout()
  (lstm): LSTM(85, 85, num_layers=4, batch_first=True, dropout=0.3)
  (lin): Sequential(
    (0): Linear(in_features=178, out_features=356, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=356, out_features=712, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=712, out_features=1424, bias=True)
    (7): ReLU(inplace=True)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=1424, out_features=2848, bias=True)
    (10): ReLU(inplace=True)
    (11): Dropout(p=0.3, inplace=False)
  )
  (lin_o): Linear(in_features=2848, out_features=6, bias=True)
)
#     print(f'{name}::\n{param}')

Test fit()

train_loss_fn, valid_loss_fn = get_loss_fn(train_pos_wts), get_loss_fn(valid_pos_wts)
model = EHR_LSTM(demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd, len(labels)).to(DEVICE) #put on GPU before instantiating optim
optimizer = torch.optim.Adagrad(model.parameters())
h = RunHistory(labels)
%time h = fit(2, h, model, train_loss_fn, valid_loss_fn, optimizer, auroc_score, \
              train_dl, valid_dl, to_chkpt_path=MODEL_STORE, from_chkpt_path=None, verbosity=1)
epoch |     train loss |     train aurocs                  valid loss |     valid aurocs    
----------------------------------------------------------------------------------------------------
    0 |          3.688 | [0.600 0.678 0.625 0.513]              1.294 | [0.665 0.699 0.980 0.843]
    1 |          1.239 | [0.725 0.867 0.909 0.700]              1.092 | [0.686 0.701 0.979 0.831]
Checkpointed to "/home/vinod/.lemonpie/models/checkpoint.tar"
CPU times: user 5.62 s, sys: 39.6 ms, total: 5.66 s
Wall time: 5.68 s
%time h = fit(3, h, model, train_loss_fn, valid_loss_fn, optimizer, auroc_score, \
              train_dl, valid_dl, to_chkpt_path=MODEL_STORE, from_chkpt_path=MODEL_STORE, verbosity=1)
From "/home/vinod/.lemonpie/models/checkpoint.tar", loading model ...
loading optimizer and epoch_index ...
epoch |     train loss |     train aurocs                  valid loss |     valid aurocs    
----------------------------------------------------------------------------------------------------
    2 |          1.046 | [0.763 0.864 0.912 0.727]              1.057 | [0.685 0.707 0.984 0.817]
    3 |          0.995 | [0.795 0.897 0.936 0.794]              1.166 | [0.696 0.713 0.979 0.804]
    4 |          0.850 | [0.799 0.922 0.951 0.748]              1.138 | [0.625 0.714 0.977 0.799]
Checkpointed to "/home/vinod/.lemonpie/models/checkpoint.tar"
CPU times: user 8.41 s, sys: 48 ms, total: 8.46 s
Wall time: 8.49 s

Test predict()

test_dl, test_pos_wts = ehr_1K_data.get_test_data()
len(test_dl), test_pos_wts
(2, tensor([ 11.,  20.,  38.,  20., 116., 116.]))
test_loss_fn = get_loss_fn(test_pos_wts)
h = predict(h, model, test_loss_fn, auroc_score, test_dl, chkpt_path=MODEL_STORE)
From "/home/vinod/.lemonpie/models/checkpoint.tar", loading model ...
test loss = 1.0411735773086548
test aurocs = [0.713938 0.920455 0.911936 0.854708 0.948498 0.39485 ]

Test plotting + results

plot_fit_results(h, labels)
h = summarize_prediction(h, labels)
Prediction Summary ...
                auroc_score  optimal_threshold     auroc_95_ci
diabetes           0.713938           0.633300  (0.611, 0.802)
stroke             0.920455           0.784929   (0.841, 0.98)
alzheimers         0.911936           0.634823  (0.837, 0.967)
coronary_heart     0.854708           0.557420  (0.756, 0.934)
breast_cancer      0.948498           0.778223    (0.876, 1.0)
epilepsy           0.394850           0.715678  (0.128, 0.679)
h.prediction_summary
auroc_score optimal_threshold auroc_95_ci
diabetes 0.713938 0.633300 (0.611, 0.802)
stroke 0.920455 0.784929 (0.841, 0.98)
alzheimers 0.911936 0.634823 (0.837, 0.967)
coronary_heart 0.854708 0.557420 (0.756, 0.934)
breast_cancer 0.948498 0.778223 (0.876, 1.0)
epilepsy 0.394850 0.715678 (0.128, 0.679)

Else AUROC score calculation is not possible resulting in this error ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.

ehr_1K_data.load_splits()
ehr_1K_data.splits.get_label_counts(list(CONDITIONS.keys()))
train valid test total
diabetes 43 14 19 76
stroke 30 7 11 48
alzheimers 12 7 6 25
coronary_heart 39 11 11 61
lung_cancer 12 0 2 14
breast_cancer 11 8 2 21
rheumatoid_arthritis 2 0 0 2
epilepsy 15 5 2 22

In this small 1K dataset, 'lung_cancer' and 'rheumatoid_arthritis' have single classes in some splits as seen in the prevalence counts above and would result in the above failure when fit is run.

However in large datasets the possibility of this is very low, but its something to watch out for.

Sizes (+ Conv Arithmetic)

bs = 64 #batch size # of patients
wd = 85 #rec_emb_width (concat of 1 yr of records)
ht = 240 #num of years of pt data (seq_len or bptt in lstm)

Each pt is a 20 by 85 matrix

  • 20 years on axis 0 (height)
  • 85 codes on axis 1 (width)
tst_pts = torch.randn(bs,ht,wd)
tst_pts.shape
torch.Size([64, 240, 85])

But ...

  • Input :: $(N, C_{in}, H_{in}, W_{in})$
  • Output:: $(N, C_{out}, H_{out}, W_{out})$

So need to reshape to insert $C_{in}$ (which is 1 in this case) after bs

tst_pts = tst_pts.reshape(bs,1,ht,wd)
tst_pts.shape
torch.Size([64, 1, 240, 85])
m = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=2, kernel_size=(5,5), padding=2), nn.ReLU(),
    nn.Conv2d(2,4,kernel_size=(3,3), padding=1), nn.ReLU(),
    nn.Conv2d(4,8,kernel_size=(3,3), stride=2, padding=1), nn.ReLU(),
    nn.Conv2d(8,16,kernel_size=(3,3), stride=2, padding=1), nn.ReLU(),
    nn.Conv2d(16,32,kernel_size=(3,3), stride=2, padding=1), nn.ReLU(),
    nn.AdaptiveMaxPool2d((4,4)),
    nn.Flatten()
)
out = m(tst_pts)
out.shape
torch.Size([64, 512])
  • AdaptivePool ensures output before Flatten is bs x 16 x 5 x 5
  • And thus Flatten will always flatten it to bs x 400
  • So can use 400 safely - no matter the size of the input (which will change based on vocab dims)
for name, param in m.named_parameters():
    print(name)
0.weight
0.bias
2.weight
2.bias
4.weight
4.bias
6.weight
6.bias
8.weight
8.bias

init_cnn[source]

init_cnn(m, initrange, zero_bn=False)

Initialize CNN as described in fast.ai

conv_layer[source]

conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=1, bn=False)

Create a single conv layer - as described in fast.ai

class EHR_CNN[source]

EHR_CNN(demograph_dims, rec_dims, demograph_wd, rec_wd, num_labels, linear_layers=4, initrange=0.3, bn=False, input_drp=0.3, linear_drp=0.3, zero_bn=False) :: Module

Based on the model described in the Deepr paper - https://arxiv.org/abs/1607.07519

Load Data + Inspect

ehr_1K_data = EHRData(PATH_1K, labels, age_start=240, age_stop=360, age_in_months=True, lazy_load_gpu=False) #entire dataset on GPU
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = get_all_emb_dims(EhrVocabList.load(PATH_1K))
train_dl, valid_dl, train_pos_wts, valid_pos_wts = ehr_1K_data.get_data()

Inspect Model

model = EHR_CNN(demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd, len(labels))
model
EHR_CNN(
  (embs): ModuleList(
    (0): Embedding(33, 8)
    (1): Embedding(14, 7)
    (2): Embedding(124, 11)
    (3): Embedding(5, 5)
    (4): Embedding(7, 6)
    (5): Embedding(4, 5)
    (6): Embedding(4, 5)
    (7): Embedding(243, 14)
    (8): Embedding(208, 13)
    (9): Embedding(3, 5)
    (10): Embedding(181, 13)
  )
  (embgs): ModuleList(
    (0): EmbeddingBag(536, 17, mode=mean)
    (1): EmbeddingBag(26, 8, mode=mean)
    (2): EmbeddingBag(50, 9, mode=mean)
    (3): EmbeddingBag(226, 13, mode=mean)
    (4): EmbeddingBag(11, 6, mode=mean)
    (5): EmbeddingBag(137, 12, mode=mean)
    (6): EmbeddingBag(184, 13, mode=mean)
    (7): EmbeddingBag(20, 7, mode=mean)
  )
  (input_dp): InputDropout()
  (lin): Sequential(
    (0): Linear(in_features=605, out_features=1210, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=1210, out_features=2420, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=2420, out_features=4840, bias=True)
    (7): ReLU(inplace=True)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=4840, out_features=9680, bias=True)
    (10): ReLU(inplace=True)
    (11): Dropout(p=0.3, inplace=False)
  )
  (lin_o): Linear(in_features=9680, out_features=6, bias=True)
  (cnn): Sequential(
    (0): Conv2d(1, 2, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(4, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): AdaptiveMaxPool2d(output_size=(4, 4))
    (11): Flatten(start_dim=1, end_dim=-1)
  )
)
#     print(f'{name}::\n{param}')

Test fit()

train_loss_fn, valid_loss_fn = get_loss_fn(train_pos_wts), get_loss_fn(valid_pos_wts)
model = EHR_CNN(demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd, len(labels)).to(DEVICE)
optimizer = torch.optim.Adagrad(model.parameters())
h = RunHistory(labels)
%time h = fit(3, h, model, train_loss_fn, valid_loss_fn, optimizer, auroc_score, \
              train_dl, valid_dl, to_chkpt_path=MODEL_STORE, from_chkpt_path=None, verbosity=1)
epoch |     train loss |     train aurocs                  valid loss |     valid aurocs    
----------------------------------------------------------------------------------------------------
    0 |        220.830 | [0.502 0.510 0.532 0.554]              1.192 | [0.677 0.751 0.982 0.788]
    1 |          1.291 | [0.661 0.816 0.801 0.710]              1.232 | [0.675 0.761 0.986 0.443]
    2 |          1.144 | [0.736 0.853 0.858 0.632]              1.139 | [0.665 0.755 0.982 0.818]
Checkpointed to "/home/vinod/.lemonpie/models/checkpoint.tar"
CPU times: user 6.36 s, sys: 539 ms, total: 6.9 s
Wall time: 7.16 s
%time h = fit(2, h, model, train_loss_fn, valid_loss_fn, optimizer, auroc_score, \
              train_dl, valid_dl, to_chkpt_path=MODEL_STORE, from_chkpt_path=MODEL_STORE, verbosity=1)
From "/home/vinod/.lemonpie/models/checkpoint.tar", loading model ...
loading optimizer and epoch_index ...
epoch |     train loss |     train aurocs                  valid loss |     valid aurocs    
----------------------------------------------------------------------------------------------------
    3 |          1.060 | [0.693 0.918 0.937 0.765]              1.280 | [0.690 0.773 0.986 0.806]
    4 |          1.185 | [0.758 0.897 0.904 0.798]              1.187 | [0.683 0.769 0.976 0.796]
Checkpointed to "/home/vinod/.lemonpie/models/checkpoint.tar"
CPU times: user 4.52 s, sys: 713 ms, total: 5.24 s
Wall time: 5.4 s

Test predict()

test_dl, test_pos_wts = ehr_1K_data.get_test_data()
len(test_dl), test_pos_wts
(2, tensor([ 11.,  20.,  38.,  20., 116., 116.]))
test_loss_fn = get_loss_fn(test_pos_wts)
h = predict(h, model, test_loss_fn, auroc_score, test_dl, chkpt_path=MODEL_STORE)
From "/home/vinod/.lemonpie/models/checkpoint.tar", loading model ...
test loss = 1.0468118786811829
test aurocs = [0.742446 0.883523 0.909753 0.814123 0.88412  0.416309]

Test plotting + results

plot_fit_results(h, labels)
h = summarize_prediction(h, labels)
Prediction Summary ...
                auroc_score  optimal_threshold     auroc_95_ci
diabetes           0.742446           0.454398  (0.628, 0.836)
stroke             0.883523           0.531821   (0.761, 0.97)
alzheimers         0.909753           0.679645   (0.83, 0.964)
coronary_heart     0.814123           0.516481  (0.683, 0.915)
breast_cancer      0.884120           0.626253   (0.825, 0.94)
epilepsy           0.416309           0.633332  (0.305, 0.528)
h.prediction_summary
auroc_score optimal_threshold auroc_95_ci
diabetes 0.742446 0.454398 (0.628, 0.836)
stroke 0.883523 0.531821 (0.761, 0.97)
alzheimers 0.909753 0.679645 (0.83, 0.964)
coronary_heart 0.814123 0.516481 (0.683, 0.915)
breast_cancer 0.884120 0.626253 (0.825, 0.94)
epilepsy 0.416309 0.633332 (0.305, 0.528)