-
Couldn't load subscription status.
- Fork 57
feat: added the precision and recall metrics for QA accuracy #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| from nltk.metrics.scores import f_measure | ||
| from nltk.metrics.scores import f_measure, precision, recall | ||
|
|
||
| import fmeval.util as util | ||
| from fmeval.constants import ( | ||
|
|
@@ -49,6 +49,8 @@ | |
| F1_SCORE = "f1_score" | ||
| EXACT_MATCH_SCORE = "exact_match_score" | ||
| QUASI_EXACT_MATCH_SCORE = "quasi_exact_match_score" | ||
| PRECISION = "precision" | ||
| RECALL = "recall" | ||
|
|
||
| PROMPT_COLUMN_NAME = "prompt" | ||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -120,6 +122,46 @@ def _f1_score( | |
| return float(ret) | ||
|
|
||
|
|
||
| def _precision(model_output: str, target_output: str, *, normalize_text: bool = False) -> float: | ||
| """ | ||
| Given the model output and the target output, compute the precision. | ||
| Precision is the fraction of words in the prediction that are also found in the target output. | ||
| Before computing precision, we normalize the text following the QuAC protocol. | ||
|
|
||
| :param model_output: The output of a model that we want to evaluate. | ||
| :param target_output: The reference or the "ground truth" output. | ||
| :param normalize_text: Normalize the text before computing f1. | ||
bilalaws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| :returns: Precision. | ||
| """ | ||
| if normalize_text: # pragma: no branch | ||
| model_output, target_output = (_normalize_text_quac_protocol(text) for text in (model_output, target_output)) | ||
| ret = precision(reference=set(target_output.split(" ")), test=set(model_output.split(" "))) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why set and not list? we want to discard repetitions of words? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Valid question, but set seems standard. At least that's what the NLTK metric assumes that we use here. |
||
| if ret is None: # pragma: no cover | ||
| return 0.0 | ||
| else: | ||
| return float(ret) | ||
|
|
||
|
|
||
| def _recall(model_output: str, target_output: str, *, normalize_text: bool = False) -> float: | ||
| """ | ||
| Given the model output and the target output, compute the recall. | ||
| Recall is the fraction of words in the target output that are also found in the answer. | ||
| Before computing recall, we normalize the text following the QuAC protocol. | ||
|
|
||
| :param model_output: The output of a model that we want to evaluate. | ||
| :param target_output: The reference or the "ground truth" output. | ||
| :param normalize_text: Normalize the text before computing f1. | ||
bilalaws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| :returns: Recall. | ||
| """ | ||
| if normalize_text: # pragma: no branch | ||
| model_output, target_output = (_normalize_text_quac_protocol(text) for text in (model_output, target_output)) | ||
| ret = recall(reference=set(target_output.split(" ")), test=set(model_output.split(" "))) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment |
||
| if ret is None: # pragma: no cover | ||
| return 0.0 | ||
| else: | ||
| return float(ret) | ||
|
|
||
|
|
||
| def _exact_match_score(model_output: str, target_output: str) -> float: | ||
| """ | ||
| Inspired by HELM: https://github.com/stanford-crfm/helm/blob/62f817eb695a31e8389e3f7be30609d3f0871837/src/helm/benchmark/metrics/basic_metrics.py#L137 | ||
|
|
@@ -150,6 +192,8 @@ def _quasi_exact_match_score(model_output: str, target_output: str) -> float: | |
| F1_SCORE: partial(_f1_score, normalize_text=True, strip_text=True), | ||
| EXACT_MATCH_SCORE: _exact_match_score, | ||
| QUASI_EXACT_MATCH_SCORE: _quasi_exact_match_score, | ||
| PRECISION: partial(_precision, normalize_text=True), | ||
| RECALL: partial(_recall, normalize_text=True), | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -236,7 +280,7 @@ def _generate_eval_scores(row: Dict[str, Any]) -> Dict[str, Any]: # pragma: no | |
| dataset = dataset.map(_generate_eval_scores).materialize() | ||
|
|
||
| dataset_scores, category_scores = aggregate_evaluation_scores( | ||
| dataset, [F1_SCORE, EXACT_MATCH_SCORE, QUASI_EXACT_MATCH_SCORE], agg_method=MEAN | ||
| dataset, [F1_SCORE, EXACT_MATCH_SCORE, QUASI_EXACT_MATCH_SCORE, PRECISION, RECALL], agg_method=MEAN | ||
| ) | ||
|
|
||
| eval_outputs.append( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.