Tabular Classification
Scikit-learn
English
hierarchical
healthcare
ehr
copd
clinical-risk
tabular
scikit-learn
clustering
unsupervised
Instructions to use stormid/copd-model-e with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Scikit-learn
How to use stormid/copd-model-e with Scikit-learn:
from huggingface_hub import hf_hub_download import joblib model = joblib.load( hf_hub_download("stormid/copd-model-e", "sklearn_model.joblib") ) # only load pickle files from sources you trust # read more about it here https://skops.readthedocs.io/en/stable/persistence.html - Notebooks
- Google Colab
- Kaggle
| """ | |
| Validation process | |
| """ | |
| import sys | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import mlflow | |
| from matplotlib import rcParams | |
| from tableone import TableOne | |
| # Set-up figures | |
| rcParams['figure.figsize'] = 20, 5 | |
| rcParams['axes.spines.top'] = False | |
| rcParams['axes.spines.right'] = False | |
| def plot_cluster_size(df, data_type): | |
| """ | |
| Produce a bar plot of cluster size | |
| -------- | |
| :param df: dataframe to plot | |
| :param data_type: type of data - train, test, val, rec, sup | |
| """ | |
| # Number of patients | |
| fig, ax = plt.subplots() | |
| df.groupby('cluster').size().plot(ax=ax, kind='barh') | |
| title = "Patient Cohorts" | |
| ax.set_title(title) | |
| ax.set_xlabel("Number of Patients", size=20) | |
| ax.set_ylabel("Cluster") | |
| plt.tight_layout() | |
| mlflow.log_figure(fig, 'fig/' + title.replace(' ', '_') + '_' + data_type + '.png') | |
| def plot_feature_hist(df, col, data_type): | |
| """ | |
| Produce a histogram plot for a chosen feature | |
| -------- | |
| :param df: dataframe to plot | |
| :param col: feature column to plot | |
| :param data_type: type of data - train, test, val, rec, sup | |
| """ | |
| fig, ax = plt.subplots() | |
| df.groupby('cluster')[col].plot(ax=ax, kind='hist', alpha=0.5) | |
| ax.set_xlabel(col) | |
| title = col + ' Histogram' | |
| ax.set_title(title, size=20) | |
| ax.legend() | |
| plt.tight_layout() | |
| mlflow.log_figure(fig, 'fig/' + title.replace(' ', '_') + '_' + data_type + '.png') | |
| def plot_feature_bar(data, col, typ, data_type): | |
| """ | |
| Produce a bar plot for a chosen feature | |
| -------- | |
| :param df: dataframe to plot | |
| :param col: feature column to plot | |
| :param typ: 'count' or 'percentage' | |
| :param data_type: type of data - train, test, val, rec, sup | |
| """ | |
| if typ == 'count': | |
| to_plot = data.groupby(['cluster']).apply( | |
| lambda x: x.groupby(col).size()) | |
| x_label = "Number" | |
| else: | |
| to_plot = data.groupby(['cluster']).apply( | |
| lambda x: 100 * x.groupby(col).size() / len(x)) | |
| x_label = "Percentage" | |
| fig, ax = plt.subplots() | |
| to_plot.plot(ax=ax, kind='barh') | |
| title = "Patient Cohorts" | |
| ax.set_title(title, size=20) | |
| ax.set_xlabel(x_label + " of patients") | |
| ax.set_ylabel("Cluster") | |
| plt.tight_layout() | |
| mlflow.log_figure(fig, 'fig/' + '_'.join((title.replace(' ', '_'), col, data_type + '.png'))) | |
| def plot_cluster_bar(data, typ, data_type): | |
| """ | |
| Produce a bar plot for a chosen feature | |
| -------- | |
| :param data: data to plot | |
| :param typ: 'count' or 'percentage' | |
| :param data_type: type of data - train, test, val, rec, sup | |
| """ | |
| fig, ax = plt.subplots() | |
| data.plot(ax=ax, kind='bar') | |
| ax.set_title(typ, size=20) | |
| ax.set_xlabel("Cluster") | |
| ax.set_ylabel("Percentage") | |
| ax.set_ylim(0, 100) | |
| plt.tight_layout() | |
| mlflow.log_figure(fig, 'fig/' + typ + '_' + data_type + '.png') | |
| def plot_events(df, data_type): | |
| """ | |
| Plot events in the next 12 months based on metric table | |
| -------- | |
| :param df: metric table | |
| :param data_type: type of data - train, test, val, rec, sup | |
| """ | |
| df = df.drop('SafeHavenID', axis=1).set_index('cluster') | |
| events = df.groupby('cluster').apply(lambda x: 100 * x.apply( | |
| lambda x: len(x[x == 1]) / len(x))) | |
| plot_cluster_bar(events, 'events', data_type) | |
| def process_deceased_metrics(col): | |
| """ | |
| Process deceased column for plotting | |
| ------- | |
| :param col: column to process | |
| """ | |
| n_deceased = 100 * ((col[col < '12+']).count()) / len(col) | |
| res = pd.DataFrame({'alive': [100 - n_deceased], 'deceased': [n_deceased]}) | |
| return res | |
| def plot_deceased(df, data_type): | |
| """ | |
| Plot events in the next 12 months based on metric table | |
| -------- | |
| :param df: metric table | |
| :param data_type: type of data - train, test, val, rec, sup | |
| """ | |
| survival = df.groupby('cluster')['time_to_death'].apply( | |
| process_deceased_metrics).reset_index().drop( | |
| 'level_1', axis=1).set_index('cluster') | |
| plot_cluster_bar(survival, 'survival', data_type) | |
| def plot_therapies(df_year, results, data_type): | |
| """ | |
| Plot patient therapies per cluster | |
| -------- | |
| :param df_year: unscaled data for current year | |
| :param results: cluster results and safehaven id | |
| :param data_type: type of data - train, test, val, rec, sup | |
| """ | |
| # Inhaler data for training group | |
| therapies = df_year[['SafeHavenID', 'single_inhaler', 'double_inhaler', 'triple_inhaler']] | |
| res_therapies = pd.merge(therapies, results, on='SafeHavenID', how='inner') | |
| # Find counts/percentage per cluster | |
| inhaler_cols = ['single_inhaler', 'double_inhaler', 'triple_inhaler'] | |
| inhals = res_therapies[['cluster'] + inhaler_cols].set_index('cluster') | |
| in_res = inhals.groupby('cluster').apply( | |
| lambda x: x.apply(lambda x: 100 * (x[x > 0].count()) / len(x))) | |
| # Number of people without an inhaler presc | |
| no_in = res_therapies.groupby('cluster').apply( | |
| lambda x: 100 * len(x[(x[inhaler_cols] == 0).all(axis=1)]) / len(x)).values | |
| # Rename columns for plotting | |
| in_res.columns = [c[0] for c in in_res.columns.str.split('_')] | |
| # Add those with no inhaler | |
| in_res['no_inhaler'] = no_in | |
| plot_cluster_bar(in_res, 'therapies', data_type) | |
| def main(): | |
| # Load in config items | |
| with open('../../../config.json') as json_config_file: | |
| config = json.load(json_config_file) | |
| data_path = config['model_data_path'] | |
| # Get datatype from cmd line | |
| data_type = sys.argv[1] | |
| run_name = sys.argv[2] | |
| run_id = sys.argv[3] | |
| # Set MLFlow parameters | |
| model_type = 'hierarchical' | |
| experiment_name = 'Model E - Date Specific: ' + model_type | |
| mlflow.set_tracking_uri('http://127.0.0.1:5000/') | |
| mlflow.set_experiment(experiment_name) | |
| mlflow.start_run(run_id=run_id) | |
| # Read in unscaled data, results and column names used to train model | |
| columns = np.load(data_path + run_name + '_cols.npy', allow_pickle=True) | |
| df_clusters = pd.read_pickle(data_path + "_".join((run_name, data_type, 'clusters.pkl'))) | |
| df_reduced = df_clusters[list(columns) + ['cluster']] | |
| # Number of patients | |
| plot_cluster_size(df_reduced, data_type) | |
| # Generate mean/std table | |
| t1_year = TableOne(df_reduced, categorical=[], groupby='cluster', pval=True) | |
| t1yr_file = data_path + 't1_year_' + run_name + '_' + data_type + '.html' | |
| t1_year.to_html(t1yr_file) | |
| mlflow.log_artifact(t1yr_file) | |
| # Histogram feature plots | |
| plot_feature_hist(df_clusters, 'age', data_type) | |
| plot_feature_hist(df_clusters, 'albumin_med_2yr', data_type) | |
| # Bar plots | |
| df_clusters['sex'] = df_clusters['sex_bin'].map({0: 'Male', 1: 'Female'}) | |
| plot_feature_bar(df_clusters, 'sex', 'percent', data_type) | |
| plot_feature_bar(df_clusters, 'simd_decile', 'precent', data_type) | |
| # Metrics for following 12 months | |
| df_events = pd.read_pickle(data_path + 'metric_table_events.pkl') | |
| df_counts = pd.read_pickle(data_path + 'metric_table_counts.pkl') | |
| df_next = pd.read_pickle(data_path + 'metric_table_next.pkl') | |
| # Merge cluster number with SafeHavenID and metrics | |
| clusters = df_clusters[['SafeHavenID', 'cluster']] | |
| df_events = clusters.merge(df_events, on='SafeHavenID', how='left').fillna(0) | |
| df_counts = clusters.merge(df_counts, on='SafeHavenID', how='left').fillna(0) | |
| df_next = clusters.merge(df_next, on='SafeHavenID', how='left').fillna('12+') | |
| # Generate TableOne for events | |
| cat_cols = df_events.columns[2:] | |
| df_events[cat_cols] = df_events[cat_cols].astype('int') | |
| event_limit = dict(zip(cat_cols, 5 * [1])) | |
| event_order = dict(zip(cat_cols, 5 * [[1, 0]])) | |
| t1_events = TableOne(df_events[df_events.columns[1:]], groupby='cluster', | |
| limit=event_limit, order=event_order) | |
| t1_events_file = data_path + '_'.join(('t1', data_type, 'events', run_name + '.html')) | |
| t1_events.to_html(t1_events_file) | |
| mlflow.log_artifact(t1_events_file) | |
| # Generate TableOne for event counts | |
| count_cols = df_counts.columns[2:] | |
| df_counts[count_cols] = df_counts[count_cols].astype('int') | |
| t1_counts = TableOne(df_counts[df_counts.columns[1:]], categorical=[], groupby='cluster') | |
| t1_counts_file = data_path + '_'.join(('t1', data_type, 'counts', run_name + '.html')) | |
| t1_counts.to_html(t1_counts_file) | |
| mlflow.log_artifact(t1_counts_file) | |
| # Generate TableOne for time to next events | |
| next_cols = df_next.columns[2:] | |
| next_event_order = dict(zip(next_cols, 5 * [['1', '3', '6', '12', '12+']])) | |
| t1_next = TableOne(df_next[df_next.columns[1:]], groupby='cluster', | |
| order=next_event_order) | |
| t1_next_file = data_path + '_'.join(('t1', data_type, 'next', run_name + '.html')) | |
| t1_next.to_html(t1_next_file) | |
| mlflow.log_artifact(t1_next_file) | |
| # Plot metrics | |
| plot_events(df_events, data_type) | |
| plot_deceased(df_next, data_type) | |
| plot_therapies(df_clusters, clusters, data_type) | |
| # Stop ML Flow | |
| mlflow.end_run() | |
| main() | |