Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/fmeval/model_runners/extractors/jumpstart_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(
model_manifest.get(SPEC_KEY, None),
self._sagemaker_session.boto_region_name,
)

default_payloads = None
if DEFAULT_PAYLOADS not in model_spec:
# Model spec contains alt configs, which should
# be obtained through JumpStart util function.
Expand All @@ -90,8 +92,8 @@ def __init__(
sagemaker_session=self._sagemaker_session,
)
configs = model_spec.inference_configs # type: ignore[attr-defined]
util.require(configs, f"JumpStart Model: {jumpstart_model_id} is not supported at this time")
default_payloads = configs.get_top_config_from_ranking().resolved_metadata_config[DEFAULT_PAYLOADS]
if configs is not None:
default_payloads = configs.get_top_config_from_ranking().resolved_metadata_config[DEFAULT_PAYLOADS]
else:
# Continue to extract default payloads by manually parsing the spec json object.
# TODO: update this code when the `default_payloads` attribute of JumpStartModelSpecs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fmeval.model_runners.composers.jumpstart_composer import JumpStartComposer

OSS_MODEL_ID = "huggingface-eqa-roberta-large"
PROPRIETARY_MODEL_ID = "cohere-gpt-medium"
PROPRIETARY_MODEL_ID = "ai21-summarization"
EMBEDDING_MODEL_ID = "tcembedding-model-id"
PROMPT = "Hello, how are you?"

Expand Down
7 changes: 6 additions & 1 deletion test/unit/model_runners/extractors/test_create_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fmeval.constants import MIME_TYPE_JSON
from fmeval.exceptions import EvalAlgorithmClientError
from fmeval.model_runners.extractors import create_extractor, JsonExtractor, JumpStartExtractor
from sagemaker.jumpstart.enums import JumpStartModelType


def test_create_extractor():
Expand All @@ -26,7 +27,11 @@ def test_create_extractor_jumpstart(jumpstart_model_id):

def test_create_extractor_jumpstart_proprietary():
assert isinstance(
create_extractor(model_accept_type=MIME_TYPE_JSON, jumpstart_model_id="cohere-gpt-medium"),
create_extractor(
model_accept_type=MIME_TYPE_JSON,
jumpstart_model_id="ai21-summarization",
jumpstart_model_type=JumpStartModelType.PROPRIETARY,
),
JumpStartExtractor,
)

Expand Down
16 changes: 16 additions & 0 deletions test/unit/model_runners/extractors/test_jumpstart_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,22 @@ def test_extractor_when_model_spec_is_missing_inference_configs(self, sagemaker_
sagemaker_session=sagemaker_session,
)

@patch("sagemaker.session.Session")
def test_extractor_when_default_payloads_found_outside_inference_configs(self, sagemaker_session):
sagemaker_session.boto_region_name = "us-west-2"
model_spec_json = {"default_payloads": {"test": {"output_keys": {"generated_text": "Hi"}}}}

with patch(
"fmeval.model_runners.extractors.jumpstart_extractor.JumpStartExtractor.get_jumpstart_sdk_spec",
return_value=model_spec_json,
):
JumpStartExtractor(
jumpstart_model_id="huggingface-llm-falcon-7b-bf16",
jumpstart_model_version="*",
jumpstart_model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=sagemaker_session,
)

@patch("sagemaker.session.Session")
def test_extractor_when_default_payloads_is_empty(self, sagemaker_session):
"""
Expand Down
2 changes: 1 addition & 1 deletion test/unit/model_runners/test_sm_jumpstart_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
CUSTOM_ATTRIBUTES = "CustomAttributes"
INFERENCE_COMPONENT_NAME = "valid_inference_component_name"
MODEL_ID = "AwesomeModel"
PROPRIETARY_MODEL_ID = "cohere-gpt-medium"
PROPRIETARY_MODEL_ID = "ai21-summarization"
MODEL_VERSION = "v1.2.3"

CONTENT_TEMPLATE = '{"data":$prompt}'
Expand Down
2 changes: 1 addition & 1 deletion test/unit/model_runners/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_is_proprietary_js_model_false():


def test_is_proprietary_js_model_true():
assert is_proprietary_js_model("us-west-2", "cohere-gpt-medium") == True
assert is_proprietary_js_model("us-west-2", "ai21-summarization") == True


@patch("fmeval.model_runners.util.list_jumpstart_models", return_value=["tcembedding-model-id"])
Expand Down