Source code for utils.model_retrain

#!/usr/bin/env python

# Purpose:
#
#  Script to take the existing model_metadata.json file or a directory and scans for 
#  model_metadata.json files and retrain, save them to DC 2.3 models. 
#
# usage: model_retrain.py [-h] -i INPUT [-o OUTPUT]
#
# optional arguments:
#   -h, --help            show this help message and exit
#
#  -i INPUT, --input INPUT     input directory/file
#  -o OUTPUT, --output OUTPUT  output result directory

import argparse
from datetime import timedelta
from datetime import datetime
import glob
import json
from pathlib import Path
import os
import sys
import tempfile
import time

import tarfile
import logging
import pandas as pd

logging.basicConfig()

logger = logging.getLogger(__name__)
#logger.setLevel(logging.DEBUG)

import atomsci.ddm.pipeline.model_pipeline as mp
import atomsci.ddm.pipeline.parameter_parser as parse
import atomsci.ddm.pipeline.model_tracker as mt
import atomsci.ddm.utils.datastore_functions as dsf
from atomsci.ddm.pipeline import compare_models as cmp
import atomsci.ddm.utils.file_utils as futils

import resource
resource.setrlimit(resource.RLIMIT_NOFILE, (65536, 65536))

mlmt_supported = True
try:
    from atomsci.clients import MLMTClient
except (ModuleNotFoundError, ImportError):
    logger.warning("Model tracker client not supported in your environment; will save models in filesystem only.")
    mlmt_supported = False


