Source code for utils.model_version_utils

"""model_version_utils.py

Misc utilities to get the AMPL version(s) used to train one or more models and check them
for compatibility with the currently running version of AMPL.:

To check the model version

 usage: model_version_utils.py [-h] -i INPUT

 optional arguments:
   -h, --help            show this help message and exit

  -i INPUT, --input INPUT     input directory/file (required)

"""

import argparse
import tarfile
import json
import os
from pathlib import Path
import sys
import logging
import re

logging.basicConfig()

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

try:
    from importlib import metadata
except ImportError:
    import importlib_metadata as metadata # python<=3.7


# ampl versions compatible groups
comp_dict = { '1.2': 'group1', '1.3': 'group1', '1.4': 'group2', '1.5': 'group3', '1.6': 'group3', '1.7': 'group3', '1.8': 'group3' }
version_pattern = re.compile(r"[\d.]+")

[docs] def get_ampl_version(): """Get the running ampl version Returns: the AMPL version """ return metadata.version("atomsci-ampl")
[docs] def get_ampl_version_from_dir(dirname): """Get the AMPL versions for all the models stored under the given directory and its subdirectories, recursively. Args: dirname (str): directory Returns: list of AMPL versions """ versions = [] # loop for path in Path(dirname).rglob('*.tar.gz'): try: version = get_ampl_version_from_model(path.absolute()) versions.append('{}, {}'.format(path.absolute(), version)) except (json.decoder.JSONDecodeError, FileNotFoundError) as e: logger.exception("Exception message: {}".format(e)) pass return '\n'.join(versions)
[docs] def get_ampl_version_from_model(filename): """Get the AMPL version from the tar file's model_metadata.json Args: filename (str): tar file Returns: the AMPL version number """ with tarfile.open(filename, mode='r:gz') as tar: try: meta_info = tar.getmember('./model_metadata.json') except KeyError: print(f"{filename} is not an AMPL model tarball") return None with tar.extractfile(meta_info) as meta_fd: metadata_dict = json.loads(meta_fd.read()) version = metadata_dict.get("model_parameters").get("ampl_version", 'probably 1.0.0') logger.info('{}, {}'.format(filename, version)) return version
[docs] def get_major_version(full_version): return '.'.join(full_version.split('.')[:2])
[docs] def get_ampl_version_from_json(metadata_path): """Parse model_metadata.json to get the AMPL version Args: filename (str): tar file Returns: the AMPL version number """ with open(metadata_path, 'r') as data_file: metadata_dict = json.load(data_file) version = metadata_dict.get("model_parameters").get("ampl_version", 'probably 1.0.0') return version
[docs] def validate_version(input): valid = re.fullmatch(version_pattern, input) if valid is None: raise ValueError("Input {} is not valid version format.".format(input)) return True
[docs] def check_version_compatible(input, ignore_check=False): """Compare the input file's version against the running AMPL version to see if they are compatible Args: filename (str): file or version number Returns: True if the input model version matches the compatible AMPL version group """ # get the versions. only compare by the major releases model_ampl_version = "" # if the input is a tar file, extract it to get the version string if (os.path.isfile(input)): model_ampl_version = get_major_version(get_ampl_version_from_model(input).strip()) else: # if the input is not a file. try to parse string like '1.5.0' validate_version(input) model_ampl_version = get_major_version(input) ampl_version = get_major_version(get_ampl_version()) logger.info('Version compatible check: {} version = "{}", AMPL version = "{}"'.format(input, model_ampl_version, ampl_version)) match = (comp_dict.get(ampl_version, ampl_version)==comp_dict.get(model_ampl_version, model_ampl_version)) # raise an exception if not match and we don't want to ignore if not match: if not ignore_check: my_error = ValueError('Version compatible check: {} version: "{}" not matching AMPL compatible version group: "{}"'.format(input, model_ampl_version, ampl_version)) raise my_error return match
#---------------- # main #----------------
[docs] def main(argv): # input file/dir (required) parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', required=True, help='input model directory/file') args = parser.parse_args() finput = args.input # check if it's a directory if os.path.isdir(finput): get_ampl_version_from_dir(finput) elif os.path.isfile(finput): get_ampl_version_from_model(finput)
if __name__ == "__main__": main(sys.argv[1:])