"""Module to plot distributions of response values in each subset of a dataset generated by a split"""
import os
import numpy as np
import pandas as pd
from atomsci.ddm.pipeline import parameter_parser as parse
import matplotlib.pyplot as plt
import seaborn as sns
# ---------------------------------------------------------------------------------------------------------------------------------
[docs]
def plot_split_subset_response_distrs(params):
"""Plot the distributions of the response variable(s) in each split subset of a dataset.
Args:
params (argparse.Namespace or dict): Structure containing dataset and split parameters.
The following parameters are required, if not set to default values:
- dataset_key
- split_uuid
- split_strategy
- splitter
- split_valid_frac
- split_test_frac
- num_folds
- smiles_col
- response_cols
Returns:
None
"""
if isinstance(params, dict):
params = parse.wrapper(params)
dset_df, split_label = get_split_labeled_dataset(params)
if params.split_strategy == 'k_fold_cv':
subset_order = sorted(set(dset_df.split_subset.values))
else:
subset_order = ['train', 'valid', 'test']
for col in params.response_cols:
if params.prediction_type == 'regression':
fig, ax = plt.subplots(figsize=(9,7))
ax = sns.kdeplot(data=dset_df, x=col, hue='split_subset', hue_order=subset_order,
bw_adjust=0.7, fill=True, common_norm=False, ax=ax)
ax.set_title(f"{col} distribution by subset under {split_label}")
else:
pct_active = []
for ss in subset_order:
ss_df = dset_df[dset_df.split_subset == ss]
nactive = sum(ss_df[col].values)
pct_active.append(100*nactive/len(ss_df))
active_df = pd.DataFrame(dict(subset=subset_order, percent_active=pct_active))
if params.split_strategy == 'k_fold_cv':
fig, ax = plt.subplots(figsize=(9,7))
else:
fig, ax = plt.subplots(figsize=(5,5))
ax = sns.barplot(data=active_df, x='subset', y='percent_active', hue='subset')
ax.set_title(f"Percent of {col} = 1 by subset under {split_label}")
ax.set_xlabel('')
# ---------------------------------------------------------------------------------------------------------------------------------
[docs]
def get_split_labeled_dataset(params):
"""Add a column to a dataset labeling the split subset for each row.
Given a dataset and split parameters (including split_uuid) referenced in `params`, returns a data frame
containing the dataset with an extra 'split_subset' column indicating the subset each data point belongs to.
For standard 3-way splits, the labels will be 'train', 'valid' and 'test'. For a k-fold CV split, the labels will
be 'fold_0' through 'fold_<k-1>' and 'test'.
Args:
params (argparse.Namespace or dict): Structure containing dataset and split parameters.
The following parameters are required, if not set to default values:
- dataset_key
- split_uuid
- split_strategy
- splitter
- split_valid_frac
- split_test_frac
- num_folds
- smiles_col
- response_cols
Returns:
A tuple (dset_df, split_label):
- dset_df (DataFrame): The dataset specified by `params.dataset_key`, with additional column `split_subset`.
- split_label (str): A short description of the split, useful for plot labeling.
"""
if isinstance(params, dict):
params = parse.wrapper(params)
dset_df = pd.read_csv(params.dataset_key, dtype={'compound_id': str})
if params.split_strategy == 'k_fold_cv':
split_file = f"{os.path.splitext(params.dataset_key)[0]}_{params.num_folds}_fold_cv_{params.splitter}_{params.split_uuid}.csv"
else:
split_file = f"{os.path.splitext(params.dataset_key)[0]}_{params.split_strategy}_{params.splitter}_{params.split_uuid}.csv"
split_df = pd.read_csv(split_file).rename(columns={'cmpd_id': 'compound_id'})
dset_df = dset_df.merge(split_df, how='left', on='compound_id')
if params.split_strategy == 'k_fold_cv':
dset_df['split_subset'] = [f"fold_{f}" for f in dset_df.fold.values]
dset_df.loc[dset_df.subset == 'test', 'split_subset'] = 'test'
nfolds = max(dset_df.fold.values) + 1
split_label = f"{nfolds}-fold {params.splitter} cross-validation split"
else:
dset_df['split_subset'] = dset_df.subset.values
split_label = f"{params.splitter} split"
return dset_df, split_label