diff --git a/poetry.lock b/poetry.lock index 24c2accc..a5f3aa5b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -263,17 +263,17 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.34.113" +version = "1.34.143" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.113-py3-none-any.whl", hash = "sha256:7e59f0a848be477a4c98a90e7a18a0e284adfb643f7879d2b303c5f493661b7a"}, - {file = "boto3-1.34.113.tar.gz", hash = "sha256:009cd143509f2ff4c37582c3f45d50f28c95eed68e8a5c36641206bdb597a9ea"}, + {file = "boto3-1.34.143-py3-none-any.whl", hash = "sha256:0d16832f23e6bd3ae94e35ea8e625529850bfad9baccd426de96ad8f445d8e03"}, + {file = "boto3-1.34.143.tar.gz", hash = "sha256:b590ce80c65149194def43ebf0ea1cf0533945502507837389a8d22e3ecbcf05"}, ] [package.dependencies] -botocore = ">=1.34.113,<1.35.0" +botocore = ">=1.34.143,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -282,13 +282,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.113" +version = "1.34.143" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.113-py3-none-any.whl", hash = "sha256:8ca87776450ef41dd25c327eb6e504294230a5756940d68bcfdedc4a7cdeca97"}, - {file = "botocore-1.34.113.tar.gz", hash = "sha256:449912ba3c4ded64f21d09d428146dd9c05337b2a112e15511bf2c4888faae79"}, + {file = "botocore-1.34.143-py3-none-any.whl", hash = "sha256:094aea179e8aaa1bc957ad49cc27d93b189dd3a1f3075d8b0ca7c445a2a88430"}, + {file = "botocore-1.34.143.tar.gz", hash = "sha256:059f032ec05733a836e04e869c5a15534420102f93116f3bc9a5b759b0651caf"}, ] [package.dependencies] @@ -297,7 +297,7 @@ python-dateutil = ">=2.1,<3.0.0" urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] -crt = ["awscrt (==0.20.9)"] +crt = ["awscrt (==0.20.11)"] [[package]] name = "certifi" @@ -3938,18 +3938,18 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] [[package]] name = "sagemaker" -version = "2.221.1" +version = "2.225.0" description = "Open source library for training and deploying models on Amazon SageMaker." optional = false python-versions = ">=3.8" files = [ - {file = "sagemaker-2.221.1-py3-none-any.whl", hash = "sha256:07e81e61c55fd6a89a9c54037ee1d30bff3f4e9698e9ff6b5253b831a5d18a21"}, - {file = "sagemaker-2.221.1.tar.gz", hash = "sha256:b4afe9ef86f33f820cffd4e1349e6dd0a3e7d4a3d6bf4aa9f96e59f450860053"}, + {file = "sagemaker-2.225.0-py3-none-any.whl", hash = "sha256:c91e2d12bbaecca9412f4beebd02d27793ac227d283598595528ca86bbb28dfd"}, + {file = "sagemaker-2.225.0.tar.gz", hash = "sha256:04f52550d484c9800f16d46c6ad127f172f953ae934d4da17ab7ebf6432997f0"}, ] [package.dependencies] attrs = ">=23.1.0,<24" -boto3 = ">=1.33.3,<2.0" +boto3 = ">=1.34.142,<2.0" cloudpickle = "2.2.1" docker = "*" google-pasta = "*" @@ -3971,12 +3971,12 @@ tqdm = "*" urllib3 = ">=1.26.8,<3.0.0" [package.extras] -all = ["PyYAML (>=5.4.1,<7)", "accelerate (>=0.24.1,<=0.27.0)", "docker (>=5.0.2,<7.0.0)", "pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3.3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "scipy (==1.10.1)", "urllib3 (>=1.26.8,<3.0.0)"] +all = ["PyYAML (>=5.4.1,<7)", "accelerate (>=0.24.1,<=0.27.0)", "docker (>=5.0.2,<8.0.0)", "pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3.3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "scipy (==1.10.1)", "urllib3 (>=1.26.8,<3.0.0)"] feature-processor = ["pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3.3"] huggingface = ["accelerate (>=0.24.1,<=0.27.0)", "sagemaker-schema-inference-artifacts (>=0.0.5)"] -local = ["PyYAML (>=5.4.1,<7)", "docker (>=5.0.2,<7.0.0)", "urllib3 (>=1.26.8,<3.0.0)"] +local = ["PyYAML (>=5.4.1,<7)", "docker (>=5.0.2,<8.0.0)", "urllib3 (>=1.26.8,<3.0.0)"] scipy = ["scipy (==1.10.1)"] -test = ["Jinja2 (==3.1.4)", "PyYAML (==6.0)", "PyYAML (>=5.4.1,<7)", "accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.0)", "apache-airflow-providers-amazon (==7.2.1)", "attrs (>=23.1.0,<24)", "awslogs (==0.14.0)", "black (==24.3.0)", "cloudpickle (==2.2.1)", "contextlib2 (==21.6.0)", "coverage (>=5.2,<6.2)", "docker (>=5.0.2,<7.0.0)", "fabric (==2.6.0)", "flake8 (==4.0.1)", "mock (==4.0.3)", "nbformat (>=5.9,<6)", "onnx (>=1.15.0)", "pandas (>=1.3.5,<1.5)", "pillow (>=10.0.1,<=11)", "pyspark (==3.3.1)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "pytest-rerunfailures (==10.2)", "pytest-timeout (==2.1.0)", "pytest-xdist (==2.4.0)", "pyvis (==0.2.1)", "requests (==2.31.0)", "sagemaker-experiments (==0.1.35)", "sagemaker-feature-store-pyspark-3.3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "schema (==0.7.5)", "scikit-learn (==1.3.0)", "scipy (==1.10.1)", "stopit (==1.1.2)", "tensorflow (>=2.1,<=2.16)", "tox (==3.24.5)", "tritonclient[http] (<2.37.0)", "urllib3 (>=1.26.8,<3.0.0)", "xgboost (>=1.6.2,<=1.7.6)"] +test = ["Jinja2 (==3.1.4)", "PyYAML (==6.0)", "PyYAML (>=5.4.1,<7)", "accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.2)", "apache-airflow-providers-amazon (==7.2.1)", "attrs (>=23.1.0,<24)", "awslogs (==0.14.0)", "black (==24.3.0)", "cloudpickle (==2.2.1)", "contextlib2 (==21.6.0)", "coverage (>=5.2,<6.2)", "docker (>=5.0.2,<8.0.0)", "fabric (==2.6.0)", "flake8 (==4.0.1)", "huggingface-hub (>=0.23.4)", "mlflow (>=2.12.2,<2.13)", "mock (==4.0.3)", "nbformat (>=5.9,<6)", "onnx (>=1.15.0)", "pandas (>=1.3.5,<1.5)", "pillow (>=10.0.1,<=11)", "pyspark (==3.3.1)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "pytest-rerunfailures (==10.2)", "pytest-timeout (==2.1.0)", "pytest-xdist (==2.4.0)", "pyvis (==0.2.1)", "requests (==2.32.2)", "sagemaker-experiments (==0.1.35)", "sagemaker-feature-store-pyspark-3.3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "schema (==0.7.5)", "scikit-learn (==1.3.0)", "scipy (==1.10.1)", "stopit (==1.1.2)", "tensorflow (>=2.1,<=2.16)", "tox (==3.24.5)", "tritonclient[http] (<2.37.0)", "urllib3 (>=1.26.8,<3.0.0)", "xgboost (>=1.6.2,<=1.7.6)"] [[package]] name = "schema" @@ -4887,4 +4887,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "7029988320c2cc687014c7fd2710b9a617dd06d689b32fcb4649ae64791b3709" +content-hash = "64ffcc42f92a252dbbe84f437143ca54e300d043f91c81e9376fbb692ee46a10" diff --git a/pyproject.toml b/pyproject.toml index 33ff67a3..023062f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ bert-score = "^0.3.13" scikit-learn = "^1.3.1" jiwer = "^3.0.3" transformers = "^4.36.0" -sagemaker = "^2.219.0" +sagemaker = "^2.225.0" testbook = "^0.4.2" ipykernel = "^6.26.0" mypy-boto3-bedrock = "^1.33.2" diff --git a/src/fmeval/constants.py b/src/fmeval/constants.py index 32550a9a..f31422be 100644 --- a/src/fmeval/constants.py +++ b/src/fmeval/constants.py @@ -123,6 +123,7 @@ class DatasetColumns(Enum): # Jumpstart JUMPSTART_MODEL_ID = "jumpstart_model_id" JUMPSTART_MODEL_VERSION = "jumpstart_model_version" +JUMPSTART_MODEL_TYPE = "jumpstart_model_type" MODEL_ID = "model_id" SPEC_KEY = "spec_key" DEFAULT_PAYLOADS = "default_payloads" diff --git a/src/fmeval/model_runners/extractors/__init__.py b/src/fmeval/model_runners/extractors/__init__.py index 7db598be..a83e294d 100644 --- a/src/fmeval/model_runners/extractors/__init__.py +++ b/src/fmeval/model_runners/extractors/__init__.py @@ -1,6 +1,14 @@ from typing import Optional -from fmeval.constants import MIME_TYPE_JSON, JUMPSTART_MODEL_ID, JUMPSTART_MODEL_VERSION, IS_EMBEDDING_MODEL +from sagemaker.jumpstart.enums import JumpStartModelType + +from fmeval.constants import ( + MIME_TYPE_JSON, + JUMPSTART_MODEL_ID, + JUMPSTART_MODEL_VERSION, + JUMPSTART_MODEL_TYPE, + IS_EMBEDDING_MODEL, +) from fmeval.exceptions import EvalAlgorithmClientError from fmeval.model_runners.extractors.json_extractor import JsonExtractor from fmeval.model_runners.extractors.jumpstart_extractor import JumpStartExtractor @@ -25,6 +33,9 @@ def create_extractor( extractor = JumpStartExtractor( jumpstart_model_id=kwargs[JUMPSTART_MODEL_ID], jumpstart_model_version=kwargs[JUMPSTART_MODEL_VERSION] if JUMPSTART_MODEL_VERSION in kwargs else "*", + jumpstart_model_type=kwargs[JUMPSTART_MODEL_TYPE] + if JUMPSTART_MODEL_TYPE in kwargs + else JumpStartModelType.OPEN_WEIGHTS, is_embedding_model=kwargs[IS_EMBEDDING_MODEL] if IS_EMBEDDING_MODEL in kwargs else False, ) else: # pragma: no cover diff --git a/src/fmeval/model_runners/extractors/jumpstart_extractor.py b/src/fmeval/model_runners/extractors/jumpstart_extractor.py index 92caf246..6a299ff0 100644 --- a/src/fmeval/model_runners/extractors/jumpstart_extractor.py +++ b/src/fmeval/model_runners/extractors/jumpstart_extractor.py @@ -1,4 +1,5 @@ import json +import logging import os from typing import Union, List, Dict, Optional from urllib import request @@ -7,19 +8,21 @@ from functional import seq from jmespath.exceptions import JMESPathError from sagemaker import Session +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from fmeval import util from fmeval.constants import ( MODEL_ID, - SPEC_KEY, GENERATED_TEXT_JMESPATH_EXPRESSION, SDK_MANIFEST_FILE, PROPRIETARY_SDK_MANIFEST_FILE, - DEFAULT_PAYLOADS, JUMPSTART_BUCKET_BASE_URL_FORMAT, JUMPSTART_BUCKET_BASE_URL_FORMAT_ENV_VAR, INPUT_LOG_PROBS_JMESPATH_EXPRESSION, EMBEDDING_JMESPATH_EXPRESSION, + SPEC_KEY, + DEFAULT_PAYLOADS, ) from fmeval.exceptions import EvalAlgorithmClientError, EvalAlgorithmInternalError from fmeval.model_runners.extractors.extractor import Extractor @@ -27,6 +30,8 @@ # The expected model response location for Jumpstart that do produce the log probabilities from fmeval.model_runners.util import get_sagemaker_session +logger = logging.getLogger(__name__) + class JumpStartExtractor(Extractor): """ @@ -37,6 +42,7 @@ def __init__( self, jumpstart_model_id: str, jumpstart_model_version: str, + jumpstart_model_type: str, is_embedding_model: Optional[bool] = False, sagemaker_session: Optional[Session] = None, ): @@ -51,6 +57,7 @@ def __init__( """ self._model_id = jumpstart_model_id self._model_version = jumpstart_model_version + self._model_type = jumpstart_model_type self._sagemaker_session = sagemaker_session if sagemaker_session else get_sagemaker_session() self._is_embedding_model = is_embedding_model @@ -62,22 +69,41 @@ def __init__( lambda x: x.get(MODEL_ID, None) == jumpstart_model_id ) util.require(model_manifest, f"Model {jumpstart_model_id} is not a valid JumpStart Model") - model_spec_key = self.get_jumpstart_sdk_spec( + + model_spec = self.get_jumpstart_sdk_spec( model_manifest.get(SPEC_KEY, None), self._sagemaker_session.boto_region_name, ) - util.require( - DEFAULT_PAYLOADS in model_spec_key, f"JumpStart Model: {jumpstart_model_id} is not supported at this time" - ) + if DEFAULT_PAYLOADS not in model_spec: + # Model spec contains alt configs, which should + # be obtained through JumpStart util function. + logger.info( + "default_payloads not found as a top-level attribute of model spec" + "Searching for default_payloads in inference configs instead." + ) + model_spec = verify_model_region_and_return_specs( + region=self._sagemaker_session.boto_region_name, + model_id=self._model_id, + version=self._model_version, + model_type=self._model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self._sagemaker_session, + ) + configs = model_spec.inference_configs # type: ignore[attr-defined] + util.require(configs, f"JumpStart Model: {jumpstart_model_id} is not supported at this time") + default_payloads = configs.get_top_config_from_ranking().resolved_metadata_config[DEFAULT_PAYLOADS] + else: + # Continue to extract default payloads by manually parsing the spec json object. + # TODO: update this code when the `default_payloads` attribute of JumpStartModelSpecs + # returns the full data, including fields like generated_text. + default_payloads = model_spec[DEFAULT_PAYLOADS] + + util.require(default_payloads, f"JumpStart Model: {jumpstart_model_id} is not supported at this time") - output_jmespath_expressions = None - input_log_probs_jmespath_expressions = None try: - output_jmespath_expressions = jmespath.compile(GENERATED_TEXT_JMESPATH_EXPRESSION).search( - model_spec_key[DEFAULT_PAYLOADS] - ) + output_jmespath_expressions = jmespath.compile(GENERATED_TEXT_JMESPATH_EXPRESSION).search(default_payloads) input_log_probs_jmespath_expressions = jmespath.compile(INPUT_LOG_PROBS_JMESPATH_EXPRESSION).search( - model_spec_key[DEFAULT_PAYLOADS] + default_payloads ) except (TypeError, JMESPathError) as e: raise EvalAlgorithmInternalError( diff --git a/src/fmeval/model_runners/sm_jumpstart_model_runner.py b/src/fmeval/model_runners/sm_jumpstart_model_runner.py index 22f465db..82aed11d 100644 --- a/src/fmeval/model_runners/sm_jumpstart_model_runner.py +++ b/src/fmeval/model_runners/sm_jumpstart_model_runner.py @@ -54,7 +54,17 @@ def __init__( :param component_name: Name of the Amazon SageMaker inference component corresponding the predictor """ + sagemaker_session = get_sagemaker_session() + util.require( + is_endpoint_in_service(sagemaker_session, endpoint_name), + f"Endpoint {endpoint_name} is not in service", + ) + # Default model type is always OPEN_WEIGHTS. See https://tinyurl.com/yc58s6wj + jumpstart_model_type = JumpStartModelType.OPEN_WEIGHTS + if is_proprietary_js_model(sagemaker_session.boto_region_name, model_id): + jumpstart_model_type = JumpStartModelType.PROPRIETARY is_text_embedding_model = is_text_embedding_js_model(model_id) + super().__init__( content_template=content_template, output=output, @@ -64,6 +74,7 @@ def __init__( accept_type=MIME_TYPE_JSON, jumpstart_model_id=model_id, jumpstart_model_version=model_version, + jumpstart_model_type=jumpstart_model_type, is_embedding_model=is_text_embedding_model, ) self._endpoint_name = endpoint_name @@ -77,17 +88,6 @@ def __init__( self._component_name = component_name self._is_embedding_model = is_text_embedding_model - sagemaker_session = get_sagemaker_session() - util.require( - is_endpoint_in_service(sagemaker_session, self._endpoint_name), - f"Endpoint {self._endpoint_name} is not in service", - ) - - # Default model type is always OPEN_WEIGHTS. See https://tinyurl.com/yc58s6wj - jumpstart_model_type = JumpStartModelType.OPEN_WEIGHTS - if is_proprietary_js_model(sagemaker_session.boto_region_name, self._model_id): - jumpstart_model_type = JumpStartModelType.PROPRIETARY - predictor = sagemaker.predictor.retrieve_default( endpoint_name=self._endpoint_name, model_id=self._model_id, diff --git a/test/integration/test_create_extractor.py b/test/integration/test_create_extractor.py new file mode 100644 index 00000000..5ae2a402 --- /dev/null +++ b/test/integration/test_create_extractor.py @@ -0,0 +1,42 @@ +import pytest + +from fmeval.constants import MIME_TYPE_JSON +from fmeval.exceptions import EvalAlgorithmClientError +from fmeval.model_runners.extractors import create_extractor, JumpStartExtractor + + +class TestCreateExtractor: + """ + These tests are under integration tests instead of unit tests because + credentials are required to call the JumpStart util function + verify_model_region_and_return_specs. + + See test/unit/model_runners/extractors/test_create_extractor.py + for corresponding unit tests. + """ + + def test_create_extractor_jumpstart(self): + """ + GIVEN a model whose default payloads are not found at the top level of + the model spec, but instead nested under the inference_configs attribute. + WHEN create_extractor is called with this model id. + THEN a JumpStartExtractor is successfully created for this model. + """ + # default payloads found in inference_component_configs + jumpstart_model_id = "huggingface-llm-mistral-7b" + assert isinstance( + create_extractor(model_accept_type=MIME_TYPE_JSON, jumpstart_model_id=jumpstart_model_id), + JumpStartExtractor, + ) + + def test_create_extractor_jumpstart_no_default_payloads(self): + """ + GIVEN a model whose spec does not contain default payloads data anywhere. + WHEN a create_extractor is called with this model id. + THEN the correct exception is raised. + """ + with pytest.raises( + EvalAlgorithmClientError, + match="JumpStart Model: xgboost-regression-snowflake is not supported at this time", + ): + create_extractor(model_accept_type=MIME_TYPE_JSON, jumpstart_model_id="xgboost-regression-snowflake") diff --git a/test/unit/model_runners/extractors/test_create_extractor.py b/test/unit/model_runners/extractors/test_create_extractor.py index 58021cbc..9856ccfe 100644 --- a/test/unit/model_runners/extractors/test_create_extractor.py +++ b/test/unit/model_runners/extractors/test_create_extractor.py @@ -9,9 +9,17 @@ def test_create_extractor(): assert isinstance(create_extractor(model_accept_type=MIME_TYPE_JSON, output_location="output"), JsonExtractor) -def test_create_extractor_jumpstart(): +@pytest.mark.parametrize( + "jumpstart_model_id", ["huggingface-llm-falcon-7b-bf16"] # default payloads found top level of model spec +) +def test_create_extractor_jumpstart(jumpstart_model_id): + """ + Note: the test case for a model whose default payloads are found in inference_configs + (instead of as a top-level attribute of the model spec) is an integration test, + since unit tests don't run with the credentials required. + """ assert isinstance( - create_extractor(model_accept_type=MIME_TYPE_JSON, jumpstart_model_id="huggingface-llm-falcon-7b-bf16"), + create_extractor(model_accept_type=MIME_TYPE_JSON, jumpstart_model_id=jumpstart_model_id), JumpStartExtractor, ) diff --git a/test/unit/model_runners/extractors/test_jumpstart_extractor.py b/test/unit/model_runners/extractors/test_jumpstart_extractor.py index 00691969..04af9006 100644 --- a/test/unit/model_runners/extractors/test_jumpstart_extractor.py +++ b/test/unit/model_runners/extractors/test_jumpstart_extractor.py @@ -1,9 +1,10 @@ from unittest import mock -from unittest.mock import patch +from unittest.mock import patch, Mock import pytest from _pytest.fixtures import fixture from _pytest.python_api import approx +from sagemaker.jumpstart.enums import JumpStartModelType from fmeval.exceptions import EvalAlgorithmClientError, EvalAlgorithmInternalError from fmeval.model_runners.extractors.jumpstart_extractor import JumpStartExtractor @@ -69,6 +70,7 @@ def extractor(self, sagemaker_session): return JumpStartExtractor( jumpstart_model_id="huggingface-llm-falcon-7b-bf16", jumpstart_model_version="*", + jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS, sagemaker_session=sagemaker_session, ) @@ -77,6 +79,7 @@ def embedding_model_extractor(self): return JumpStartExtractor( jumpstart_model_id=EMBEDDING_MODEL_ID, jumpstart_model_version="*", + jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS, is_embedding_model=True, ) @@ -131,6 +134,7 @@ def test_extractor_with_bad_output_expression(self, sagemaker_session): JumpStartExtractor( jumpstart_model_id="huggingface-llm-falcon-7b-bf16", jumpstart_model_version="*", + jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS, sagemaker_session=sagemaker_session, ) @@ -152,6 +156,7 @@ def test_extractor_with_bad_input_log_probability(self, sagemaker_session): JumpStartExtractor( jumpstart_model_id="huggingface-llm-falcon-7b-bf16", jumpstart_model_version="*", + jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS, sagemaker_session=sagemaker_session, ) @@ -168,5 +173,87 @@ def test_extractor_with_invalid_default_payload(self, sagemaker_session): JumpStartExtractor( jumpstart_model_id="huggingface-llm-falcon-7b-bf16", jumpstart_model_version="*", + jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=sagemaker_session, + ) + + @patch("sagemaker.session.Session") + def test_extractor_when_default_payloads_found_in_inference_configs(self, sagemaker_session): + sagemaker_session.boto_region_name = "us-west-2" + model_spec_json = {} + model_spec_js_object = Mock() + inference_configs = Mock() + top_config = Mock() + + resolved_metadata_config = {"default_payloads": {"test": {"output_keys": {"generated_text": "Hi"}}}} + top_config.resolved_metadata_config = resolved_metadata_config + inference_configs.get_top_config_from_ranking.return_value = top_config + model_spec_js_object.inference_configs = inference_configs + + with patch( + "fmeval.model_runners.extractors.jumpstart_extractor.JumpStartExtractor.get_jumpstart_sdk_spec", + return_value=model_spec_json, + ), patch( + "fmeval.model_runners.extractors.jumpstart_extractor.verify_model_region_and_return_specs", + return_value=model_spec_js_object, + ): + JumpStartExtractor( + jumpstart_model_id="huggingface-llm-falcon-7b-bf16", + jumpstart_model_version="*", + jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=sagemaker_session, + ) + + @patch("sagemaker.session.Session") + def test_extractor_when_model_spec_is_missing_inference_configs(self, sagemaker_session): + """ + GIVEN a model whose spec does not contain default_payloads as a top-level attribute + WHEN the `inference_configs` field of the JumpStartModelSpecs object returned by + verify_model_region_and_return_specs is None + THEN the correct exception is raised. + """ + sagemaker_session.boto_region_name = "us-west-2" + model_spec_json = {"not_default_payloads": {"test": {"output_keys": {"generated_text": "{"}}}} + model_spec_js_object = Mock() + model_spec_js_object.inference_configs = None + + with patch( + "fmeval.model_runners.extractors.jumpstart_extractor.JumpStartExtractor.get_jumpstart_sdk_spec", + return_value=model_spec_json, + ), patch( + "fmeval.model_runners.extractors.jumpstart_extractor.verify_model_region_and_return_specs", + return_value=model_spec_js_object, + ), pytest.raises( + EvalAlgorithmClientError, + match="JumpStart Model: huggingface-llm-falcon-7b-bf16 is not supported at this time", + ): + JumpStartExtractor( + jumpstart_model_id="huggingface-llm-falcon-7b-bf16", + jumpstart_model_version="*", + jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=sagemaker_session, + ) + + @patch("sagemaker.session.Session") + def test_extractor_when_default_payloads_is_empty(self, sagemaker_session): + """ + GIVEN the default payloads data that is found is empty. + WHEN a JumpStartExtractor is being initialized. + THEN the correct exception is raised. + """ + sagemaker_session.boto_region_name = "us-west-2" + model_spec_json = {"default_payloads": None} + + with patch( + "fmeval.model_runners.extractors.jumpstart_extractor.JumpStartExtractor.get_jumpstart_sdk_spec", + return_value=model_spec_json, + ), pytest.raises( + EvalAlgorithmClientError, + match="JumpStart Model: huggingface-llm-falcon-7b-bf16 is not supported at this time", + ): + JumpStartExtractor( + jumpstart_model_id="huggingface-llm-falcon-7b-bf16", + jumpstart_model_version="*", + jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS, sagemaker_session=sagemaker_session, ) diff --git a/test/unit/model_runners/test_sm_jumpstart_model_runner.py b/test/unit/model_runners/test_sm_jumpstart_model_runner.py index 9f2e5cfc..b3e5c952 100644 --- a/test/unit/model_runners/test_sm_jumpstart_model_runner.py +++ b/test/unit/model_runners/test_sm_jumpstart_model_runner.py @@ -87,6 +87,8 @@ def test_jumpstart_model_runner_init( model_type=JumpStartModelType.PROPRIETARY, sagemaker_session=sagemaker_session_class.return_value, region=None, + hub_arn=None, + config_name=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, ) @@ -98,6 +100,8 @@ def test_jumpstart_model_runner_init( model_type=JumpStartModelType.OPEN_WEIGHTS, sagemaker_session=sagemaker_session_class.return_value, region=None, + hub_arn=None, + config_name=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, )