Source code for utils.split_response_dist_plots

"""Module to plot distributions of response values in each subset of a dataset generated by a split"""

import os
import numpy as np
import pandas as pd
from atomsci.ddm.pipeline import parameter_parser as parse

import matplotlib.pyplot as plt
import seaborn as sns


# ---------------------------------------------------------------------------------------------------------------------------------
[docs] def plot_split_subset_response_distrs(params): """Plot the distributions of the response variable(s) in each split subset of a dataset. Args: params (argparse.Namespace or dict): Structure containing dataset and split parameters. The following parameters are required, if not set to default values: | - dataset_key | - split_uuid | - split_strategy | - splitter | - split_valid_frac | - split_test_frac | - num_folds | - smiles_col | - response_cols Returns: None """ if isinstance(params, dict): params = parse.wrapper(params) dset_df, split_label = get_split_labeled_dataset(params) if params.split_strategy == 'k_fold_cv': subset_order = sorted(set(dset_df.split_subset.values)) else: subset_order = ['train', 'valid', 'test'] for col in params.response_cols: if params.prediction_type == 'regression': fig, ax = plt.subplots(figsize=(9,7)) ax = sns.kdeplot(data=dset_df, x=col, hue='split_subset', hue_order=subset_order, bw_adjust=0.7, fill=True, common_norm=False, ax=ax) ax.set_title(f"{col} distribution by subset under {split_label}") else: pct_active = [] for ss in subset_order: ss_df = dset_df[dset_df.split_subset == ss] nactive = sum(ss_df[col].values) pct_active.append(100*nactive/len(ss_df)) active_df = pd.DataFrame(dict(subset=subset_order, percent_active=pct_active)) if params.split_strategy == 'k_fold_cv': fig, ax = plt.subplots(figsize=(9,7)) else: fig, ax = plt.subplots(figsize=(5,5)) ax = sns.barplot(data=active_df, x='subset', y='percent_active', hue='subset') ax.set_title(f"Percent of {col} = 1 by subset under {split_label}") ax.set_xlabel('')
# ---------------------------------------------------------------------------------------------------------------------------------
[docs] def get_split_labeled_dataset(params): """Add a column to a dataset labeling the split subset for each row. Given a dataset and split parameters (including split_uuid) referenced in `params`, returns a data frame containing the dataset with an extra 'split_subset' column indicating the subset each data point belongs to. For standard 3-way splits, the labels will be 'train', 'valid' and 'test'. For a k-fold CV split, the labels will be 'fold_0' through 'fold_<k-1>' and 'test'. Args: params (argparse.Namespace or dict): Structure containing dataset and split parameters. The following parameters are required, if not set to default values: | - dataset_key | - split_uuid | - split_strategy | - splitter | - split_valid_frac | - split_test_frac | - num_folds | - smiles_col | - response_cols Returns: A tuple (dset_df, split_label): | - dset_df (DataFrame): The dataset specified by `params.dataset_key`, with additional column `split_subset`. | - split_label (str): A short description of the split, useful for plot labeling. """ if isinstance(params, dict): params = parse.wrapper(params) dset_df = pd.read_csv(params.dataset_key, dtype={'compound_id': str}) if params.split_strategy == 'k_fold_cv': split_file = f"{os.path.splitext(params.dataset_key)[0]}_{params.num_folds}_fold_cv_{params.splitter}_{params.split_uuid}.csv" else: split_file = f"{os.path.splitext(params.dataset_key)[0]}_{params.split_strategy}_{params.splitter}_{params.split_uuid}.csv" split_df = pd.read_csv(split_file).rename(columns={'cmpd_id': 'compound_id'}) dset_df = dset_df.merge(split_df, how='left', on='compound_id') if params.split_strategy == 'k_fold_cv': dset_df['split_subset'] = [f"fold_{f}" for f in dset_df.fold.values] dset_df.loc[dset_df.subset == 'test', 'split_subset'] = 'test' nfolds = max(dset_df.fold.values) + 1 split_label = f"{nfolds}-fold {params.splitter} cross-validation split" else: dset_df['split_subset'] = dset_df.subset.values split_label = f"{params.splitter} split" return dset_df, split_label