Source code for utils.split_response_dist_plots

"""Module to plot distributions of response values in each subset of a dataset generated by a split"""

import os
import numpy as np
import pandas as pd
from atomsci.ddm.pipeline import parameter_parser as parse

import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats


# ---------------------------------------------------------------------------------------------------------------------------------
[docs] def plot_split_subset_response_distrs(params, axes=None, plot_size=7): """Plot the distributions of the response variable(s) in each split subset of a dataset. Args: params (argparse.Namespace or dict): Structure containing dataset and split parameters. The following parameters are required, if not set to default values: | - dataset_key | - split_uuid | - split_strategy | - splitter | - split_valid_frac | - split_test_frac | - num_folds | - smiles_col | - response_cols axes (matplotlib.Axes): Axes to draw plots in, if provided plot_size (float): Height of plots; ignored if axes is provided Returns: None """ # Save current matplotlib color cycle and switch to 'colorblind' palette old_palette = sns.color_palette() sns.set_palette('colorblind') if isinstance(params, dict): params = parse.wrapper(params) dset_df, split_label = get_split_labeled_dataset(params) if params.split_strategy == 'k_fold_cv': subset_order = sorted(set(dset_df.split_subset.values)) else: subset_order = ['train', 'valid', 'test'] wdist_df = compute_split_subset_wasserstein_distances(params) if axes is None: fig, axes = plt.subplots(1, len(params.response_cols), figsize=(plot_size*len(params.response_cols), plot_size)) if len(params.response_cols) == 1: axes = [axes] else: axes = axes.flatten() for colnum, col in enumerate(params.response_cols): ax = axes[colnum] if params.split_strategy == 'train_valid_test': tvv_wdist = wdist_df[(wdist_df.split_subset == 'valid') & (wdist_df.response_col == col)].distance.values[0] tvt_wdist = wdist_df[(wdist_df.split_subset == 'test') & (wdist_df.response_col == col)].distance.values[0] if params.prediction_type == 'regression': ax = sns.kdeplot(data=dset_df, x=col, hue='split_subset', hue_order=subset_order, bw_adjust=0.7, fill=True, common_norm=False, ax=ax) ax.set_title(f"{col} distribution by subset under {split_label}") if params.split_strategy == 'train_valid_test': ax.set_title(f"{col} distribution by subset under {split_label}\n" \ f"Wasserstein distances: valid = {tvv_wdist:.3f}, test = {tvt_wdist:.3f}") else: ax.set_title(f"{col} distribution by subset under {split_label}") else: pct_active = [] for ss in subset_order: ss_df = dset_df[dset_df.split_subset == ss] nactive = np.nansum(ss_df[col].values) pct_active.append(100*nactive/sum(ss_df[col].notna())) active_df = pd.DataFrame(dict(subset=subset_order, percent_active=pct_active)) ax = sns.barplot(data=active_df, x='subset', y='percent_active', hue='subset', ax=ax) if params.split_strategy == 'train_valid_test': ax.set_title(f"Percent of {col} = 1 by subset under {split_label}" \ f"Wasserstein distances: valid = {tvv_wdist:.3f}, test = {tvt_wdist:.3f}") else: ax.set_title(f"Percent of {col} = 1 by subset under {split_label}") ax.set_xlabel('') # Restore previous matplotlib color cycle sns.set_palette(old_palette)
# ---------------------------------------------------------------------------------------------------------------------------------
[docs] def compute_split_subset_wasserstein_distances(params): """Compute the Wasserstein ("earth-moving") distance between the distributions of the response variable(s) in the validation and test sets to that of the training set. In the case of a k-fold CV split, compare the distributions of folds 1 through k-1 and the test set to that of fold 0. Args: params (argparse.Namespace or dict): Structure containing dataset and split parameters. The following parameters are required, if not set to default values: | - dataset_key | - split_uuid | - split_strategy | - splitter | - split_valid_frac | - split_test_frac | - num_folds | - smiles_col | - response_cols Returns: (DataFrame): A table of Wasserstein distances relative to the training set or fold 0 for each other split subset, for each response variable. """ if isinstance(params, dict): params = parse.wrapper(params) dset_df, split_label = get_split_labeled_dataset(params) if params.split_strategy == 'k_fold_cv': subset_order = sorted(set(dset_df.split_subset.values)) else: subset_order = ['train', 'valid', 'test'] response_vars = [] subsets = [] distances = [] train_set = subset_order[0] for col in params.response_cols: train_vals = dset_df[dset_df.split_subset == train_set][col].values train_vals = train_vals[~np.isnan(train_vals)] for subset in subset_order[1:]: subset_vals = dset_df[dset_df.split_subset == subset][col].values subset_vals = subset_vals[~np.isnan(subset_vals)] dist = stats.wasserstein_distance(train_vals, subset_vals) response_vars.append(col) subsets.append(subset) distances.append(dist) dist_df = pd.DataFrame(dict(response_col=response_vars, split_subset=subsets, distance=distances)) return dist_df
# ---------------------------------------------------------------------------------------------------------------------------------
[docs] def get_split_labeled_dataset(params): """Add a column to a dataset labeling the split subset for each row. Given a dataset and split parameters (including split_uuid) referenced in `params`, returns a data frame containing the dataset with an extra 'split_subset' column indicating the subset each data point belongs to. For standard 3-way splits, the labels will be 'train', 'valid' and 'test'. For a k-fold CV split, the labels will be 'fold_0' through 'fold_<k-1>' and 'test'. Args: params (argparse.Namespace or dict): Structure containing dataset and split parameters. The following parameters are required, if not set to default values: | - dataset_key | - split_uuid | - split_strategy | - splitter | - split_valid_frac | - split_test_frac | - num_folds | - smiles_col | - response_cols Returns: A tuple (dset_df, split_label): | - dset_df (DataFrame): The dataset specified by `params.dataset_key`, with additional column `split_subset`. | - split_label (str): A short description of the split, useful for plot labeling. """ if isinstance(params, dict): params = parse.wrapper(params) dset_df = pd.read_csv(params.dataset_key, dtype={params.id_col: str}) if params.split_strategy == 'k_fold_cv': split_file = f"{os.path.splitext(params.dataset_key)[0]}_{params.num_folds}_fold_cv_{params.splitter}_{params.split_uuid}.csv" else: split_file = f"{os.path.splitext(params.dataset_key)[0]}_{params.split_strategy}_{params.splitter}_{params.split_uuid}.csv" split_df = pd.read_csv(split_file, dtype={'cmpd_id': str}).rename(columns={'cmpd_id': 'compound_id'}) dset_df = dset_df.merge(split_df, how='left', on='compound_id') if params.split_strategy == 'k_fold_cv': dset_df['split_subset'] = [f"fold_{f}" for f in dset_df.fold.values] dset_df.loc[dset_df.subset == 'test', 'split_subset'] = 'test' nfolds = max(dset_df.fold.values) + 1 split_label = f"{nfolds}-fold {params.splitter} cross-validation split" else: dset_df['split_subset'] = dset_df.subset.values if params.splitter == 'multitaskscaffold': split_label = 'MTSS split' else: split_label = f"{params.splitter} split" return dset_df, split_label