[docs] def train_model(input, output, dskey='', production=False): """Retrain a model saved in a model_metadata.json file Args: input (str): path to model_metadata.json file output (str): path to output directory dskey (str): new dataset key if file location has changed production (bool): retrain the model using production mode Returns: None """ # Train model # ----------- # Read parameter JSON file with open(input) as f: config = json.loads(f.read()) # set a new dataset key if necessary if not dskey == '': config['dataset_key'] = dskey # Parse parameters params = parse.wrapper(config) params.result_dir = output # otherwise this will have the same uuid as the source model params.model_uuid = None # use the same split params.previously_split = True params.split_uuid = config['splitting_parameters']['split_uuid'] # use production mode to train params.production = production if params.production and 'nn_specific' in config: params.max_epochs = config['nn_specific']['best_epoch']+1 # change save mode if retraining elsewhere if not mlmt_supported: params.save_results=False # specify collection logger.debug("model params %s" % str(params)) logger.debug(params.__dict__.items()) # Create model pipeline model = mp.ModelPipeline(params) # Train model model.train_model() return model
[docs] def train_model_from_tar(input, output, dskey='', production=False): """Retrain a model saved in a tar.gz file Args: input (str): path to a tar.gz file output (str): path to output directory dskey (str): new dataset key if file location has changed Returns: None """ tmpdir = tempfile.mkdtemp() with tarfile.open(input, mode='r:gz') as tar: futils.safe_extract(tar, path=tmpdir) # make metadata path metadata_path = os.path.join(tmpdir, 'model_metadata.json') return train_model(metadata_path, output, dskey=dskey, production=production)
[docs] def train_model_from_tracker(model_uuid, output_dir, production=False): """Retrain a model saved in the model tracker, but save it to output_dir and don't insert it into the model tracker Args: model_uuid (str): model tracker model_uuid file output_dir (str): path to output directory Returns: the model pipeline object with trained model """ if not mlmt_supported: logger.debug("Model tracker not supported in your environment; can load models from filesystem only.") return None mlmt_client = dsf.initialize_model_tracker() collection_name = mt.get_model_collection_by_uuid(model_uuid, mlmt_client=mlmt_client) # get metadata from tracker config = mt.get_metadata_by_uuid(model_uuid) # check if datastore dataset try: result = dsf.retrieve_dataset_by_datasetkey(config['training_dataset']['dataset_key'], bucket=config['training_dataset']['bucket']) if result is not None: config['datastore']=True except: pass # fix weird old parameters #if config[] # Parse parameters params = parse.wrapper(config) params.result_dir = output_dir # otherwise this will have the same uuid as the source model params.model_uuid = None # use the same split params.previously_split = True params.split_uuid = config['splitting_parameters']['split_uuid'] # specify collection params.collection_name = collection_name # use production mode to train params.production = production if params.production and 'nn_specific' in config: params.max_epochs = config['nn_specific']['best_epoch']+1 logger.debug("model params %s" % str(params)) # Create model pipeline model = mp.ModelPipeline(params) # Train model model.train_model() return model
[docs] def train_models_from_dataset_keys(input, output, pred_type='regression', production=False): """Retrain a list of models from an input file Args: input (str): path to an Excel or csv file. the required columns are 'dataset_key' and 'bucket' (public, private_file or Filesystem). output (str): path to output directory pred_type (str, optional): set the model prediction type. if not, uses the default 'regression' Returns: None """ df = pd.DataFrame() # parse the input file logger.debug("Parsing %s file." % input) try: df = pd.read_excel(input) except: try: df = pd.read_csv(input) except: Exception('Unable to parse input %s. Only Excel or csv file is accepted.' % input) # extract the public bucket, then dataset keys public_list = df.loc[df['bucket'] == 'public'] dataset_keys = public_list['dataset_key'].tolist() logger.debug('Found %d public dataset keys' % len(dataset_keys)) client = MLMTClient() collections = client.get_collection_names() bucket = 'public' # find the collections colls_w_dset = [] for dset in dataset_keys: for coll in collections: datasets = cmp.get_collection_datasets(coll) if (dset, bucket) in datasets: colls_w_dset.append(coll) logger.debug('Found the dataset_keys in %d collections' % len(colls_w_dset)) logger.debug("Train the model using prediction type %s." % pred_type) metric_type = 'r2_score' if (pred_type == 'classification'): metric_type = 'roc_auc_score' try: # find the best models best_mods = cmp.get_best_models_info(col_names=colls_w_dset, bucket=bucket, pred_type=pred_type, result_dir=None, PK_pipeline=False, output_dir=output, shortlist_key=None, input_dset_keys=dataset_keys, save_results=False, subset='valid', metric_type=metric_type, selection_type='max', other_filters={}) # retrain with uuid for model_uuid in best_mods.model_uuid.sort_values(): try: logger.debug('Training %s in %s' % (model_uuid, output)) train_model_from_tracker(model_uuid, output, production=production) except: Exception(f'Error for model_uuid {model_uuid}') pass except Exception as e: Exception(f'Error: %s' % str(e) )
#---------------- # main #----------------
[docs] def main(argv): start_time = time.time() # input file/dir (required) parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', required=True, help='input directory, file or model_uuid') parser.add_argument('-o', '--output', help='output result directory') parser.add_argument('-dk', '--dataset_key', default='', help='Sometimes dataset keys get moved. Specify new location of dataset. Only works when passing in one model at time.') parser.add_argument('-pd_type', '--pred_type', default='regression', help='Specify the prediction type used for model retrain. The default is set to regression.') parser.add_argument('-prod', '--production', action='store_true', default=False, help='Retrain the model in production mode') args = parser.parse_args() input = args.input output = args.output # if not specified, default to temp dir if not (output and output.strip()): output = tempfile.mkdtemp() # 1 check if it's a directory if os.path.isdir(input): # loop for path in Path(input).rglob('model_metadata.json'): train_model(path.absolute(), output, production=args.production) elif os.path.isfile(input): # 2 if it's a file, check if it's a json or tar.gz or file that contains list of dataset keys if input.endswith('.json'): train_model(input, output, dskey=args.dataset_key, production=args.production) elif input.endswith('.tar.gz'): train_model_from_tar(input, output, dskey=args.dataset_key, production=args.production) else: train_models_from_dataset_keys(input, output, pred_type=args.pred_type, production=args.production) else: try: # 3 try to process 'input' as uuid train_model_from_tracker(input, output, production=args.production) except: Exception('Unrecognized input %s'%input) elapsed_time_secs = time.time() - start_time logger.info("Execution took: %s secs" % timedelta(seconds=round(elapsed_time_secs)))
if __name__ == "__main__": main(sys.argv[1:])