Add Exact match and Regex Match Evaluators (#11132)

This commit is contained in:
William FH
2023-09-27 14:18:07 -07:00
committed by GitHub
parent e355606b11
commit 33da8bd711
25 changed files with 3641 additions and 3211 deletions

View File

@@ -67,8 +67,10 @@ from langchain.evaluation.embedding_distance import (
EmbeddingDistanceEvalChain,
PairwiseEmbeddingDistanceEvalChain,
)
from langchain.evaluation.exact_match.base import ExactMatchStringEvaluator
from langchain.evaluation.loading import load_dataset, load_evaluator, load_evaluators
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
from langchain.evaluation.regex_match.base import RegexMatchStringEvaluator
from langchain.evaluation.schema import (
AgentTrajectoryEvaluator,
EvaluatorType,
@@ -83,6 +85,8 @@ from langchain.evaluation.string_distance import (
__all__ = [
"EvaluatorType",
"ExactMatchStringEvaluator",
"RegexMatchStringEvaluator",
"PairwiseStringEvalChain",
"LabeledPairwiseStringEvalChain",
"QAEvalChain",

View File

@@ -0,0 +1,97 @@
import string
from typing import Any, List
from langchain.evaluation.schema import StringEvaluator
class ExactMatchStringEvaluator(StringEvaluator):
"""Compute an exact match between the prediction and the reference.
Examples
----------
>>> evaluator = ExactMatchChain()
>>> evaluator.evaluate_strings(
prediction="Mindy is the CTO",
reference="Mindy is the CTO",
) # This will return {'score': 1.0}
>>> evaluator.evaluate_strings(
prediction="Mindy is the CTO",
reference="Mindy is the CEO",
) # This will return {'score': 0.0}
"""
def __init__(
self,
*,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
**kwargs: Any,
):
super().__init__()
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.ignore_numbers = ignore_numbers
@property
def requires_input(self) -> bool:
"""
This evaluator does not require input.
"""
return False
@property
def requires_reference(self) -> bool:
"""
This evaluator requires a reference.
"""
return True
@property
def input_keys(self) -> List[str]:
"""
Get the input keys.
Returns:
List[str]: The input keys.
"""
return ["reference", "prediction"]
@property
def evaluation_name(self) -> str:
"""
Get the evaluation name.
Returns:
str: The evaluation name.
"""
return "exact_match"
def _evaluate_strings( # type: ignore[arg-type,override]
self,
*,
prediction: str,
reference: str,
**kwargs: Any,
) -> dict:
"""
Evaluate the exact match between the prediction and the reference.
Args:
prediction (str): The prediction string.
reference (Optional[str], optional): The reference string.
Returns:
dict: The evaluation results containing the score.
"""
if self.ignore_case:
prediction = prediction.lower()
reference = reference.lower()
if self.ignore_punctuation:
prediction = prediction.translate(str.maketrans("", "", string.punctuation))
reference = reference.translate(str.maketrans("", "", string.punctuation))
if self.ignore_numbers:
prediction = prediction.translate(str.maketrans("", "", string.digits))
reference = reference.translate(str.maketrans("", "", string.digits))
return {"score": int(prediction == reference)}

View File

@@ -14,11 +14,13 @@ from langchain.evaluation.embedding_distance.base import (
EmbeddingDistanceEvalChain,
PairwiseEmbeddingDistanceEvalChain,
)
from langchain.evaluation.exact_match.base import ExactMatchStringEvaluator
from langchain.evaluation.parsing.base import (
JsonEqualityEvaluator,
JsonValidityEvaluator,
)
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
from langchain.evaluation.regex_match.base import RegexMatchStringEvaluator
from langchain.evaluation.schema import EvaluatorType, LLMEvalChain, StringEvaluator
from langchain.evaluation.string_distance.base import (
PairwiseStringDistanceEvalChain,
@@ -78,6 +80,8 @@ _EVALUATOR_MAP: Dict[
EvaluatorType.PAIRWISE_EMBEDDING_DISTANCE: PairwiseEmbeddingDistanceEvalChain,
EvaluatorType.JSON_VALIDITY: JsonValidityEvaluator,
EvaluatorType.JSON_EQUALITY: JsonEqualityEvaluator,
EvaluatorType.REGEX_MATCH: RegexMatchStringEvaluator,
EvaluatorType.EXACT_MATCH: ExactMatchStringEvaluator,
}
@@ -111,7 +115,7 @@ def load_evaluator(
if evaluator not in _EVALUATOR_MAP:
raise ValueError(
f"Unknown evaluator type: {evaluator}"
f"Valid types are: {list(_EVALUATOR_MAP.keys())}"
f"\nValid types are: {list(_EVALUATOR_MAP.keys())}"
)
evaluator_cls = _EVALUATOR_MAP[evaluator]
if issubclass(evaluator_cls, LLMEvalChain):

View File

@@ -0,0 +1,86 @@
import re
from typing import Any, List
from langchain.evaluation.schema import StringEvaluator
class RegexMatchStringEvaluator(StringEvaluator):
"""Compute a regex match between the prediction and the reference.
Examples
----------
>>> evaluator = RegexMatchStringEvaluator(flags=re.IGNORECASE)
>>> evaluator.evaluate_strings(
prediction="Mindy is the CTO",
reference="^mindy.*cto$",
) # This will return {'score': 1.0} due to the IGNORECASE flag
>>> evaluator = RegexMatchStringEvaluator()
>>> evaluator.evaluate_strings(
prediction="Mindy is the CTO",
reference="^Mike.*CEO$",
) # This will return {'score': 0.0}
>>> evaluator.evaluate_strings(
prediction="Mindy is the CTO",
reference="^Mike.*CEO$|^Mindy.*CTO$",
) # This will return {'score': 1.0} as the prediction matches the second pattern in the union
""" # noqa: E501
def __init__(self, *, flags: int = 0, **kwargs: Any): # Default is no flags
super().__init__()
self.flags = flags
@property
def requires_input(self) -> bool:
"""
This evaluator does not require input.
"""
return False
@property
def requires_reference(self) -> bool:
"""
This evaluator requires a reference.
"""
return True
@property
def input_keys(self) -> List[str]:
"""
Get the input keys.
Returns:
List[str]: The input keys.
"""
return ["reference", "prediction"]
@property
def evaluation_name(self) -> str:
"""
Get the evaluation name.
Returns:
str: The evaluation name.
"""
return "regex_match"
def _evaluate_strings( # type: ignore[arg-type,override]
self,
*,
prediction: str,
reference: str,
**kwargs: Any,
) -> dict:
"""
Evaluate the regex match between the prediction and the reference.
Args:
prediction (str): The prediction string.
reference (Optional[str], optional): The reference regex pattern.
Returns:
dict: The evaluation results containing the score.
"""
match = re.match(reference, prediction, flags=self.flags)
return {"score": int(bool(match))}

View File

@@ -44,6 +44,10 @@ class EvaluatorType(str, Enum):
custom set of criteria, with a reference label."""
STRING_DISTANCE = "string_distance"
"""Compare predictions to a reference answer using string edit distances."""
EXACT_MATCH = "exact_match"
"""Compare predictions to a reference answer using exact matching."""
REGEX_MATCH = "regex_match"
"""Compare predictions to a reference answer using regular expressions."""
PAIRWISE_STRING_DISTANCE = "pairwise_string_distance"
"""Compare predictions based on string edit distances."""
EMBEDDING_DISTANCE = "embedding_distance"

View File

@@ -261,4 +261,34 @@ class RunEvalConfig(BaseModel):
evaluator_type: EvaluatorType = EvaluatorType.JSON_EQUALITY
class ExactMatch(EvalConfig):
"""Configuration for an exact match string evaluator.
Parameters
----------
ignore_case : bool
Whether to ignore case when comparing strings.
ignore_punctuation : bool
Whether to ignore punctuation when comparing strings.
ignore_numbers : bool
Whether to ignore numbers when comparing strings.
"""
evaluator_type: EvaluatorType = EvaluatorType.STRING_DISTANCE
ignore_case: bool = False
ignore_punctuation: bool = False
ignore_numbers: bool = False
class RegexMatch(EvalConfig):
"""Configuration for a regex match string evaluator.
Parameters
----------
flags : int
The flags to pass to the regex. Example: re.IGNORECASE.
"""
evaluator_type: EvaluatorType = EvaluatorType.REGEX_MATCH
flags: int = 0
# TODO: Trajectory

View File

@@ -0,0 +1,49 @@
import pytest
from langchain.evaluation import ExactMatchStringEvaluator
@pytest.fixture
def exact_match_string_evaluator() -> ExactMatchStringEvaluator:
"""Create an ExactMatchStringEvaluator with default configuration."""
return ExactMatchStringEvaluator()
@pytest.fixture
def exact_match_string_evaluator_ignore_case() -> ExactMatchStringEvaluator:
"""Create an ExactMatchStringEvaluator with ignore_case set to True."""
return ExactMatchStringEvaluator(ignore_case=True)
def test_default_exact_matching(
exact_match_string_evaluator: ExactMatchStringEvaluator,
) -> None:
prediction = "Mindy is the CTO"
reference = "Mindy is the CTO"
result = exact_match_string_evaluator.evaluate_strings(
prediction=prediction, reference=reference
)
assert result["score"] == 1.0
reference = "Mindy is the CEO"
result = exact_match_string_evaluator.evaluate_strings(
prediction=prediction, reference=reference
)
assert result["score"] == 0.0
def test_exact_matching_with_ignore_case(
exact_match_string_evaluator_ignore_case: ExactMatchStringEvaluator,
) -> None:
prediction = "Mindy is the CTO"
reference = "mindy is the cto"
result = exact_match_string_evaluator_ignore_case.evaluate_strings(
prediction=prediction, reference=reference
)
assert result["score"] == 1.0
reference = "mindy is the CEO"
result = exact_match_string_evaluator_ignore_case.evaluate_strings(
prediction=prediction, reference=reference
)
assert result["score"] == 0.0

View File

@@ -0,0 +1,45 @@
import re
import pytest
from langchain.evaluation import RegexMatchStringEvaluator
@pytest.fixture
def regex_match_string_evaluator() -> RegexMatchStringEvaluator:
"""Create a RegexMatchStringEvaluator with default configuration."""
return RegexMatchStringEvaluator()
@pytest.fixture
def regex_match_string_evaluator_ignore_case() -> RegexMatchStringEvaluator:
"""Create a RegexMatchStringEvaluator with IGNORECASE flag."""
return RegexMatchStringEvaluator(flags=re.IGNORECASE)
def test_default_regex_matching(
regex_match_string_evaluator: RegexMatchStringEvaluator,
) -> None:
prediction = "Mindy is the CTO"
reference = "^Mindy.*CTO$"
result = regex_match_string_evaluator.evaluate_strings(
prediction=prediction, reference=reference
)
assert result["score"] == 1.0
reference = "^Mike.*CEO$"
result = regex_match_string_evaluator.evaluate_strings(
prediction=prediction, reference=reference
)
assert result["score"] == 0.0
def test_regex_matching_with_ignore_case(
regex_match_string_evaluator_ignore_case: RegexMatchStringEvaluator,
) -> None:
prediction = "Mindy is the CTO"
reference = "^mindy.*cto$"
result = regex_match_string_evaluator_ignore_case.evaluate_strings(
prediction=prediction, reference=reference
)
assert result["score"] == 1.0

View File

@@ -41,6 +41,7 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
EvaluatorType.LABELED_PAIRWISE_STRING,
],
[EvaluatorType.JSON_EQUALITY],
[EvaluatorType.EXACT_MATCH, EvaluatorType.REGEX_MATCH],
],
)
def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None: