diff --git a/src/fmeval/model_runners/composers/composers.py b/src/fmeval/model_runners/composers/composers.py index 0b38589a..da6fafe1 100644 --- a/src/fmeval/model_runners/composers/composers.py +++ b/src/fmeval/model_runners/composers/composers.py @@ -1,6 +1,6 @@ import json from abc import ABC, abstractmethod -from typing import Any, Dict, Union, List +from typing import Any, Dict, Union, List, Optional from fmeval.exceptions import EvalAlgorithmClientError from fmeval.model_runners.composers.template import VanillaTemplate @@ -16,22 +16,24 @@ def __init__(self, template: str, placeholder: str): self.placeholder = placeholder self.vanilla_template = VanillaTemplate(template) - def _get_filled_in_template(self, data: str) -> str: + def _get_filled_in_template(self, placeholder_data_dict: Dict) -> str: """ - Returns the string that results from replacing self.placeholder - in self.template with `data`. + Returns the string that results from replacing keywords of placeholder_data_dict + in self.template with corresponding value. :param data: Data to replace placeholder. :return: A template that has its placeholders "filled in". """ - return self.vanilla_template.substitute(**{self.placeholder: data}) + return self.vanilla_template.substitute(**placeholder_data_dict) @abstractmethod - def compose(self, data: str) -> Any: + def compose(self, data: Optional[str], placeholder_data_dict: Optional[Dict[str, str]]) -> Any: """ - Composes an object using the input data, self.vanilla_template, and self.placeholder. + Composes an object using the input data, self.vanilla_template, self.placeholder, + and placeholder and data dictionary. :param data: The data used to compose a new object. + :param placeholder_data_dict: The placeholder and original data dict used for composing. :return: A new object composed using `data`, self.vanilla_template, and self.placeholder. """ @@ -60,7 +62,7 @@ def compose(self, data: str) -> Union[str, List, Dict]: :return: A JSON object representing a prompt that will be consumed by a model. """ try: - return json.loads(self._get_filled_in_template(json.dumps(data))) + return json.loads(self._get_filled_in_template({self.placeholder: json.dumps(data)})) except Exception as e: raise EvalAlgorithmClientError( f"Unable to load a JSON object with template '{self.vanilla_template.template}' using data {data} ", @@ -78,16 +80,27 @@ class PromptComposer(Composer): def __init__(self, template: str): super().__init__(template=template, placeholder=self.PLACEHOLDER) - def compose(self, data: str) -> str: + def compose(self, data: Optional[str] = None, placeholder_data_dict: Optional[Dict[str, str]] = {}) -> str: """ - Composes a prompt that will be fed to an LLM. + Composes a prompt with data and/or from placeholder_data_dict that will be fed to an LLM. + When both `data` and `placeholder_data_dict` are given and there are duplicates, + the placeholders from placeholder_data_dict take precedence. Example: data = "London is the capital of" + template = + "[INST] <>Answer the following question in as few words as possible.<> + Question: $model_input [/INST]" composed prompt = "[INST] <>Answer the following question in as few words as possible.<> Question: London is the capital of [/INST]" :param data: The original string that forms the basis of the returned prompt. - :return: A prompt composed by replacing self.placeholder in self.vanilla_template with `data`. + :param placeholder_data_dict: The placeholder and original string dict. + :return: A prompt composed by replacing self.placeholder in self.vanilla_template with `data`, + and/or replacing keys of `placeholder_data_dict` with its corresponding value. """ - return self._get_filled_in_template(data) + mapping_obj = {} + if data: + mapping_obj = {self.placeholder: data} + mapping_obj.update(**placeholder_data_dict) + return self._get_filled_in_template(placeholder_data_dict=mapping_obj) diff --git a/test/unit/model_runners/composers/test_composers.py b/test/unit/model_runners/composers/test_composers.py index 0ef02eb9..28f85e0b 100644 --- a/test/unit/model_runners/composers/test_composers.py +++ b/test/unit/model_runners/composers/test_composers.py @@ -1,5 +1,5 @@ import re -from typing import NamedTuple, Union, List, Dict +from typing import NamedTuple, Union, List, Dict, Optional import pytest @@ -52,12 +52,53 @@ def test_invalid_template(self): class TestPromptComposer: - def test_compose(self): - composer = PromptComposer(template="Answer the following question: $model_input") - prompt = "London is the capital of?" - expected_result = "Answer the following question: London is the capital of?" - result = composer.compose(prompt) - assert result == expected_result + class TestCaseCompose(NamedTuple): + template: str + prompt: Optional[str] + placeholder_data_dict: Dict + expected_result: str + + @pytest.mark.parametrize( + "test_case", + [ + # Test case to verify composing a prompt with `data` + TestCaseCompose( + template="Answer the following question: $model_input", + prompt="London is the capital of?", + placeholder_data_dict={}, + expected_result="Answer the following question: London is the capital of?", + ), + # Test case verify composing a prompt with placeholder_data_dict + TestCaseCompose( + template="Question: $model_input \n context: $context \n statement: $statements", + prompt=None, + placeholder_data_dict={ + "model_input": "sample question", + "context": "sample context", + "statements": "statement1", + }, + expected_result="Question: sample question \n context: sample context \n statement: statement1", + ), + # Test case verify composing a prompt with placeholder_data_dict argument takes higher priority than `data` + TestCaseCompose( + template="Question: $model_input", + prompt="question from prompt", + placeholder_data_dict={"model_input": "question from kwargs"}, + expected_result="Question: question from kwargs", + ), + # Test case verify composing a prompt with both `data` and placeholder_data_dict + TestCaseCompose( + template="Question: $model_input \n Context: $context", + prompt="question from prompt", + placeholder_data_dict={"context": "some context"}, + expected_result="Question: question from prompt \n Context: some context", + ), + ], + ) + def test_compose(self, test_case): + composer = PromptComposer(template=test_case.template) + result = composer.compose(test_case.prompt, test_case.placeholder_data_dict) + assert result == test_case.expected_result def test_invalid_template(self): composer = PromptComposer(template="Answer the following question: $invalid")