"""model_version_utils.py
Misc utilities to get the AMPL version(s) used to train one or more models and check them
for compatibility with the currently running version of AMPL.:
To check the model version
usage: model_version_utils.py [-h] -i INPUT
optional arguments:
-h, --help show this help message and exit
-i INPUT, --input INPUT input directory/file (required)
"""
import argparse
import traceback
import tarfile
import json
import os
from pathlib import Path
import sys
import tarfile
import logging
import pandas as pd
import pdb
import re
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
try:
from importlib import metadata
except ImportError:
import importlib_metadata as metadata # python<=3.7
# ampl versions compatible groups
comp_dict = { '1.2': 'group1', '1.3': 'group1', '1.4': 'group2', '1.5': 'group3', '1.6': 'group3' }
version_pattern = re.compile(r"[\d.]+")
[docs]
def get_ampl_version():
"""Get the running ampl version
Returns:
the AMPL version
"""
return metadata.version("atomsci-ampl")
[docs]
def get_ampl_version_from_dir(dirname):
"""Get the AMPL versions for all the models stored under the given directory and its subdirectories,
recursively.
Args:
dirname (str): directory
Returns:
list of AMPL versions
"""
versions = []
# loop
for path in Path(dirname).rglob('*.tar.gz'):
try:
version = get_ampl_version_from_model(path.absolute())
versions.append('{}, {}'.format(path.absolute(), version))
except (json.decoder.JSONDecodeError, FileNotFoundError) as e:
logger.exception("Exception message: {}".format(e))
pass
return '\n'.join(versions)
[docs]
def get_ampl_version_from_model(filename):
"""Get the AMPL version from the tar file's model_metadata.json
Args:
filename (str): tar file
Returns:
the AMPL version number
"""
with tarfile.open(filename, mode='r:gz') as tar:
try:
meta_info = tar.getmember('./model_metadata.json')
except KeyError:
print(f"{filename} is not an AMPL model tarball")
return None
with tar.extractfile(meta_info) as meta_fd:
metadata_dict = json.loads(meta_fd.read())
version = metadata_dict.get("model_parameters").get("ampl_version", 'probably 1.0.0')
logger.info('{}, {}'.format(filename, version))
return version
[docs]
def get_major_version(full_version):
return '.'.join(full_version.split('.')[:2])
[docs]
def get_ampl_version_from_json(metadata_path):
"""Parse model_metadata.json to get the AMPL version
Args:
filename (str): tar file
Returns:
the AMPL version number
"""
with open(metadata_path, 'r') as data_file:
metadata_dict = json.load(data_file)
version = metadata_dict.get("model_parameters").get("ampl_version", 'probably 1.0.0')
return version
[docs]
def validate_version(input):
valid = re.fullmatch(version_pattern, input)
if valid is None:
raise ValueError("Input {} is not valid version format.".format(input))
return True
[docs]
def check_version_compatible(input, ignore_check=False):
"""Compare the input file's version against the running AMPL version to see if
they are compatible
Args:
filename (str): file or version number
Returns:
True if the input model version matches the compatible AMPL version group
"""
# get the versions. only compare by the major releases
model_ampl_version = ""
# if the input is a tar file, extract it to get the version string
if (os.path.isfile(input)):
model_ampl_version = get_major_version(get_ampl_version_from_model(input).strip())
else:
# if the input is not a file. try to parse string like '1.5.0'
validate_version(input)
model_ampl_version = get_major_version(input)
ampl_version = get_major_version(get_ampl_version())
logger.info('Version compatible check: {} version = "{}", AMPL version = "{}"'.format(input, model_ampl_version, ampl_version))
match = (comp_dict.get(ampl_version, ampl_version)==comp_dict.get(model_ampl_version, model_ampl_version))
# raise an exception if not match and we don't want to ignore
if not match:
if not ignore_check:
my_error = ValueError('Version compatible check: {} version: "{}" not matching AMPL compatible version group: "{}"'.format(input, model_ampl_version, ampl_version))
raise my_error
return match
#----------------
# main
#----------------
[docs]
def main(argv):
# input file/dir (required)
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', required=True, help='input model directory/file')
args = parser.parse_args()
finput = args.input
# check if it's a directory
if os.path.isdir(finput):
get_ampl_version_from_dir(finput)
elif os.path.isfile(finput):
get_ampl_version_from_model(finput)
if __name__ == "__main__":
main(sys.argv[1:])