From efc1d8942230fdce39cc47edc25f57d5b40c9a6c Mon Sep 17 00:00:00 2001 From: Daniel Zhu Date: Tue, 26 Mar 2024 13:40:05 -0700 Subject: [PATCH 1/3] chore: restore evaluate_sample and evaluate signatures in EvalAlgorithmInterface --- .../classification_accuracy.py | 1 + ...sification_accuracy_semantic_robustness.py | 1 + src/fmeval/eval_algorithms/eval_algorithm.py | 68 ++++-- .../eval_algorithms/factual_knowledge.py | 1 + .../general_semantic_robustness.py | 14 +- .../eval_algorithms/prompt_stereotyping.py | 5 +- src/fmeval/eval_algorithms/qa_accuracy.py | 1 + .../qa_accuracy_semantic_robustness.py | 1 + .../semantic_robustness_utils.py | 17 +- .../eval_algorithms/summarization_accuracy.py | 130 ++++++----- ...marization_accuracy_semantic_robustness.py | 1 + src/fmeval/eval_algorithms/toxicity.py | 1 + src/fmeval/eval_algorithms/util.py | 8 + src/fmeval/util.py | 12 + .../test_general_semantic_robustness.py | 6 +- .../test_summarization_accuracy.py | 213 ++++++++---------- test/unit/test_util.py | 21 +- 17 files changed, 292 insertions(+), 209 deletions(-) diff --git a/src/fmeval/eval_algorithms/classification_accuracy.py b/src/fmeval/eval_algorithms/classification_accuracy.py index c9a472e9..39c79ed0 100644 --- a/src/fmeval/eval_algorithms/classification_accuracy.py +++ b/src/fmeval/eval_algorithms/classification_accuracy.py @@ -122,6 +122,7 @@ def __init__(self, eval_algorithm_config: ClassificationAccuracyConfig = Classif :param eval_algorithm_config: Classification Accuracy eval algorithm config. """ + super().__init__(eval_algorithm_config) self._eval_algorithm_config = eval_algorithm_config self._valid_labels = self._eval_algorithm_config.valid_labels diff --git a/src/fmeval/eval_algorithms/classification_accuracy_semantic_robustness.py b/src/fmeval/eval_algorithms/classification_accuracy_semantic_robustness.py index c09b8405..2af9fe56 100644 --- a/src/fmeval/eval_algorithms/classification_accuracy_semantic_robustness.py +++ b/src/fmeval/eval_algorithms/classification_accuracy_semantic_robustness.py @@ -138,6 +138,7 @@ def __init__( :param eval_algorithm_config: Classification Accuracy Semantic Robustness eval algorithm config. """ + super().__init__(eval_algorithm_config) self.eval_name = CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS self._eval_algorithm_config = eval_algorithm_config self._classification_accuracy_eval_algo = ClassificationAccuracy( diff --git a/src/fmeval/eval_algorithms/eval_algorithm.py b/src/fmeval/eval_algorithms/eval_algorithm.py index f8177ffa..cb51b216 100644 --- a/src/fmeval/eval_algorithms/eval_algorithm.py +++ b/src/fmeval/eval_algorithms/eval_algorithm.py @@ -1,47 +1,77 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional + +from fmeval.data_loaders.data_config import DataConfig from fmeval.eval_algorithms import EvalScore, EvalOutput +from fmeval.model_runners.model_runner import ModelRunner class EvalAlgorithmConfig: - """Configuration class to be used or extended to provide evaluation algorithm-specific parameters.""" + """Configuration class to be inherited from to provide evaluation algorithm-specific parameters.""" class EvalAlgorithmInterface(ABC): """Interface for evaluation algorithms. This interface defines two required methods that all evaluation algorithms must implement. - The signatures of these methods is intentionally as generic as possible, to allow for - maximum freedom when implementing a new evaluation algorithm. """ + def __init__(self, eval_algorithm_config: EvalAlgorithmConfig): + """Initialize an evaluation algorithm instance. + + :param eval_algorithm_config: Contains all configurable parameters for the evaluation algorithm. + """ + @abstractmethod - def evaluate_sample(self, *args, **kwargs) -> List[EvalScore]: + def evaluate_sample( + self, + model_input: Optional[str] = None, + model_output: Optional[str] = None, + target_output: Optional[str] = None, + model: Optional[ModelRunner] = None, + ) -> List[EvalScore]: """Compute metrics for a single sample, where a sample is defined by the particular algorithm. - The arguments to this method should be any data that pertains to the sample to be evaluated, - and any additional data used to compute the relevant metrics/scores. + The `evaluate_sample` method implemented by different algorithms should use a subset of + these input parameters, but not all of them are required. - Example: - The evaluate_sample method of the FactualKnowledge evaluation algorithm takes - two arguments: `target_output` and `model_output`, which are used to compute the factual - knowledge score. + :param model_input: The input passed to `model`. If this parameter is not None, + `model` should likewise not be None. + :param model_output: The output from invoking a model. If provided, `model` generally + will not be required, as the output is already available. + :param target_output: The reference output that `model_output` will be compared against. + Note that if `model_output` is not provided but `model` and `model_input` are provided + instead, the output from invoking `model` will take the place of `model_output`. + :param model: A ModelRunner representing the model being evaluated. :returns: A list of EvalScore objects, where each EvalScore represents a single score/metric that is computed by the evaluation algorithm. - See the built-in evaluation algorithms (ex: FactualKnowledge, SummarizationAccuracy) - for concrete examples. """ @abstractmethod - def evaluate(self, *args, **kwargs) -> List[EvalOutput]: + def evaluate( + self, + model: Optional[ModelRunner] = None, + dataset_config: Optional[DataConfig] = None, + prompt_template: Optional[str] = None, + num_records: int = 100, + save: bool = False, + ) -> List[EvalOutput]: """Compute metrics on all samples in one or more datasets. - The format that the dataset(s) in question take is up to the implementer - of the evaluation algorithm. All built-in evaluation algorithms in fmeval - currently utilize Ray Datasets. See the built-in evaluation algorithms - (ex: FactualKnowledge, SummarizationAccuracy) for concrete examples of how - to implement the `evaluate` method. + :param model: An instance of ModelRunner representing the model being evaluated. + :param dataset_config: Configures the single dataset used for the evaluation. + If not provided, this method will run evaluations using all of its supported + built-in datasets. + :param prompt_template: A template used to generate prompts from raw text inputs. + This parameter is not required if you with to run evaluations using the built-in + datasets, as they have their own default prompt templates pre-configured. + :param save: If True, model responses and scores will be saved to a file. + By default, the directory that this output file gets written to is + DEFAULT_EVAL_RESULTS_PATH, but this directory can be configured through + the EVAL_RESULTS_PATH environment variable. + :param num_records: The number of records to be randomly sampled from the input dataset + that is used for the evaluation. :returns: A list of EvalOutput objects, where an EvalOutput encapsulates the EvalScores (and optionally, CategoryScores) generated by the evaluation, diff --git a/src/fmeval/eval_algorithms/factual_knowledge.py b/src/fmeval/eval_algorithms/factual_knowledge.py index 5ceadb1f..809c6597 100644 --- a/src/fmeval/eval_algorithms/factual_knowledge.py +++ b/src/fmeval/eval_algorithms/factual_knowledge.py @@ -73,6 +73,7 @@ def __init__(self, eval_algorithm_config: FactualKnowledgeConfig = FactualKnowle :param eval_algorithm_config: Factual knowledge eval algorithm config. """ + super().__init__(eval_algorithm_config) self.eval_name = FACTUAL_KNOWLEDGE self._eval_algorithm_config = eval_algorithm_config diff --git a/src/fmeval/eval_algorithms/general_semantic_robustness.py b/src/fmeval/eval_algorithms/general_semantic_robustness.py index b1856c3c..45a8a1bc 100644 --- a/src/fmeval/eval_algorithms/general_semantic_robustness.py +++ b/src/fmeval/eval_algorithms/general_semantic_robustness.py @@ -127,6 +127,7 @@ def __init__( `evaluate_sample` method, which is a computationally cheap operation that does not require utilizing Ray for parallel execution. """ + super().__init__(eval_algorithm_config) self.num_perturbations = eval_algorithm_config.num_perturbations self.num_baseline_samples = eval_algorithm_config.num_baseline_samples self.perturbation_transform = get_perturbation_transform(eval_algorithm_config) @@ -134,7 +135,7 @@ def __init__( if use_ray: self.bertscore_model = create_shared_resource(self.bertscore_model) - def build_pipeline( + def _build_pipeline( self, model: ModelRunner, prompt_template: str, @@ -159,12 +160,15 @@ def build_pipeline( :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`. """ - transforms = get_model_responses_from_perturbed_inputs( + ( + get_perturbed_inputs, + gen_perturbed_prompts, + get_perturbed_responses, + ) = get_model_responses_from_perturbed_inputs( self.perturbation_transform, prompt_template, model, ) - get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_responses = transforms original_model_output_key = DatasetColumns.MODEL_OUTPUT.value.name # Compute BERTScores with target_output = the original model output @@ -286,7 +290,7 @@ def evaluate_sample( DatasetColumns.PROMPT.value.name: prompt, DatasetColumns.MODEL_OUTPUT.value.name: model_output, } - pipeline = self.build_pipeline(model, prompt_template, is_deterministic=is_deterministic) + pipeline = self._build_pipeline(model, prompt_template, is_deterministic=is_deterministic) output_record = pipeline.execute_record(sample) bert_score_dissimilarity_value = output_record[BERT_SCORE_DISSIMILARITY] @@ -329,7 +333,7 @@ def evaluate( model_invocation_pipeline = create_model_invocation_pipeline(model, dataset_prompt_template) dataset = model_invocation_pipeline.execute(dataset) is_deterministic = verify_model_determinism(model, dataset, DatasetColumns.PROMPT.value.name) - pipeline = self.build_pipeline(model, dataset_prompt_template, is_deterministic=is_deterministic) + pipeline = self._build_pipeline(model, dataset_prompt_template, is_deterministic=is_deterministic) eval_output = compute_and_aggregate_metrics( pipeline=pipeline, dataset=dataset, diff --git a/src/fmeval/eval_algorithms/prompt_stereotyping.py b/src/fmeval/eval_algorithms/prompt_stereotyping.py index c7760173..01d4aaa5 100644 --- a/src/fmeval/eval_algorithms/prompt_stereotyping.py +++ b/src/fmeval/eval_algorithms/prompt_stereotyping.py @@ -8,7 +8,7 @@ MEAN, ) from fmeval.data_loaders.util import DataConfig, get_dataset -from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface +from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface, EvalAlgorithmConfig from fmeval.eval_algorithms import ( EvalAlgorithm, EvalOutput, @@ -53,6 +53,9 @@ class PromptStereotyping(EvalAlgorithmInterface): eval_name = PROMPT_STEREOTYPING + def __init__(self): + super().__init__(EvalAlgorithmConfig()) + def evaluate( self, model: Optional[ModelRunner] = None, diff --git a/src/fmeval/eval_algorithms/qa_accuracy.py b/src/fmeval/eval_algorithms/qa_accuracy.py index 8aac41fe..e1835f20 100644 --- a/src/fmeval/eval_algorithms/qa_accuracy.py +++ b/src/fmeval/eval_algorithms/qa_accuracy.py @@ -242,6 +242,7 @@ def __init__(self, eval_algorithm_config: QAAccuracyConfig = QAAccuracyConfig()) :param eval_algorithm_config: QA Accuracy eval algorithm config. """ + super().__init__(eval_algorithm_config) self._eval_algorithm_config = eval_algorithm_config def evaluate( diff --git a/src/fmeval/eval_algorithms/qa_accuracy_semantic_robustness.py b/src/fmeval/eval_algorithms/qa_accuracy_semantic_robustness.py index 4662c030..e454c93b 100644 --- a/src/fmeval/eval_algorithms/qa_accuracy_semantic_robustness.py +++ b/src/fmeval/eval_algorithms/qa_accuracy_semantic_robustness.py @@ -139,6 +139,7 @@ def __init__( :param eval_algorithm_config: QA Accuracy Semantic Robustness eval algorithm config. """ + super().__init__(eval_algorithm_config) self.eval_name = QA_ACCURACY_SEMANTIC_ROBUSTNESS self._eval_algorithm_config = eval_algorithm_config diff --git a/src/fmeval/eval_algorithms/semantic_robustness_utils.py b/src/fmeval/eval_algorithms/semantic_robustness_utils.py index 1997fdb1..fc5f0c17 100644 --- a/src/fmeval/eval_algorithms/semantic_robustness_utils.py +++ b/src/fmeval/eval_algorithms/semantic_robustness_utils.py @@ -2,6 +2,7 @@ from typing import Tuple from fmeval.constants import BUTTER_FINGER, RANDOM_UPPER_CASE, WHITESPACE_ADD_REMOVE, DatasetColumns +from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmConfig from fmeval.model_runners.model_runner import ModelRunner from fmeval.transforms.common import GeneratePrompt, GetModelOutputs from fmeval.transforms.semantic_perturbations import ( @@ -21,7 +22,7 @@ @dataclass(frozen=True) -class SemanticRobustnessConfig: +class SemanticRobustnessConfig(EvalAlgorithmConfig): """Configures the semantic robustness evaluation algorithms. :param perturbation_type: Perturbation type for generating perturbed inputs. @@ -53,6 +54,12 @@ def __post_init__(self): def get_perturbation_transform(config: SemanticRobustnessConfig) -> SemanticPerturbation: + """Returns a semantic perturbation transform based on parameters in `config`. + + :param config: A config that specifies a perturbation type, which dictates the + SemanticPerturbation that gets returned, and its configurable parameters. + :returns: A SemanticPerturbation instance, initialized with parameters passed via `config`. + """ if config.perturbation_type == BUTTER_FINGER: return ButterFinger( input_key=DatasetColumns.MODEL_INPUT.value.name, @@ -91,6 +98,14 @@ def get_model_responses_from_perturbed_inputs( prompt_template: str, model: ModelRunner, ) -> Tuple[SemanticPerturbation, GeneratePrompt, GetModelOutputs]: + """Returns a tuple of transforms for perturbing model inputs, composing prompts, and getting model outputs. + + :param perturbation: The semantic perturbation transform used to perturb inputs. + :param prompt_template: The template used for composing prompts out of the perturbed inputs. + :param model: The model that is invoked on the prompts constructed from perturbed inputs. + :returns: A tuple of three transforms, where the first is the same SemanticPerturbation + that was passed in, and the second two are created in this function. + """ # Generate prompts from perturbed inputs gen_perturbed_prompts = GeneratePrompt( input_keys=perturbation.output_keys, diff --git a/src/fmeval/eval_algorithms/summarization_accuracy.py b/src/fmeval/eval_algorithms/summarization_accuracy.py index c9b42e26..250b6207 100644 --- a/src/fmeval/eval_algorithms/summarization_accuracy.py +++ b/src/fmeval/eval_algorithms/summarization_accuracy.py @@ -1,14 +1,21 @@ import logging from dataclasses import dataclass -from typing import Dict, Any, Optional, List, Union, Tuple +from typing import List, Optional, Tuple, Union + from ray import ObjectRef -from fmeval.eval_algorithms import EvalScore, EvalOutput, EvalAlgorithm +from fmeval.data_loaders.util import get_dataset +from fmeval.eval_algorithms import EvalAlgorithm, EvalOutput, EvalScore from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface, EvalAlgorithmConfig -from fmeval.eval_algorithms import util -from fmeval.eval_algorithms.util import get_dataset_configs -from fmeval.util import assert_condition, require, create_shared_resource, get_eval_results_path -from fmeval.constants import BERTSCORE_DEFAULT_MODEL, DatasetColumns +from fmeval.eval_algorithms.util import get_dataset_configs, validate_dataset, evaluate_dataset +from fmeval.util import ( + assert_condition, + require, + create_shared_resource, + get_eval_results_path, + cleanup_shared_resource, +) +from fmeval.constants import BERTSCORE_DEFAULT_MODEL, DatasetColumns, MEAN from fmeval.transforms.transform_pipeline import TransformPipeline from fmeval.data_loaders.data_config import DataConfig from fmeval.helper_models import BertscoreModelTypes, BertscoreModel @@ -82,33 +89,30 @@ class SummarizationAccuracy(EvalAlgorithmInterface): eval_name = EvalAlgorithm.SUMMARIZATION_ACCURACY.value - def __init__(self, config: SummarizationAccuracyConfig = SummarizationAccuracyConfig(), use_ray: bool = True): + def __init__(self, eval_algorithm_config: SummarizationAccuracyConfig = SummarizationAccuracyConfig()): """SummarizationAccuracy initializer. - :param config: Summarization Accuracy evaluation algorithm config. - :param use_ray: Whether to create a Ray actor for the BertscoreModel used by this evaluation - algorithm instance. Currently, `evaluate` will only work if `use_ray` is set to True, - as the execution of the transform pipeline relies on the BertscoreModel existing - in shared memory. This flag can be set to False if you only plan on invoking the - `evaluate_sample` method, which is a computationally cheap operation that does not - require utilizing Ray for parallel execution. + :param eval_algorithm_config: Summarization Accuracy evaluation algorithm config. """ - self.use_ray = use_ray - meteor_score, rouge_score, bert_score, self.bertscore_model = SummarizationAccuracy.build_pipeline( + super().__init__(eval_algorithm_config) + self.bertscore_model = BertscoreModel(eval_algorithm_config.model_type_for_bertscore) + meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms( target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name], model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name], meteor_keys=[METEOR_SCORE], rouge_keys=[ROUGE_SCORE], bertscore_keys=[BERT_SCORE], - rouge_type=config.rouge_type, - use_stemmer_for_rouge=config.use_stemmer_for_rouge, - model_type_for_bertscore=config.model_type_for_bertscore, - use_ray=use_ray, + rouge_type=eval_algorithm_config.rouge_type, + use_stemmer_for_rouge=eval_algorithm_config.use_stemmer_for_rouge, + bertscore_model=self.bertscore_model, ) + self.meteor_score = meteor_score + self.rouge_score = rouge_score + self.bert_score = bert_score self.pipeline = TransformPipeline([meteor_score, rouge_score, bert_score]) @staticmethod - def build_pipeline( + def _create_transforms( target_output_keys: List[str], model_output_keys: List[str], meteor_keys: List[str], @@ -116,10 +120,21 @@ def build_pipeline( bertscore_keys: List[str], rouge_type: str, use_stemmer_for_rouge: bool, - bertscore_model: Optional[Union[BertscoreModel, ObjectRef]] = None, - model_type_for_bertscore: Optional[str] = None, - use_ray: bool = True, - ) -> Tuple[MeteorScore, RougeScore, BertScore, Union[BertscoreModel, ObjectRef]]: + bertscore_model: Union[BertscoreModel, ObjectRef], + ) -> Tuple[MeteorScore, RougeScore, BertScore]: + """Create a TransformPipeline containing summarization accuracy score transforms. + + :param target_output_keys: See the corresponding parameter in MeteorScore, RougeScore, and BertScore. + :param model_output_keys: See the corresponding parameter in MeteorScore, RougeScore, and BertScore. + :param meteor_keys: The `output_keys` parameter for the returned MeteorScore instance. + :param rouge_keys: The `output_keys` parameter for the returned RougeScore instance. + :param bertscore_keys: The `output_keys` parameter for the returned BertScore instance. + :param rouge_type: See the corresponding parameter in RougeScore. + :param use_stemmer_for_rouge: See `use_stemmer` in RougeScore. + :param bertscore_model: A BertscoreModel or Ray actor handle corresponding to a BertscoreModel + (i.e. a shared resource) used in the creation of the returned BertScore instance. + :returns: A tuple containing the created MeteorScore, RougeScore, and BertScore instances. + """ meteor_transform = MeteorScore( target_output_keys=target_output_keys, model_output_keys=model_output_keys, @@ -134,14 +149,6 @@ def build_pipeline( rouge_type=rouge_type, use_stemmer=use_stemmer_for_rouge, ) - if bertscore_model is None: # pragma: no branch - require( - model_type_for_bertscore is not None, - "model_type_for_bertscore must not be None when bertscore_model is not provided.", - ) - bertscore_model = BertscoreModel(model_type_for_bertscore) - if use_ray: # pragma: no branch - bertscore_model = create_shared_resource(bertscore_model) bert_transform = BertScore( target_output_keys=target_output_keys, model_output_keys=model_output_keys, @@ -149,30 +156,19 @@ def build_pipeline( allow_duplicate_input_keys=True, bertscore_model=bertscore_model, ) - return meteor_transform, rouge_transform, bert_transform, bertscore_model - - @staticmethod - def create_sample(target_output: str, model_output: str) -> Dict[str, Any]: - """Create a sample in the record format used by Transforms. - - This function's primary use is to be called by evaluate_sample. - - :param target_output: The target_output parameter passed to evaluate_sample. - :param model_output: The model_output parameter passed to evaluate_sample. - """ - return { - DatasetColumns.TARGET_OUTPUT.value.name: target_output, - DatasetColumns.MODEL_OUTPUT.value.name: model_output, - } + return meteor_transform, rouge_transform, bert_transform - def evaluate_sample(self, target_output: str, model_output: str) -> List[EvalScore]: + def evaluate_sample(self, target_output: str, model_output: str) -> List[EvalScore]: # type: ignore[override] """Compute summarization accuracy metrics for a single sample. :param target_output: The expected/desired model output. :param model_output: The actual model output. :returns: A list of EvalScore objects, one for each of the summarization accuracy metrics. """ - sample = SummarizationAccuracy.create_sample(target_output=target_output, model_output=model_output) + sample = { + DatasetColumns.TARGET_OUTPUT.value.name: target_output, + DatasetColumns.MODEL_OUTPUT.value.name: model_output, + } output_record = self.pipeline.execute_record(sample) assert_condition( all(metric_name in output_record for metric_name in METRIC_NAMES), @@ -197,7 +193,7 @@ def evaluate( :param dataset_config: Configures the single dataset used for evaluation. If not provided, evaluations will be run on all of this algorithm's built-in datasets. :param prompt_template: A template used to generate prompts that are fed to the model. - If not provided, defaults will be used. + If not provided, defaults will be used. If provided, `model` must not be None. :param num_records: The number of records to be sampled randomly from the input dataset(s) used to perform the evaluation(s). :param save: If set to true, prompt responses and scores will be saved to a file. @@ -206,25 +202,39 @@ def evaluate( :return: A list of EvalOutput objects. """ - require( - self.use_ray, - "The use_ray instance attribute of SummarizationAccuracy must be True in order " - "for the evaluate method to run successfully.", + # Create a shared resource to be used during the evaluation. + bertscore_shared_resource = create_shared_resource(self.bertscore_model) + # Create a new pipeline that uses the shared resource instead of self.bertscore_model. + meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms( + target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name], + model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name], + meteor_keys=[METEOR_SCORE], + rouge_keys=[ROUGE_SCORE], + bertscore_keys=[BERT_SCORE], + rouge_type=self.rouge_score.rouge_type, + use_stemmer_for_rouge=self.rouge_score.use_stemmer, + bertscore_model=bertscore_shared_resource, ) + pipeline = TransformPipeline([meteor_score, rouge_score, bert_score]) + dataset_configs = get_dataset_configs(dataset_config, self.eval_name) eval_outputs = [] for dataset_config in dataset_configs: - eval_output = util.evaluate_dataset( - dataset_config=dataset_config, - pipeline=self.pipeline, + dataset = get_dataset(dataset_config, num_records) + validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name]) + eval_output = evaluate_dataset( + dataset=dataset, + pipeline=pipeline, + dataset_name=dataset_config.dataset_name, eval_name=self.eval_name, metric_names=METRIC_NAMES, - required_columns=[DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.MODEL_INPUT.value.name], eval_results_path=get_eval_results_path(), model=model, prompt_template=prompt_template, - num_records=num_records, + agg_method=MEAN, save=save, ) eval_outputs.append(eval_output) + + cleanup_shared_resource(bertscore_shared_resource) return eval_outputs diff --git a/src/fmeval/eval_algorithms/summarization_accuracy_semantic_robustness.py b/src/fmeval/eval_algorithms/summarization_accuracy_semantic_robustness.py index f28cccd2..f1d46036 100644 --- a/src/fmeval/eval_algorithms/summarization_accuracy_semantic_robustness.py +++ b/src/fmeval/eval_algorithms/summarization_accuracy_semantic_robustness.py @@ -187,6 +187,7 @@ def __init__( :param eval_algorithm_config: Summarization Accuracy Semantic Robustness eval algorithm config. """ + super().__init__(eval_algorithm_config) self.eval_name = EvalAlgorithm.SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS.value self._eval_algorithm_config = eval_algorithm_config diff --git a/src/fmeval/eval_algorithms/toxicity.py b/src/fmeval/eval_algorithms/toxicity.py index 57f034fa..a7f8b130 100644 --- a/src/fmeval/eval_algorithms/toxicity.py +++ b/src/fmeval/eval_algorithms/toxicity.py @@ -75,6 +75,7 @@ def __init__(self, eval_algorithm_config: ToxicityConfig = ToxicityConfig()): :param eval_algorithm_config: Toxicity eval algorithm config """ + super().__init__(eval_algorithm_config) self.eval_name = TOXICITY self._eval_algorithm_config = eval_algorithm_config self._helper_model = TOXICITY_HELPER_MODEL_MAPPING[self._eval_algorithm_config.model_type]() diff --git a/src/fmeval/eval_algorithms/util.py b/src/fmeval/eval_algorithms/util.py index f5027b02..de815131 100644 --- a/src/fmeval/eval_algorithms/util.py +++ b/src/fmeval/eval_algorithms/util.py @@ -407,6 +407,14 @@ def get_bert_score( def create_model_invocation_pipeline(model: ModelRunner, prompt_template: str) -> TransformPipeline: + """Create a transform pipeline for performing the standard action of invoking a model on a prompt. + + :param model: The model to be invoked. + :param prompt_template: The template used for constructing prompts (out of raw inputs) + that will be fed to the model. + :returns: A TransformPipeline instance containing a GeneratePrompt transform that uses `prompt_template` + and a GetModelOutputs transform for invoking the model on the generated prompts. + """ gen_prompt = GeneratePrompt( input_keys=[DatasetColumns.MODEL_INPUT.value.name], output_keys=[DatasetColumns.PROMPT.value.name], diff --git a/src/fmeval/util.py b/src/fmeval/util.py index 378c40d0..fd507f16 100644 --- a/src/fmeval/util.py +++ b/src/fmeval/util.py @@ -112,3 +112,15 @@ def create_shared_resource(resource: object, num_cpus: int = 1) -> ObjectRef: resource_cls, serialized_data = resource.__reduce__() # type: ignore[misc] wrapped_resource_cls = ray.remote(num_cpus=num_cpus)(resource_cls) return wrapped_resource_cls.remote(*serialized_data) + + +def cleanup_shared_resource(resource: ObjectRef) -> None: + """Removes the resource from shared memory. + Concretely, this function kills the Ray actor corresponding + to `resource`, which in most cases will be an actor created + via create_shared_resource. + :param resource: A Ray actor handle to a shared resource + (ex: a BertscoreModel). + :returns: None + """ + ray.kill(resource) diff --git a/test/unit/eval_algorithms/test_general_semantic_robustness.py b/test/unit/eval_algorithms/test_general_semantic_robustness.py index b2fc4361..5a6f359e 100644 --- a/test/unit/eval_algorithms/test_general_semantic_robustness.py +++ b/test/unit/eval_algorithms/test_general_semantic_robustness.py @@ -190,14 +190,14 @@ def test_init(self, bertscore_model, create_shared_resource, perturbation_type, def test_build_pipeline(self, bertscore_model, is_deterministic, config): """ GIVEN a deterministic model. - WHEN `build_pipeline` is called. + WHEN a GeneralSemanticRobustness' `_build_pipeline` method is called. THEN a TransformPipeline with the correct Transforms is returned. """ # Mock BertscoreModel so that the actual model doesn't get loaded into memory during test. bertscore_model.return_value = Mock(spec=BertscoreModel) eval_algo = GeneralSemanticRobustness(config, use_ray=False) - pipeline = eval_algo.build_pipeline( + pipeline = eval_algo._build_pipeline( model=Mock(), prompt_template="$model_input", is_deterministic=is_deterministic, @@ -393,7 +393,7 @@ class TestCaseEvaluate(NamedTuple): ) @patch("fmeval.eval_algorithms.general_semantic_robustness.get_eval_results_path") @patch("fmeval.eval_algorithms.general_semantic_robustness.compute_and_aggregate_metrics") - @patch("fmeval.eval_algorithms.general_semantic_robustness.GeneralSemanticRobustness.build_pipeline") + @patch("fmeval.eval_algorithms.general_semantic_robustness.GeneralSemanticRobustness._build_pipeline") @patch("fmeval.eval_algorithms.general_semantic_robustness.verify_model_determinism") @patch("fmeval.eval_algorithms.general_semantic_robustness.create_model_invocation_pipeline") @patch("fmeval.eval_algorithms.general_semantic_robustness.get_dataset") diff --git a/test/unit/eval_algorithms/test_summarization_accuracy.py b/test/unit/eval_algorithms/test_summarization_accuracy.py index e13e2ab0..61d14634 100644 --- a/test/unit/eval_algorithms/test_summarization_accuracy.py +++ b/test/unit/eval_algorithms/test_summarization_accuracy.py @@ -2,14 +2,14 @@ import re import ray -from typing import NamedTuple -from unittest.mock import Mock, patch +from typing import NamedTuple, Optional +from unittest.mock import Mock, patch, call from _pytest.fixtures import fixture -from fmeval.constants import DatasetColumns, BERTSCORE_DEFAULT_MODEL +from fmeval.constants import DatasetColumns, MEAN from fmeval.eval_algorithms import EvalScore from fmeval.helper_models import BertscoreModel -from fmeval.transforms.summarization_accuracy_metrics import ROUGE_L, MeteorScore, RougeScore, BertScore +from fmeval.transforms.summarization_accuracy_metrics import ROUGE_L from fmeval.eval_algorithms.summarization_accuracy import ( SummarizationAccuracyConfig, SummarizationAccuracy, @@ -81,97 +81,29 @@ class TestSummarizationAccuracy: @fixture(scope="module") def eval_algo(self) -> SummarizationAccuracy: - return SummarizationAccuracy(SummarizationAccuracyConfig(), use_ray=False) + return SummarizationAccuracy(SummarizationAccuracyConfig()) @patch("fmeval.eval_algorithms.summarization_accuracy.TransformPipeline") - @patch("fmeval.eval_algorithms.summarization_accuracy.SummarizationAccuracy.build_pipeline") - def test_init(self, mock_build_pipeline, mock_transform_pipeline_cls): + @patch("fmeval.eval_algorithms.summarization_accuracy.SummarizationAccuracy._create_transforms") + @patch("fmeval.eval_algorithms.summarization_accuracy.BertscoreModel") + def test_init(self, bertscore_model_cls, mock_create_transforms, mock_transform_pipeline_cls): """ GIVEN default arguments. WHEN a SummarizationAccuracy is initialized. - THEN SummarizationAccuracy.build_pipeline is called, a TransformPipeline - is initialized with the correct Transforms, and said pipeline is set - to the instance's `pipeline` attribute. + THEN SummarizationAccuracy._create_transforms is called, + a TransformPipeline is initialized with the correct Transforms, + and said pipeline is set to the instance's `pipeline` attribute. """ mock_meteor, mock_rouge, mock_bertscore = Mock(), Mock(), Mock() - mock_build_pipeline.return_value = mock_meteor, mock_rouge, mock_bertscore, Mock() - summ_acc = SummarizationAccuracy() - mock_transform_pipeline_cls.assert_called_with([mock_meteor, mock_rouge, mock_bertscore]) - assert summ_acc.pipeline == mock_transform_pipeline_cls.return_value + mock_create_transforms.return_value = mock_meteor, mock_rouge, mock_bertscore + config = SummarizationAccuracyConfig() + summ_acc = SummarizationAccuracy(config) - @pytest.mark.parametrize("use_ray", [True, False]) - @patch("fmeval.eval_algorithms.summarization_accuracy.create_shared_resource") - def test_build_pipeline(self, mock_shared_resource, use_ray): - """ - GIVEN valid arguments where the bertscore_model argument is None. - WHEN SummarizationAccuracy's build_pipeline method is called. - THEN the correct outputs are returned, and create_shared_resource - is called if `use_ray` is True (and not called otherwise). - """ - meteor_score, rouge_score, bert_score, bertscore_model = SummarizationAccuracy.build_pipeline( - target_output_keys=["target_output"], - model_output_keys=["model_output"], - meteor_keys=[METEOR_SCORE], - rouge_keys=[ROUGE_SCORE], - bertscore_keys=[BERT_SCORE], - rouge_type=ROUGE_L, - use_stemmer_for_rouge=True, - model_type_for_bertscore=BERTSCORE_DEFAULT_MODEL, - use_ray=use_ray, - ) - assert isinstance(meteor_score, MeteorScore) - assert isinstance(rouge_score, RougeScore) - assert isinstance(bert_score, BertScore) - if use_ray: - mock_shared_resource.assert_called_once() - assert bertscore_model == mock_shared_resource.return_value - else: - mock_shared_resource.assert_not_called() - assert isinstance(bertscore_model, BertscoreModel) + bertscore_model_cls.assert_called_with(config.model_type_for_bertscore) + assert summ_acc.bertscore_model == bertscore_model_cls.return_value - @patch("fmeval.eval_algorithms.summarization_accuracy.BertscoreModel") - def test_build_pipeline_with_existing_bertscore_model(self, mock_bertscore_model_cls): - """ - GIVEN a `bertscore_model` argument that is not None. - WHEN SummarizationAccuracy's build_pipeline method is called. - THEN the bertscore_model that is returned by build_pipeline is - the same object that was passed in. - """ - bertscore_model_instance = Mock() - _, _, _, bertscore_model = SummarizationAccuracy.build_pipeline( - target_output_keys=["target_output"], - model_output_keys=["model_output"], - meteor_keys=[METEOR_SCORE], - rouge_keys=[ROUGE_SCORE], - bertscore_keys=[BERT_SCORE], - rouge_type=ROUGE_L, - use_stemmer_for_rouge=True, - bertscore_model=bertscore_model_instance, - use_ray=False, - ) - assert bertscore_model is bertscore_model_instance - mock_bertscore_model_cls.assert_not_called() - - def test_build_pipeline_missing_bertscore_model_type(self): - """ - GIVEN bertscore_model and model_type_for_bertscore arguments with value None. - WHEN SummarizationAccuracy's build_pipeline method is called. - THEN an exception is raised. - """ - with pytest.raises( - EvalAlgorithmClientError, - match="model_type_for_bertscore must not be None when bertscore_model is not provided.", - ): - SummarizationAccuracy.build_pipeline( - target_output_keys=["target_output"], - model_output_keys=["model_output"], - meteor_keys=[METEOR_SCORE], - rouge_keys=[ROUGE_SCORE], - bertscore_keys=[BERT_SCORE], - rouge_type=ROUGE_L, - use_stemmer_for_rouge=True, - use_ray=False, - ) + mock_transform_pipeline_cls.assert_called_with([mock_meteor, mock_rouge, mock_bertscore]) + assert summ_acc.pipeline == mock_transform_pipeline_cls.return_value class TestCaseSummarizationAccuracyInvalidConfig(NamedTuple): rouge_type: str @@ -206,7 +138,7 @@ def test_summarization_accuracy_invalid_config(self, rouge_type, model_type_for_ SummarizationAccuracyConfig(rouge_type=rouge_type, model_type_for_bertscore=model_type_for_bertscore) @patch("fmeval.eval_algorithms.summarization_accuracy.BertscoreModel") - def test_evaluate_sample(self, bertscore_model): + def test_evaluate_sample(self, bertscore_model_cls): """ GIVEN valid inputs. WHEN SummarizationAccuracy.evaluate_sample is called. @@ -215,7 +147,7 @@ def test_evaluate_sample(self, bertscore_model): # Mock the BertscoreModel class so that the actual model doesn't get loaded into memory. bertscore_model_instance = Mock(spec=BertscoreModel) bertscore_model_instance.invoke_model = Mock(return_value=BERTSCORE_DUMMY_VALUE) - bertscore_model.return_value = bertscore_model_instance + bertscore_model_cls.return_value = bertscore_model_instance model_output = "Berlin: Art, Heritage, Exhibitions Hub." target_output = "Berlin: an art metropolis." @@ -225,66 +157,109 @@ def test_evaluate_sample(self, bertscore_model): EvalScore(name=BERT_SCORE, value=BERTSCORE_DUMMY_VALUE), ] config = SummarizationAccuracyConfig(rouge_type=ROUGE_L) - eval_algorithm = SummarizationAccuracy(config, use_ray=False) + eval_algorithm = SummarizationAccuracy(config) actual_response = eval_algorithm.evaluate_sample(target_output, model_output) for actual_eval_score, expected_eval_score in zip(actual_response, expected_response): assert actual_eval_score.name == expected_eval_score.name assert actual_eval_score.value == pytest.approx(expected_eval_score.value, rel=1e-5) + class TestCaseEvaluate(NamedTuple): + user_provided_prompt_template: Optional[str] + dataset_prompt_template: str + + @pytest.mark.parametrize( + "test_case", + [ + TestCaseEvaluate( + user_provided_prompt_template="Summarize $model_input, please.", + dataset_prompt_template="Summarize $model_input, please.", + ), + TestCaseEvaluate( + user_provided_prompt_template=None, + dataset_prompt_template=None, + ), + ], + ) @patch("fmeval.eval_algorithms.summarization_accuracy.get_eval_results_path") - @patch("fmeval.eval_algorithms.summarization_accuracy.util.evaluate_dataset") + @patch("fmeval.eval_algorithms.summarization_accuracy.cleanup_shared_resource") + @patch("fmeval.eval_algorithms.summarization_accuracy.evaluate_dataset") + @patch("fmeval.eval_algorithms.summarization_accuracy.create_shared_resource") @patch("fmeval.eval_algorithms.summarization_accuracy.TransformPipeline") - @patch("fmeval.eval_algorithms.summarization_accuracy.SummarizationAccuracy.build_pipeline") + @patch("fmeval.eval_algorithms.summarization_accuracy.SummarizationAccuracy._create_transforms") + @patch("fmeval.eval_algorithms.summarization_accuracy.get_dataset") + @patch("fmeval.eval_algorithms.summarization_accuracy.get_dataset_configs") def test_evaluate( - self, mock_build_pipeline, mock_transform_pipeline_cls, mock_evaluate_dataset, mock_get_results_path + self, + mock_get_dataset_configs, + mock_get_dataset, + mock_create_transforms, + mock_transform_pipeline_cls, + mock_create_shared_resource, + mock_evaluate_dataset, + mock_cleanup_shared_resource, + mock_get_results_path, + test_case, ): """ - GIVEN a SummarizationAccuracy instance whose `use_ray` attribute is True. + GIVEN a SummarizationAccuracy instance. WHEN its evaluate method is called with valid arguments. - THEN `util.evaluate_dataset` is called with the correct arguments. + THEN a new TransformPipeline that uses a BertscoreModel shared resource + is created, and `evaluate_dataset` is called with the correct arguments. """ - mock_build_pipeline.return_value = Mock(), Mock(), Mock(), Mock() + # The transforms that are saved as instance attributes of the SummarizationAccuracy instance + meteor_score, rouge_score, bert_score = Mock(), Mock(), Mock() + # The transforms that get used in the pipeline that gets executed by `evaluate`. + pipeline_meteor, pipeline_rouge, pipeline_bertscore = Mock(), Mock(), Mock() + + mock_create_transforms.side_effect = [ + (meteor_score, rouge_score, bert_score), + (pipeline_meteor, pipeline_rouge, pipeline_bertscore), + ] + + instance_pipeline = Mock() # The self.pipeline of the SummarizationAccuracy instance + executed_pipeline = Mock() # The pipeline that gets created and executed in `evaluate` + mock_transform_pipeline_cls.side_effect = [instance_pipeline, executed_pipeline] + mock_get_results_path.return_value = "/path/to/results" model_runner = Mock() + dataset_config = Mock() + dataset_config.dataset_name = "my_custom_dataset" + mock_get_dataset_configs.return_value = [dataset_config] + + mock_dataset = Mock() + # So that validate_dataset does not error + mock_dataset.columns = Mock( + return_value=[DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name] + ) + mock_get_dataset.return_value = mock_dataset summ_acc = SummarizationAccuracy() output = summ_acc.evaluate( model=model_runner, dataset_config=dataset_config, - prompt_template="Summarize $model_input, please.", + prompt_template=test_case.user_provided_prompt_template, num_records=162, save=True, ) + mock_create_shared_resource.assert_called_once_with(summ_acc.bertscore_model) + assert mock_create_transforms.call_count == 2 # once during initialization, once during evaluate + mock_transform_pipeline_cls.assert_has_calls( + [call([meteor_score, rouge_score, bert_score]), call([pipeline_meteor, pipeline_rouge, pipeline_bertscore])] + ) mock_evaluate_dataset.assert_called_once_with( + dataset=mock_dataset, + pipeline=executed_pipeline, + dataset_name=dataset_config.dataset_name, eval_name=summ_acc.eval_name, - pipeline=summ_acc.pipeline, metric_names=METRIC_NAMES, - required_columns=[DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.MODEL_INPUT.value.name], eval_results_path="/path/to/results", model=model_runner, - dataset_config=dataset_config, - prompt_template="Summarize $model_input, please.", - num_records=162, + prompt_template=test_case.dataset_prompt_template, + agg_method=MEAN, save=True, ) + mock_cleanup_shared_resource.assert_called_once_with(mock_create_shared_resource.return_value) assert output == [mock_evaluate_dataset.return_value] - - @patch("fmeval.eval_algorithms.summarization_accuracy.TransformPipeline") - @patch("fmeval.eval_algorithms.summarization_accuracy.SummarizationAccuracy.build_pipeline") - def test_evaluate_failure(self, mock_build_pipeline, mock_transform_pipeline_cls): - """ - GIVEN a SummarizationAccuracy instance whose `use_ray` attribute is False. - WHEN its evaluate method is called. - THEN an exception is raised. - """ - - mock_build_pipeline.return_value = Mock(), Mock(), Mock(), Mock() - summ_acc = SummarizationAccuracy(use_ray=False) - err_msg = ( - "The use_ray instance attribute of SummarizationAccuracy must be True in order " - "for the evaluate method to run successfully." - ) - with pytest.raises(EvalAlgorithmClientError, match=err_msg): - summ_acc.evaluate() + assert summ_acc.pipeline == instance_pipeline diff --git a/test/unit/test_util.py b/test/unit/test_util.py index 85477be2..7befcb83 100644 --- a/test/unit/test_util.py +++ b/test/unit/test_util.py @@ -6,7 +6,14 @@ from fmeval.constants import DEFAULT_EVAL_RESULTS_PATH from fmeval.exceptions import EvalAlgorithmClientError -from fmeval.util import require, project_root, singleton, create_shared_resource, get_eval_results_path +from fmeval.util import ( + require, + project_root, + singleton, + create_shared_resource, + get_eval_results_path, + cleanup_shared_resource, +) def test_require(): @@ -97,3 +104,15 @@ def __reduce__(self): mock_ray_remote.assert_called_once_with(num_cpus=num_cpus) mock_actor_class.assert_called_once_with(Dummy) mock_wrapped_resource_class.remote.assert_called_once_with("C", 2) + + +@patch("fmeval.util.ray.kill") +def test_cleanup_shared_resource(mock_ray_kill): + """ + GIVEN a shared resource. + WHEN cleanup_shared_resource is called. + THEN ray.kill is called on this resource. + """ + resource = Mock() + cleanup_shared_resource(resource) + mock_ray_kill.assert_called_once_with(resource) From 781c8e29eacec3e8f349c1eb03f8bfa1a8f21f08 Mon Sep 17 00:00:00 2001 From: Daniel Zhu Date: Tue, 26 Mar 2024 14:09:03 -0700 Subject: [PATCH 2/3] fix: update call to evaluate_dataset --- src/fmeval/eval_algorithms/summarization_accuracy.py | 10 ++++------ .../eval_algorithms/test_summarization_accuracy.py | 5 +++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/fmeval/eval_algorithms/summarization_accuracy.py b/src/fmeval/eval_algorithms/summarization_accuracy.py index 250b6207..c92b06ea 100644 --- a/src/fmeval/eval_algorithms/summarization_accuracy.py +++ b/src/fmeval/eval_algorithms/summarization_accuracy.py @@ -4,10 +4,9 @@ from ray import ObjectRef -from fmeval.data_loaders.util import get_dataset from fmeval.eval_algorithms import EvalAlgorithm, EvalOutput, EvalScore from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface, EvalAlgorithmConfig -from fmeval.eval_algorithms.util import get_dataset_configs, validate_dataset, evaluate_dataset +from fmeval.eval_algorithms.util import evaluate_dataset, get_dataset_configs from fmeval.util import ( assert_condition, require, @@ -220,17 +219,16 @@ def evaluate( dataset_configs = get_dataset_configs(dataset_config, self.eval_name) eval_outputs = [] for dataset_config in dataset_configs: - dataset = get_dataset(dataset_config, num_records) - validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name]) eval_output = evaluate_dataset( - dataset=dataset, + dataset_config=dataset_config, pipeline=pipeline, - dataset_name=dataset_config.dataset_name, eval_name=self.eval_name, metric_names=METRIC_NAMES, + required_columns=[DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name], eval_results_path=get_eval_results_path(), model=model, prompt_template=prompt_template, + num_records=num_records, agg_method=MEAN, save=save, ) diff --git a/test/unit/eval_algorithms/test_summarization_accuracy.py b/test/unit/eval_algorithms/test_summarization_accuracy.py index 61d14634..e0ae3719 100644 --- a/test/unit/eval_algorithms/test_summarization_accuracy.py +++ b/test/unit/eval_algorithms/test_summarization_accuracy.py @@ -249,14 +249,15 @@ def test_evaluate( [call([meteor_score, rouge_score, bert_score]), call([pipeline_meteor, pipeline_rouge, pipeline_bertscore])] ) mock_evaluate_dataset.assert_called_once_with( - dataset=mock_dataset, + dataset_config=dataset_config, pipeline=executed_pipeline, - dataset_name=dataset_config.dataset_name, eval_name=summ_acc.eval_name, metric_names=METRIC_NAMES, + required_columns=[DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name], eval_results_path="/path/to/results", model=model_runner, prompt_template=test_case.dataset_prompt_template, + num_records=162, agg_method=MEAN, save=True, ) From e215ae9138037969b1e1edb01cc4e84473567818 Mon Sep 17 00:00:00 2001 From: Daniel Zhu Date: Tue, 26 Mar 2024 14:25:28 -0700 Subject: [PATCH 3/3] fix unit test correspondingly --- test/unit/eval_algorithms/test_summarization_accuracy.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/unit/eval_algorithms/test_summarization_accuracy.py b/test/unit/eval_algorithms/test_summarization_accuracy.py index e0ae3719..95c44fa9 100644 --- a/test/unit/eval_algorithms/test_summarization_accuracy.py +++ b/test/unit/eval_algorithms/test_summarization_accuracy.py @@ -186,12 +186,10 @@ class TestCaseEvaluate(NamedTuple): @patch("fmeval.eval_algorithms.summarization_accuracy.create_shared_resource") @patch("fmeval.eval_algorithms.summarization_accuracy.TransformPipeline") @patch("fmeval.eval_algorithms.summarization_accuracy.SummarizationAccuracy._create_transforms") - @patch("fmeval.eval_algorithms.summarization_accuracy.get_dataset") @patch("fmeval.eval_algorithms.summarization_accuracy.get_dataset_configs") def test_evaluate( self, mock_get_dataset_configs, - mock_get_dataset, mock_create_transforms, mock_transform_pipeline_cls, mock_create_shared_resource, @@ -227,13 +225,6 @@ def test_evaluate( dataset_config.dataset_name = "my_custom_dataset" mock_get_dataset_configs.return_value = [dataset_config] - mock_dataset = Mock() - # So that validate_dataset does not error - mock_dataset.columns = Mock( - return_value=[DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name] - ) - mock_get_dataset.return_value = mock_dataset - summ_acc = SummarizationAccuracy() output = summ_acc.evaluate( model=model_runner,