mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 03:19:38 +00:00
Add Exact match and Regex Match Evaluators (#11132)
This commit is contained in:
@@ -67,8 +67,10 @@ from langchain.evaluation.embedding_distance import (
|
||||
EmbeddingDistanceEvalChain,
|
||||
PairwiseEmbeddingDistanceEvalChain,
|
||||
)
|
||||
from langchain.evaluation.exact_match.base import ExactMatchStringEvaluator
|
||||
from langchain.evaluation.loading import load_dataset, load_evaluator, load_evaluators
|
||||
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
|
||||
from langchain.evaluation.regex_match.base import RegexMatchStringEvaluator
|
||||
from langchain.evaluation.schema import (
|
||||
AgentTrajectoryEvaluator,
|
||||
EvaluatorType,
|
||||
@@ -83,6 +85,8 @@ from langchain.evaluation.string_distance import (
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorType",
|
||||
"ExactMatchStringEvaluator",
|
||||
"RegexMatchStringEvaluator",
|
||||
"PairwiseStringEvalChain",
|
||||
"LabeledPairwiseStringEvalChain",
|
||||
"QAEvalChain",
|
||||
|
97
libs/langchain/langchain/evaluation/exact_match/base.py
Normal file
97
libs/langchain/langchain/evaluation/exact_match/base.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import string
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
|
||||
|
||||
class ExactMatchStringEvaluator(StringEvaluator):
|
||||
"""Compute an exact match between the prediction and the reference.
|
||||
|
||||
Examples
|
||||
----------
|
||||
>>> evaluator = ExactMatchChain()
|
||||
>>> evaluator.evaluate_strings(
|
||||
prediction="Mindy is the CTO",
|
||||
reference="Mindy is the CTO",
|
||||
) # This will return {'score': 1.0}
|
||||
|
||||
>>> evaluator.evaluate_strings(
|
||||
prediction="Mindy is the CTO",
|
||||
reference="Mindy is the CEO",
|
||||
) # This will return {'score': 0.0}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ignore_case: bool = False,
|
||||
ignore_punctuation: bool = False,
|
||||
ignore_numbers: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore_case = ignore_case
|
||||
self.ignore_punctuation = ignore_punctuation
|
||||
self.ignore_numbers = ignore_numbers
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
"""
|
||||
This evaluator does not require input.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
"""
|
||||
This evaluator requires a reference.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""
|
||||
Get the input keys.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
"""
|
||||
return ["reference", "prediction"]
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
"""
|
||||
Get the evaluation name.
|
||||
|
||||
Returns:
|
||||
str: The evaluation name.
|
||||
"""
|
||||
return "exact_match"
|
||||
|
||||
def _evaluate_strings( # type: ignore[arg-type,override]
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
reference: str,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate the exact match between the prediction and the reference.
|
||||
|
||||
Args:
|
||||
prediction (str): The prediction string.
|
||||
reference (Optional[str], optional): The reference string.
|
||||
|
||||
Returns:
|
||||
dict: The evaluation results containing the score.
|
||||
"""
|
||||
if self.ignore_case:
|
||||
prediction = prediction.lower()
|
||||
reference = reference.lower()
|
||||
if self.ignore_punctuation:
|
||||
prediction = prediction.translate(str.maketrans("", "", string.punctuation))
|
||||
reference = reference.translate(str.maketrans("", "", string.punctuation))
|
||||
if self.ignore_numbers:
|
||||
prediction = prediction.translate(str.maketrans("", "", string.digits))
|
||||
reference = reference.translate(str.maketrans("", "", string.digits))
|
||||
return {"score": int(prediction == reference)}
|
@@ -14,11 +14,13 @@ from langchain.evaluation.embedding_distance.base import (
|
||||
EmbeddingDistanceEvalChain,
|
||||
PairwiseEmbeddingDistanceEvalChain,
|
||||
)
|
||||
from langchain.evaluation.exact_match.base import ExactMatchStringEvaluator
|
||||
from langchain.evaluation.parsing.base import (
|
||||
JsonEqualityEvaluator,
|
||||
JsonValidityEvaluator,
|
||||
)
|
||||
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
|
||||
from langchain.evaluation.regex_match.base import RegexMatchStringEvaluator
|
||||
from langchain.evaluation.schema import EvaluatorType, LLMEvalChain, StringEvaluator
|
||||
from langchain.evaluation.string_distance.base import (
|
||||
PairwiseStringDistanceEvalChain,
|
||||
@@ -78,6 +80,8 @@ _EVALUATOR_MAP: Dict[
|
||||
EvaluatorType.PAIRWISE_EMBEDDING_DISTANCE: PairwiseEmbeddingDistanceEvalChain,
|
||||
EvaluatorType.JSON_VALIDITY: JsonValidityEvaluator,
|
||||
EvaluatorType.JSON_EQUALITY: JsonEqualityEvaluator,
|
||||
EvaluatorType.REGEX_MATCH: RegexMatchStringEvaluator,
|
||||
EvaluatorType.EXACT_MATCH: ExactMatchStringEvaluator,
|
||||
}
|
||||
|
||||
|
||||
@@ -111,7 +115,7 @@ def load_evaluator(
|
||||
if evaluator not in _EVALUATOR_MAP:
|
||||
raise ValueError(
|
||||
f"Unknown evaluator type: {evaluator}"
|
||||
f"Valid types are: {list(_EVALUATOR_MAP.keys())}"
|
||||
f"\nValid types are: {list(_EVALUATOR_MAP.keys())}"
|
||||
)
|
||||
evaluator_cls = _EVALUATOR_MAP[evaluator]
|
||||
if issubclass(evaluator_cls, LLMEvalChain):
|
||||
|
86
libs/langchain/langchain/evaluation/regex_match/base.py
Normal file
86
libs/langchain/langchain/evaluation/regex_match/base.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import re
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
|
||||
|
||||
class RegexMatchStringEvaluator(StringEvaluator):
|
||||
"""Compute a regex match between the prediction and the reference.
|
||||
|
||||
Examples
|
||||
----------
|
||||
>>> evaluator = RegexMatchStringEvaluator(flags=re.IGNORECASE)
|
||||
>>> evaluator.evaluate_strings(
|
||||
prediction="Mindy is the CTO",
|
||||
reference="^mindy.*cto$",
|
||||
) # This will return {'score': 1.0} due to the IGNORECASE flag
|
||||
|
||||
>>> evaluator = RegexMatchStringEvaluator()
|
||||
>>> evaluator.evaluate_strings(
|
||||
prediction="Mindy is the CTO",
|
||||
reference="^Mike.*CEO$",
|
||||
) # This will return {'score': 0.0}
|
||||
|
||||
>>> evaluator.evaluate_strings(
|
||||
prediction="Mindy is the CTO",
|
||||
reference="^Mike.*CEO$|^Mindy.*CTO$",
|
||||
) # This will return {'score': 1.0} as the prediction matches the second pattern in the union
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, *, flags: int = 0, **kwargs: Any): # Default is no flags
|
||||
super().__init__()
|
||||
self.flags = flags
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
"""
|
||||
This evaluator does not require input.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
"""
|
||||
This evaluator requires a reference.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""
|
||||
Get the input keys.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
"""
|
||||
return ["reference", "prediction"]
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
"""
|
||||
Get the evaluation name.
|
||||
|
||||
Returns:
|
||||
str: The evaluation name.
|
||||
"""
|
||||
return "regex_match"
|
||||
|
||||
def _evaluate_strings( # type: ignore[arg-type,override]
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
reference: str,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate the regex match between the prediction and the reference.
|
||||
|
||||
Args:
|
||||
prediction (str): The prediction string.
|
||||
reference (Optional[str], optional): The reference regex pattern.
|
||||
|
||||
Returns:
|
||||
dict: The evaluation results containing the score.
|
||||
"""
|
||||
match = re.match(reference, prediction, flags=self.flags)
|
||||
return {"score": int(bool(match))}
|
@@ -44,6 +44,10 @@ class EvaluatorType(str, Enum):
|
||||
custom set of criteria, with a reference label."""
|
||||
STRING_DISTANCE = "string_distance"
|
||||
"""Compare predictions to a reference answer using string edit distances."""
|
||||
EXACT_MATCH = "exact_match"
|
||||
"""Compare predictions to a reference answer using exact matching."""
|
||||
REGEX_MATCH = "regex_match"
|
||||
"""Compare predictions to a reference answer using regular expressions."""
|
||||
PAIRWISE_STRING_DISTANCE = "pairwise_string_distance"
|
||||
"""Compare predictions based on string edit distances."""
|
||||
EMBEDDING_DISTANCE = "embedding_distance"
|
||||
|
@@ -261,4 +261,34 @@ class RunEvalConfig(BaseModel):
|
||||
|
||||
evaluator_type: EvaluatorType = EvaluatorType.JSON_EQUALITY
|
||||
|
||||
class ExactMatch(EvalConfig):
|
||||
"""Configuration for an exact match string evaluator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ignore_case : bool
|
||||
Whether to ignore case when comparing strings.
|
||||
ignore_punctuation : bool
|
||||
Whether to ignore punctuation when comparing strings.
|
||||
ignore_numbers : bool
|
||||
Whether to ignore numbers when comparing strings.
|
||||
"""
|
||||
|
||||
evaluator_type: EvaluatorType = EvaluatorType.STRING_DISTANCE
|
||||
ignore_case: bool = False
|
||||
ignore_punctuation: bool = False
|
||||
ignore_numbers: bool = False
|
||||
|
||||
class RegexMatch(EvalConfig):
|
||||
"""Configuration for a regex match string evaluator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flags : int
|
||||
The flags to pass to the regex. Example: re.IGNORECASE.
|
||||
"""
|
||||
|
||||
evaluator_type: EvaluatorType = EvaluatorType.REGEX_MATCH
|
||||
flags: int = 0
|
||||
|
||||
# TODO: Trajectory
|
||||
|
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
|
||||
from langchain.evaluation import ExactMatchStringEvaluator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def exact_match_string_evaluator() -> ExactMatchStringEvaluator:
|
||||
"""Create an ExactMatchStringEvaluator with default configuration."""
|
||||
return ExactMatchStringEvaluator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def exact_match_string_evaluator_ignore_case() -> ExactMatchStringEvaluator:
|
||||
"""Create an ExactMatchStringEvaluator with ignore_case set to True."""
|
||||
return ExactMatchStringEvaluator(ignore_case=True)
|
||||
|
||||
|
||||
def test_default_exact_matching(
|
||||
exact_match_string_evaluator: ExactMatchStringEvaluator,
|
||||
) -> None:
|
||||
prediction = "Mindy is the CTO"
|
||||
reference = "Mindy is the CTO"
|
||||
result = exact_match_string_evaluator.evaluate_strings(
|
||||
prediction=prediction, reference=reference
|
||||
)
|
||||
assert result["score"] == 1.0
|
||||
|
||||
reference = "Mindy is the CEO"
|
||||
result = exact_match_string_evaluator.evaluate_strings(
|
||||
prediction=prediction, reference=reference
|
||||
)
|
||||
assert result["score"] == 0.0
|
||||
|
||||
|
||||
def test_exact_matching_with_ignore_case(
|
||||
exact_match_string_evaluator_ignore_case: ExactMatchStringEvaluator,
|
||||
) -> None:
|
||||
prediction = "Mindy is the CTO"
|
||||
reference = "mindy is the cto"
|
||||
result = exact_match_string_evaluator_ignore_case.evaluate_strings(
|
||||
prediction=prediction, reference=reference
|
||||
)
|
||||
assert result["score"] == 1.0
|
||||
|
||||
reference = "mindy is the CEO"
|
||||
result = exact_match_string_evaluator_ignore_case.evaluate_strings(
|
||||
prediction=prediction, reference=reference
|
||||
)
|
||||
assert result["score"] == 0.0
|
@@ -0,0 +1,45 @@
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.evaluation import RegexMatchStringEvaluator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regex_match_string_evaluator() -> RegexMatchStringEvaluator:
|
||||
"""Create a RegexMatchStringEvaluator with default configuration."""
|
||||
return RegexMatchStringEvaluator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regex_match_string_evaluator_ignore_case() -> RegexMatchStringEvaluator:
|
||||
"""Create a RegexMatchStringEvaluator with IGNORECASE flag."""
|
||||
return RegexMatchStringEvaluator(flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def test_default_regex_matching(
|
||||
regex_match_string_evaluator: RegexMatchStringEvaluator,
|
||||
) -> None:
|
||||
prediction = "Mindy is the CTO"
|
||||
reference = "^Mindy.*CTO$"
|
||||
result = regex_match_string_evaluator.evaluate_strings(
|
||||
prediction=prediction, reference=reference
|
||||
)
|
||||
assert result["score"] == 1.0
|
||||
|
||||
reference = "^Mike.*CEO$"
|
||||
result = regex_match_string_evaluator.evaluate_strings(
|
||||
prediction=prediction, reference=reference
|
||||
)
|
||||
assert result["score"] == 0.0
|
||||
|
||||
|
||||
def test_regex_matching_with_ignore_case(
|
||||
regex_match_string_evaluator_ignore_case: RegexMatchStringEvaluator,
|
||||
) -> None:
|
||||
prediction = "Mindy is the CTO"
|
||||
reference = "^mindy.*cto$"
|
||||
result = regex_match_string_evaluator_ignore_case.evaluate_strings(
|
||||
prediction=prediction, reference=reference
|
||||
)
|
||||
assert result["score"] == 1.0
|
@@ -41,6 +41,7 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
|
||||
EvaluatorType.LABELED_PAIRWISE_STRING,
|
||||
],
|
||||
[EvaluatorType.JSON_EQUALITY],
|
||||
[EvaluatorType.EXACT_MATCH, EvaluatorType.REGEX_MATCH],
|
||||
],
|
||||
)
|
||||
def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None:
|
||||
|
Reference in New Issue
Block a user