First we import basics
- This creates a
settings.yaml
file from a template in a new directory (~/.lemonpie
) - Then it loads global variables needed for everything like ..
DEVICE
set to GPU if it exists, else CPU- Default paths for data(
DATA_STORE
), logs(LOG_STORE
), models(MODEL_STORE
) & experiments(EXPERIMENT_STORE
) - And some other variables used in pre-processing
We also import fastai.imports
for all other required external libs
from lemonpie.basics import *
from fastai.imports import *
DEVICE
DATA_STORE
Next we will download Synthea's 1,000 patients csv dataset into our datastore.
- For more details about the smalll dataset we are downloading, read the details on the Synthea website.
- The resulting directory structure must be
{DATA_STORE}/synthea/1K/raw_original
- We already have a global variable called
PATH_1K
for convenience
PATH_1K # global variable
So, we create the directory structure for PATH_1K
in our DATA_STORE
Path.mkdir(Path(PATH_1K), parents=True, exist_ok=True)
Next, we download the data
synthea_url = 'https://storage.googleapis.com/synthea-public/synthea_sample_data_csv_apr2020.zip'
import requests
data = requests.get(synthea_url)
data_file = Path(f'{PATH_1K}/data.zip')
if not data_file.exists():
print(f'Downloading from {synthea_url}')
with open(f'{PATH_1K}/data.zip', 'wb') as f:
f.write(data.content)
else:
print('File exists so skipping download')
print('Done!')
And unzip
from zipfile import ZipFile
with ZipFile(f'{PATH_1K}/data.zip', 'r') as zipObj:
zipObj.extractall(PATH_1K)
Synthea zip creates a csv
directory, the library requires it to be named raw_original
, so just renaming ..
os.listdir(PATH_1K)
os.rename(f'{PATH_1K}/csv', f'{PATH_1K}/raw_original')
os.listdir(PATH_1K)
os.listdir(f'{PATH_1K}/raw_original')
- Before we pre-process the dataset, we need to decide which conditions will be populated in the pre-processed patients.
- Then when we train the models, the labels we train them on, will be a subset (or full set) of these pre-processed conditions.
- An initial set of conditions are provided in the
CONDITIONS
dictionary that was created when we imported basics and created an initial settings file above.
CONDITIONS
Next run preprocessing
from lemonpie.preprocessing.transform import *
preprocess_ehr_dataset(PATH_1K, today=pd.Timestamp.today(), conditions_dict=CONDITIONS, from_raw_data=True)
The default settings for pre-processing generates patient data from 0 to 20 years of age, this can be changed by passing in a different age span (in years or months) to this function.
- Before we run the models, we need to decide which labels we want to train the models on.
- And these labels must be a subset of the conditions we used when pre-processing the dataset (as mentioned above).
- Say we pick the following subset
labels = ['diabetes', 'stroke', 'alzheimers', 'coronary_heart', 'breast_cancer', 'epilepsy']
Next, create the data object
- This provides data management tools like data loaders etc.
from lemonpie.data import *
ehr_1K_data = EHRData(PATH_1K, labels)
Load vocabs and their dimensions
- These were created in the pre-processing step above
from lemonpie.preprocessing.vocab import *
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = get_all_emb_dims(EhrVocabList.load(PATH_1K))
Get DataLoaders
train_dl, valid_dl, train_pos_wts, valid_pos_wts = ehr_1K_data.get_data()
Loss functions
from lemonpie.learn import *
train_loss_fn, valid_loss_fn = get_loss_fn(train_pos_wts), get_loss_fn(valid_pos_wts)
from lemonpie.models import *
model = EHR_LSTM(demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd, len(labels)).to(DEVICE)
Optimizer
optimizer = torch.optim.Adagrad(model.parameters())
Then run fit
h = RunHistory(labels)
from lemonpie.metrics import *
%time h = fit(5, h, model, train_loss_fn, valid_loss_fn, optimizer, auroc_score, \
train_dl, valid_dl, to_chkpt_path=MODEL_STORE, from_chkpt_path=None, verbosity=1)
plot_fit_results(h, labels)
Run inference on the test set
test_dl, test_pos_wts = ehr_1K_data.get_test_data()
test_loss_fn = get_loss_fn(test_pos_wts)
h = predict(h, model, test_loss_fn, auroc_score, test_dl, chkpt_path=MODEL_STORE)
h = summarize_prediction(h, labels)
h.prediction_summary
The way to find out is to get prevalence counts after creating the data object. See following example which uses the data object we created above.
ehr_1K_data.load_splits()
ehr_1K_data.splits.get_label_counts(list(CONDITIONS.keys()))
In this small 1K dataset, 'lung_cancer' and 'rheumatoid_arthritis' have single classes in some splits (e.g. no lung_cancer patients in validation set) as seen in the prevalence counts above and would result in the above failure when fit is run.
However, in large datasets the possibility of this is very low, but its something to watch out for.
model = EHR_CNN(demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd, num_labels=len(labels)).to(DEVICE)
h2 = RunHistory(labels)
h2 = fit(5, h, model, train_loss_fn, valid_loss_fn, optimizer, auroc_score, \
train_dl, valid_dl, to_chkpt_path=None, from_chkpt_path=None, verbosity=0.5)