"""Module to interface model pipeline to model tracker service."""
import os
import tempfile
import sys
import pandas as pd
import json
import tarfile
import logging
logger = logging.getLogger('ATOM')
from atomsci.ddm.utils import datastore_functions as dsf
from atomsci.ddm.pipeline import parameter_parser as parse
from atomsci.ddm.pipeline import transformations as trans
import atomsci.ddm.utils.file_utils as futils
mlmt_supported = True
try:
from atomsci.clients import MLMTClient
except (ModuleNotFoundError, ImportError):
logger.debug("Model tracker client not supported in your environment; will save models in filesystem only.")
mlmt_supported = False
[docs]
class DatastoreInsertionException(Exception):
pass
[docs]
class MLMTClientInstantiationException(Exception):
pass
# *********************************************************************************************************************************
[docs]
def save_model(pipeline, collection_name='model_tracker', log=True):
"""Save the model.
Save the model files to the datastore and save the model metadata dict to the Mongo database.
Args:
pipeline (ModelPipeline object): the pipeline to use
collection_name (str): the name of the Mongo DB collection to use
log (bool): True if logs should be printed, default False
use_personal_client (bool): True if personal client should be used (i.e. for testing), default False
Returns:
None if insertion was successful, raises DatastoreInsertionException, MLMTClientInstantiationException
or MongoInsertionException otherwise
"""
if pipeline is None:
raise Exception('pipeline cannot be None.')
if not mlmt_supported:
logger.error("Model tracker not supported in your environment; can save models in filesystem only.")
return
# ModelPipeline.create_model_metadata() should be called before the call to save_model.
# Get the metadata dictionary from the model pipeline.
metadata_dict = pipeline.model_metadata
model_uuid = metadata_dict['model_uuid']
if model_uuid is None:
raise ValueError("model_uuid is missing from pipeline metadata.")
#### Part 1: Save the model tarball in the datastore ####
model = pipeline.model_wrapper
# Put tar file in a temporary directory that will automatically be destroyed when we're done
with tempfile.TemporaryDirectory() as tmp_dir:
tarball_path = os.path.join(tmp_dir, f"model_{model_uuid}.tar.gz")
save_model_tarball(pipeline.params.output_dir, tarball_path)
title = f"{model_uuid} model tarball"
ds_key = f"model_{model_uuid}_tarball"
uploaded_results = dsf.upload_file_to_DS(
bucket=pipeline.params.model_bucket, title=title, description=title, tags=[],
key_values={'model_uuid' : model_uuid, 'file_category': 'ml_model'}, filepath=tmp_dir,
filename=tarball_path, dataset_key=ds_key, client=pipeline.ds_client,
return_metadata=True)
if uploaded_results is None:
raise DatastoreInsertionException('Unable to upload title={title} to datastore.'.format(title=title))
# Get the dataset_oid for actual metadata file stored in datastore.
model_dataset_oid = uploaded_results['dataset_oid']
# By adding dataset_oid to the dict, we can immediately find the datastore file asssociated with a model.
metadata_dict['model_parameters']['model_dataset_oid'] = model_dataset_oid
#### Part 2: Save the model metadata in the model tracker ####
mlmt_client = dsf.initialize_model_tracker()
mlmt_client.save_metadata(collection_name=collection_name,
model_uuid=metadata_dict['model_uuid'],
model_metadata=metadata_dict)
if log:
logger.info('Successfully inserted into the database with model_uuid %s.' % model_uuid)
# *********************************************************************************************************************************
# *********************************************************************************************************************************
# *********************************************************************************************************************************
# *********************************************************************************************************************************
[docs]
def get_model_collection_by_uuid(model_uuid, mlmt_client=None):
"""Retrieve model collection given a uuid.
Args:
model_uuid (str): model uuid
mlmt_client: Ignored
Returns:
Matching collection name
Raises:
ValueError if there is no collection containing a model with the given uuid.
"""
if not mlmt_supported:
logger.error("Model tracker not supported in your environment; can load models from filesystem only.")
return None
mlmt_client = dsf.initialize_model_tracker()
collections = mlmt_client.collections.get_collection_names().result()
for col in collections:
if not col.startswith('old_'):
if mlmt_client.count_models(collection_name=col, model_uuid=model_uuid) > 0:
return col
raise ValueError('Collection not found for uuid: ' + model_uuid)
# *********************************************************************************************************************************
[docs]
def get_model_training_data_by_uuid(uuid):
"""Retrieve data used to train, validate, and test a model given the uuid
Args:
uuid (str): model uuid
Returns:
a tuple of datafraes containint training data, validation data, and test data including the compound ID, RDKIT SMILES, and response value
"""
if not mlmt_supported:
logger.error("Model tracker not supported in your environment; can load models from filesystem only.")
return None
model_meta = get_metadata_by_uuid(uuid)
response_col = model_meta['training_dataset']['response_cols']
smiles_col = model_meta['training_dataset']['smiles_col']
id_col = model_meta['training_dataset']['id_col']
full_data = dsf.retrieve_dataset_by_dataset_oid(model_meta['training_dataset']['dataset_oid'])
# Pull split data and merge into initial dataset
split_meta = dsf.search_datasets_by_key_value('split_dataset_uuid', model_meta['splitting_parameters']['split_uuid'])
split_oid = split_meta['dataset_oid'].values[0]
split_data = dsf.retrieve_dataset_by_dataset_oid(split_oid)
split_data['compound_id'] = split_data['cmpd_id']
split_data = split_data.drop(columns=['cmpd_id'])
full_data = pd.merge(full_data, split_data, how='inner',
left_on=[id_col], right_on=['compound_id'])
train_data = full_data[full_data['subset'] == 'train'][['compound_id',smiles_col,id_col,*response_col]].reset_index(drop=True)
valid_data = full_data[full_data['subset'] == 'valid'][['compound_id',smiles_col,id_col,*response_col]].reset_index(drop=True)
test_data = full_data[full_data['subset'] == 'test'][['compound_id',smiles_col,id_col,*response_col]].reset_index(drop=True)
return train_data, valid_data, test_data
# *********************************************************************************************************************************
[docs]
def save_model_tarball(output_dir, model_tarball_path):
"""Save the model parameters, metadata and transformers as a portable gzipped tar archive.
Args:
output_dir (str): Output directory from model training
model_tarball_path (str): Path of tarball file to be created
Returns:
None
"""
with tarfile.open(model_tarball_path, mode='w:gz') as tarball:
for filename in ['best_model', 'model_metadata.json', 'model_metrics.json']:
tarball.add(f"{output_dir}/{filename}", arcname=f"./{filename}")
if os.path.exists(f"{output_dir}/transformers.pkl"):
tarball.add(f"{output_dir}/transformers.pkl", arcname='./transformers.pkl')
logger.info(f"Wrote model tarball to {model_tarball_path}")
# *********************************************************************************************************************************
# *********************************************************************************************************************************
[docs]
def export_model(model_uuid, collection, model_dir, alt_bucket='CRADA'):
"""Export the metadata (parameters) and other files needed to recreate a model
from the model tracker database to a gzipped tar archive.
Args:
model_uuid (str): Model unique identifier
collection (str): Name of the collection holding the model in the database.
model_dir (str): Path to directory where the model metadata and parameter files will be written. The directory will
be created if it doesn't already exist. Subsequently, the directory contents will be packed into a gzipped tar archive
named model_dir.tar.gz.
alt_bucket (str): Alternate datastore bucket to search for model tarball and transformer objects.
Returns:
none
"""
if not mlmt_supported:
logger.info("Model tracker not supported in your environment; can load models from filesystem only.")
return
ds_client = dsf.config_client()
metadata_dict = get_metadata_by_uuid(model_uuid, collection_name=collection)
output_dir = model_dir
model_dir = f"{output_dir}/best_model"
# Convert metadata if it's in the old camelcase format (shouldn't exist anymore)
if 'ModelMetadata' in metadata_dict:
# Convert old style metadata
metadata_dict = convert_metadata(metadata_dict)
if 'model_parameters' in metadata_dict:
model_parameters = metadata_dict['model_parameters']
else:
raise Exception("Bad metadata for model UUID %s" % model_uuid)
model_params = parse.wrapper(metadata_dict)
# Override selected model training parameters
# Check that buckets where model tarball and transformers were saved still exist. If not, try alt_bucket.
trans_bucket_differs = (model_params.transformer_bucket != model_params.model_bucket)
model_bucket_meta = ds_client.ds_buckets.get_buckets(buckets=[model_params.model_bucket]).result()
if len(model_bucket_meta) == 0:
model_params.model_bucket = alt_bucket
if trans_bucket_differs:
trans_bucket_meta = ds_client.ds_buckets.get_buckets(buckets=[model_params.transformer_bucket]).result()
if len(trans_bucket_meta) == 0:
model_params.transformer_bucket = alt_bucket
else:
if len(model_bucket_meta) == 0:
model_params.transformer_bucket = alt_bucket
# Get the tarball containing the saved model from the datastore, and extract it into output_dir or model_dir.
extract_dir = extract_datastore_model_tarball(model_uuid, model_params.model_bucket, output_dir, model_dir)
if extract_dir == model_dir:
# Download the transformers pickle file if there is one
if trans.transformers_needed(model_params):
try:
if model_params.transformer_key is None:
transformer_key = 'transformers_%s.pkl' % model_uuid
else:
transformer_key = model_params.transformer_key
trans_fp = ds_client.open_bucket_dataset(model_params.transformer_bucket, transformer_key, mode='b')
trans_data = trans_fp.read()
trans_fp.close()
trans_path = "%s/transformers.pkl" % output_dir
trans_out = open(trans_path, mode='wb')
trans_out.write(trans_data)
trans_out.close()
del model_parameters['transformer_oid']
model_parameters['transformer_key'] = 'transformers.pkl'
except:
logger.info("Transformers expected but not found in datastore in bucket %s with key\n%s"
% (model_params.transformer_bucket, transformer_key))
raise
# Save the metadata params
model_parameters['save_results'] = False
meta_path = f"{output_dir}/model_metadata.json"
with open(meta_path, 'w') as meta_out:
json.dump(metadata_dict, meta_out, indent=4)
# Save the metrics to model_metrics.json
if 'training_metrics' in metadata_dict:
model_metrics = metadata_dict['training_metrics']
metrics_path = f"{output_dir}/model_metrics.json"
with open(metrics_path, 'w') as metrics_out:
json.dump(model_metrics, metrics_out, indent=4)
else:
logger.info(f"No metrics saved for model {model_uuid}")
# Create a new tarball containing both the metadata and the parameters from the retrieved model tarball
new_tarpath = "%s.tar.gz" % output_dir
tarball = tarfile.open(new_tarpath, mode='w:gz')
tarball.add(output_dir, arcname='.')
tarball.close()
logger.info("Wrote model files to %s" % new_tarpath)
# *********************************************************************************************************************************