Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions src/fmeval/eval_algorithms/qa_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dataclasses import dataclass

from nltk.metrics.scores import f_measure
from nltk.metrics.scores import f_measure, precision, recall

import fmeval.util as util
from fmeval.constants import (
Expand Down Expand Up @@ -49,6 +49,8 @@
F1_SCORE = "f1_score"
EXACT_MATCH_SCORE = "exact_match_score"
QUASI_EXACT_MATCH_SCORE = "quasi_exact_match_score"
PRECISION = "precision"
RECALL = "recall"

PROMPT_COLUMN_NAME = "prompt"
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -120,6 +122,46 @@ def _f1_score(
return float(ret)


def _precision(model_output: str, target_output: str, *, normalize_text: bool = False) -> float:
"""
Given the model output and the target output, compute the precision.
Precision is the fraction of words in the prediction that are also found in the target output.
Before computing precision, we normalize the text following the QuAC protocol.

:param model_output: The output of a model that we want to evaluate.
:param target_output: The reference or the "ground truth" output.
:param normalize_text: Normalize the text before computing f1.
:returns: Precision.
"""
if normalize_text: # pragma: no branch
model_output, target_output = (_normalize_text_quac_protocol(text) for text in (model_output, target_output))
ret = precision(reference=set(target_output.split(" ")), test=set(model_output.split(" ")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why set and not list? we want to discard repetitions of words?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid question, but set seems standard. At least that's what the NLTK metric assumes that we use here.

if ret is None: # pragma: no cover
return 0.0
else:
return float(ret)


def _recall(model_output: str, target_output: str, *, normalize_text: bool = False) -> float:
"""
Given the model output and the target output, compute the recall.
Recall is the fraction of words in the target output that are also found in the answer.
Before computing recall, we normalize the text following the QuAC protocol.

:param model_output: The output of a model that we want to evaluate.
:param target_output: The reference or the "ground truth" output.
:param normalize_text: Normalize the text before computing f1.
:returns: Recall.
"""
if normalize_text: # pragma: no branch
model_output, target_output = (_normalize_text_quac_protocol(text) for text in (model_output, target_output))
ret = recall(reference=set(target_output.split(" ")), test=set(model_output.split(" ")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

if ret is None: # pragma: no cover
return 0.0
else:
return float(ret)


def _exact_match_score(model_output: str, target_output: str) -> float:
"""
Inspired by HELM: https://github.com/stanford-crfm/helm/blob/62f817eb695a31e8389e3f7be30609d3f0871837/src/helm/benchmark/metrics/basic_metrics.py#L137
Expand Down Expand Up @@ -150,6 +192,8 @@ def _quasi_exact_match_score(model_output: str, target_output: str) -> float:
F1_SCORE: partial(_f1_score, normalize_text=True, strip_text=True),
EXACT_MATCH_SCORE: _exact_match_score,
QUASI_EXACT_MATCH_SCORE: _quasi_exact_match_score,
PRECISION: partial(_precision, normalize_text=True),
RECALL: partial(_recall, normalize_text=True),
}


Expand Down Expand Up @@ -236,7 +280,7 @@ def _generate_eval_scores(row: Dict[str, Any]) -> Dict[str, Any]: # pragma: no
dataset = dataset.map(_generate_eval_scores).materialize()

dataset_scores, category_scores = aggregate_evaluation_scores(
dataset, [F1_SCORE, EXACT_MATCH_SCORE, QUASI_EXACT_MATCH_SCORE], agg_method=MEAN
dataset, [F1_SCORE, EXACT_MATCH_SCORE, QUASI_EXACT_MATCH_SCORE, PRECISION, RECALL], agg_method=MEAN
)

eval_outputs.append(
Expand Down
12 changes: 12 additions & 0 deletions src/fmeval/eval_algorithms/qa_accuracy_semantic_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
F1_SCORE,
EXACT_MATCH_SCORE,
QUASI_EXACT_MATCH_SCORE,
PRECISION,
RECALL,
QAAccuracy,
QAAccuracyConfig,
)
Expand All @@ -72,6 +74,8 @@
DELTA_F1_SCORE = PREFIX_FOR_DELTA_SCORES + F1_SCORE
DELTA_EXACT_MATCH_SCORE = PREFIX_FOR_DELTA_SCORES + EXACT_MATCH_SCORE
DELTA_QUASI_EXACT_MATCH_SCORE = PREFIX_FOR_DELTA_SCORES + QUASI_EXACT_MATCH_SCORE
DELTA_PRECISION = PREFIX_FOR_DELTA_SCORES + PRECISION
DELTA_RECALL = PREFIX_FOR_DELTA_SCORES + RECALL

PROMPT_COLUMN_NAME = "prompt"
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -231,9 +235,13 @@ def _generate_score_columns(row: Dict[str, Any]) -> Dict[str, Any]: # pragma: n
F1_SCORE,
EXACT_MATCH_SCORE,
QUASI_EXACT_MATCH_SCORE,
PRECISION,
RECALL,
DELTA_F1_SCORE,
DELTA_EXACT_MATCH_SCORE,
DELTA_QUASI_EXACT_MATCH_SCORE,
DELTA_PRECISION,
DELTA_RECALL,
],
agg_method=MEAN,
)
Expand All @@ -259,9 +267,13 @@ def _generate_score_columns(row: Dict[str, Any]) -> Dict[str, Any]: # pragma: n
F1_SCORE,
EXACT_MATCH_SCORE,
QUASI_EXACT_MATCH_SCORE,
PRECISION,
RECALL,
DELTA_F1_SCORE,
DELTA_EXACT_MATCH_SCORE,
DELTA_QUASI_EXACT_MATCH_SCORE,
DELTA_PRECISION,
DELTA_RECALL,
],
path=generate_output_dataset_path(
path_to_parent_dir=self._eval_results_path,
Expand Down
6 changes: 6 additions & 0 deletions test/integration/test_qa_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
F1_SCORE,
EXACT_MATCH_SCORE,
QUASI_EXACT_MATCH_SCORE,
PRECISION,
RECALL,
)

from fmeval.data_loaders.data_config import DataConfig
Expand Down Expand Up @@ -60,3 +62,7 @@ def test_evaluate(self, integration_tests_dir):
assert eval_score.value == approx(0.060606, abs=ABS_TOL)
elif eval_score.name == QUASI_EXACT_MATCH_SCORE:
assert eval_score.value == approx(0.303030, abs=ABS_TOL)
elif eval_score.name == PRECISION:
assert eval_score.value == approx(0.357660, abs=ABS_TOL)
elif eval_score.name == RECALL:
assert eval_score.value == approx(0.381313, abs=ABS_TOL)
28 changes: 27 additions & 1 deletion test/integration/test_qa_accuracy_semantic_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
DELTA_F1_SCORE,
DELTA_EXACT_MATCH_SCORE,
DELTA_QUASI_EXACT_MATCH_SCORE,
DELTA_PRECISION,
DELTA_RECALL,
)
from fmeval.eval_algorithms.qa_accuracy import F1_SCORE, QUASI_EXACT_MATCH_SCORE, EXACT_MATCH_SCORE
from fmeval.eval_algorithms.qa_accuracy import F1_SCORE, QUASI_EXACT_MATCH_SCORE, EXACT_MATCH_SCORE, PRECISION, RECALL
from fmeval.data_loaders.data_config import DataConfig
from fmeval.constants import MIME_TYPE_JSONLINES
from test.integration.models.model_runners import sm_model_runner
Expand Down Expand Up @@ -44,9 +46,13 @@ class TestCaseEvaluateSample(NamedTuple):
F1_SCORE: 1.0,
EXACT_MATCH_SCORE: 1.0,
QUASI_EXACT_MATCH_SCORE: 1.0,
PRECISION: 1.0,
RECALL: 1.0,
DELTA_F1_SCORE: 0.8,
DELTA_EXACT_MATCH_SCORE: 0.8,
DELTA_QUASI_EXACT_MATCH_SCORE: 0.8,
DELTA_PRECISION: 0.8,
DELTA_RECALL: 0.8,
},
),
TestCaseEvaluateSample(
Expand All @@ -59,9 +65,13 @@ class TestCaseEvaluateSample(NamedTuple):
F1_SCORE: 1.0,
EXACT_MATCH_SCORE: 1.0,
QUASI_EXACT_MATCH_SCORE: 1.0,
PRECISION: 1.0,
RECALL: 1.0,
DELTA_F1_SCORE: 1.0,
DELTA_EXACT_MATCH_SCORE: 1.0,
DELTA_QUASI_EXACT_MATCH_SCORE: 1.0,
DELTA_PRECISION: 1.0,
DELTA_RECALL: 1.0,
},
),
TestCaseEvaluateSample(
Expand All @@ -75,9 +85,13 @@ class TestCaseEvaluateSample(NamedTuple):
F1_SCORE: 1.0,
EXACT_MATCH_SCORE: 1.0,
QUASI_EXACT_MATCH_SCORE: 1.0,
PRECISION: 1.0,
RECALL: 1.0,
DELTA_F1_SCORE: 0.8,
DELTA_EXACT_MATCH_SCORE: 0.8,
DELTA_QUASI_EXACT_MATCH_SCORE: 0.8,
DELTA_PRECISION: 0.8,
DELTA_RECALL: 0.8,
},
),
],
Expand Down Expand Up @@ -109,9 +123,13 @@ class TestCaseEvaluate(NamedTuple):
F1_SCORE: 0.3606,
EXACT_MATCH_SCORE: 0.0606,
QUASI_EXACT_MATCH_SCORE: 0.3030,
PRECISION: 0.3577,
RECALL: 0.3813,
DELTA_F1_SCORE: 0.2277,
DELTA_EXACT_MATCH_SCORE: 0.0586,
DELTA_QUASI_EXACT_MATCH_SCORE: 0.2101,
DELTA_PRECISION: 0.2284,
DELTA_RECALL: 0.2316,
},
),
TestCaseEvaluate(
Expand All @@ -124,9 +142,13 @@ class TestCaseEvaluate(NamedTuple):
F1_SCORE: 0.3606,
EXACT_MATCH_SCORE: 0.0606,
QUASI_EXACT_MATCH_SCORE: 0.3030,
PRECISION: 0.3577,
RECALL: 0.3813,
DELTA_F1_SCORE: 0.1876,
DELTA_EXACT_MATCH_SCORE: 0.0687,
DELTA_QUASI_EXACT_MATCH_SCORE: 0.1798,
DELTA_PRECISION: 0.194,
DELTA_RECALL: 0.1804,
},
),
TestCaseEvaluate(
Expand All @@ -140,9 +162,13 @@ class TestCaseEvaluate(NamedTuple):
F1_SCORE: 0.3606,
EXACT_MATCH_SCORE: 0.0606,
QUASI_EXACT_MATCH_SCORE: 0.3030,
PRECISION: 0.3577,
RECALL: 0.3813,
DELTA_F1_SCORE: 0.1709,
DELTA_EXACT_MATCH_SCORE: 0.0525,
DELTA_QUASI_EXACT_MATCH_SCORE: 0.1535,
DELTA_PRECISION: 0.1761,
DELTA_RECALL: 0.1705,
},
),
],
Expand Down
62 changes: 56 additions & 6 deletions test/unit/eval_algorithms/test_qa_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
F1_SCORE,
EXACT_MATCH_SCORE,
QUASI_EXACT_MATCH_SCORE,
PRECISION,
RECALL,
_f1_score,
_exact_match_score,
)
Expand Down Expand Up @@ -74,6 +76,13 @@
MODEL_OUTPUT_COLUMN_NAME: "2022",
CATEGORY_COLUMN_NAME: "sports",
},
# Answer is longer than the model output.
{
MODEL_INPUT_COLUMN_NAME: "Did RMS Titanic sink in 1912?",
TARGET_OUTPUT_COLUMN_NAME: "yes",
MODEL_OUTPUT_COLUMN_NAME: "Yes. That is true.",
CATEGORY_COLUMN_NAME: "history",
},
]
)

Expand All @@ -94,6 +103,8 @@
EvalScore(name=F1_SCORE, value=2 / 3),
EvalScore(name=EXACT_MATCH_SCORE, value=1 / 3),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=2 / 3),
EvalScore(name=PRECISION, value=2 / 3),
EvalScore(name=RECALL, value=2 / 3),
],
),
CategoryScore(
Expand All @@ -102,6 +113,8 @@
EvalScore(name=F1_SCORE, value=2 / 3),
EvalScore(name=EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=PRECISION, value=1.0),
EvalScore(name=RECALL, value=1 / 2),
],
),
CategoryScore(
Expand All @@ -110,14 +123,28 @@
EvalScore(name=F1_SCORE, value=1.0),
EvalScore(name=EXACT_MATCH_SCORE, value=1.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=1.0),
EvalScore(name=PRECISION, value=1.0),
EvalScore(name=RECALL, value=1.0),
],
),
CategoryScore(
name="history",
scores=[
EvalScore(name=F1_SCORE, value=2 / 5),
EvalScore(name=EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=PRECISION, value=1 / 4),
EvalScore(name=RECALL, value=1.0),
],
),
]

DATASET_SCORES = [
EvalScore(name=F1_SCORE, value=11 / 15),
EvalScore(name=EXACT_MATCH_SCORE, value=2 / 5),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=3 / 5),
EvalScore(name=F1_SCORE, value=61 / 90),
EvalScore(name=EXACT_MATCH_SCORE, value=2 / 6),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=3 / 6),
EvalScore(name=PRECISION, value=17 / 24),
EvalScore(name=RECALL, value=3 / 4),
]

EVAL_RESULTS_PATH = DEFAULT_EVAL_RESULTS_PATH
Expand Down Expand Up @@ -161,9 +188,11 @@ class TestCaseQAAccuracyEvaluateSampleInvalid(NamedTuple):
model_output="London",
target_output="London",
expected_response=[
EvalScore(name=F1_SCORE, value=1),
EvalScore(name=EXACT_MATCH_SCORE, value=1),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=1),
EvalScore(name=F1_SCORE, value=1.0),
EvalScore(name=EXACT_MATCH_SCORE, value=1.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=1.0),
EvalScore(name=PRECISION, value=1.0),
EvalScore(name=RECALL, value=1.0),
],
),
# Partial match
Expand All @@ -175,6 +204,8 @@ class TestCaseQAAccuracyEvaluateSampleInvalid(NamedTuple):
EvalScore(name=F1_SCORE, value=2 / 3),
EvalScore(name=EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=PRECISION, value=1.0),
EvalScore(name=RECALL, value=1 / 2),
],
),
# Wrong answer. All scores should be zero.
Expand All @@ -186,6 +217,8 @@ class TestCaseQAAccuracyEvaluateSampleInvalid(NamedTuple):
EvalScore(name=F1_SCORE, value=0.0),
EvalScore(name=EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=PRECISION, value=0.0),
EvalScore(name=RECALL, value=0.0),
],
),
# Correct answer but with punctuation added.
Expand All @@ -197,6 +230,8 @@ class TestCaseQAAccuracyEvaluateSampleInvalid(NamedTuple):
EvalScore(name=F1_SCORE, value=1.0),
EvalScore(name=EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=1.0),
EvalScore(name=PRECISION, value=1.0),
EvalScore(name=RECALL, value=1.0),
],
),
# Many correct answers.
Expand All @@ -208,6 +243,21 @@ class TestCaseQAAccuracyEvaluateSampleInvalid(NamedTuple):
EvalScore(name=F1_SCORE, value=1.0),
EvalScore(name=EXACT_MATCH_SCORE, value=1.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=1.0),
EvalScore(name=PRECISION, value=1.0),
EvalScore(name=RECALL, value=1.0),
],
),
# Answer is longer than the model output.
TestCaseQAAccuracyEvaluateSample(
model_input="Did RMS Titanic sink in 1912?",
model_output="Yes. That is true.",
target_output="yes",
expected_response=[
EvalScore(name=F1_SCORE, value=0.4),
EvalScore(name=EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=QUASI_EXACT_MATCH_SCORE, value=0.0),
EvalScore(name=PRECISION, value=0.25),
EvalScore(name=RECALL, value=1.0),
],
),
],
Expand Down
Loading