mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
1 Commits
langchain-
...
wfh/json_o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61302fa033 |
@@ -3,27 +3,14 @@ from typing import Any, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
|
||||
from langchain.evaluation.comparison import PairwiseStringEvalChain
|
||||
from langchain.evaluation.comparison.eval_chain import LabeledPairwiseStringEvalChain
|
||||
from langchain.evaluation.criteria.eval_chain import (
|
||||
CriteriaEvalChain,
|
||||
LabeledCriteriaEvalChain,
|
||||
)
|
||||
from langchain.evaluation.embedding_distance.base import (
|
||||
EmbeddingDistanceEvalChain,
|
||||
PairwiseEmbeddingDistanceEvalChain,
|
||||
)
|
||||
from langchain.evaluation.parsing.base import (
|
||||
JsonEqualityEvaluator,
|
||||
JsonValidityEvaluator,
|
||||
)
|
||||
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
|
||||
from langchain.evaluation import qa
|
||||
from langchain.evaluation.agents import trajectory_eval_chain
|
||||
from langchain.evaluation.comparison import eval_chain as comparison
|
||||
from langchain.evaluation.criteria import eval_chain as criteria
|
||||
from langchain.evaluation.embedding_distance import base as embedding
|
||||
from langchain.evaluation.parsing import base as json_parsing
|
||||
from langchain.evaluation.schema import EvaluatorType, LLMEvalChain, StringEvaluator
|
||||
from langchain.evaluation.string_distance.base import (
|
||||
PairwiseStringDistanceEvalChain,
|
||||
StringDistanceEvalChain,
|
||||
)
|
||||
from langchain.evaluation.string_distance import base as string_distance
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
@@ -64,20 +51,25 @@ def load_dataset(uri: str) -> List[Dict]:
|
||||
_EVALUATOR_MAP: Dict[
|
||||
EvaluatorType, Union[Type[LLMEvalChain], Type[Chain], Type[StringEvaluator]]
|
||||
] = {
|
||||
EvaluatorType.QA: QAEvalChain,
|
||||
EvaluatorType.COT_QA: CotQAEvalChain,
|
||||
EvaluatorType.CONTEXT_QA: ContextQAEvalChain,
|
||||
EvaluatorType.PAIRWISE_STRING: PairwiseStringEvalChain,
|
||||
EvaluatorType.LABELED_PAIRWISE_STRING: LabeledPairwiseStringEvalChain,
|
||||
EvaluatorType.AGENT_TRAJECTORY: TrajectoryEvalChain,
|
||||
EvaluatorType.CRITERIA: CriteriaEvalChain,
|
||||
EvaluatorType.LABELED_CRITERIA: LabeledCriteriaEvalChain,
|
||||
EvaluatorType.STRING_DISTANCE: StringDistanceEvalChain,
|
||||
EvaluatorType.PAIRWISE_STRING_DISTANCE: PairwiseStringDistanceEvalChain,
|
||||
EvaluatorType.EMBEDDING_DISTANCE: EmbeddingDistanceEvalChain,
|
||||
EvaluatorType.PAIRWISE_EMBEDDING_DISTANCE: PairwiseEmbeddingDistanceEvalChain,
|
||||
EvaluatorType.JSON_VALIDITY: JsonValidityEvaluator,
|
||||
EvaluatorType.JSON_EQUALITY: JsonEqualityEvaluator,
|
||||
EvaluatorType.QA: qa.QAEvalChain,
|
||||
EvaluatorType.COT_QA: qa.CotQAEvalChain,
|
||||
EvaluatorType.CONTEXT_QA: qa.ContextQAEvalChain,
|
||||
EvaluatorType.PAIRWISE_STRING: comparison.PairwiseStringEvalChain,
|
||||
EvaluatorType.LABELED_PAIRWISE_STRING: comparison.LabeledPairwiseStringEvalChain,
|
||||
EvaluatorType.AGENT_TRAJECTORY: trajectory_eval_chain.TrajectoryEvalChain,
|
||||
EvaluatorType.CRITERIA: criteria.CriteriaEvalChain,
|
||||
EvaluatorType.LABELED_CRITERIA: criteria.LabeledCriteriaEvalChain,
|
||||
EvaluatorType.STRING_DISTANCE: string_distance.StringDistanceEvalChain,
|
||||
EvaluatorType.PAIRWISE_STRING_DISTANCE: string_distance.PairwiseStringDistanceEvalChain, # noqa: E501
|
||||
EvaluatorType.EMBEDDING_DISTANCE: embedding.EmbeddingDistanceEvalChain,
|
||||
EvaluatorType.PAIRWISE_EMBEDDING_DISTANCE: embedding.PairwiseEmbeddingDistanceEvalChain, # noqa: E501
|
||||
EvaluatorType.JSON_VALIDITY: json_parsing.JsonValidityEvaluator,
|
||||
EvaluatorType.JSON_EQUALITY: json_parsing.JsonEqualityEvaluator,
|
||||
EvaluatorType.JSON_ACCURACY: json_parsing.JsonAccuracyEvaluator,
|
||||
EvaluatorType.JSON_PRECISION: json_parsing.JsonPrecisionEvaluator,
|
||||
EvaluatorType.JSON_RECALL: json_parsing.JsonRecallEvaluator,
|
||||
EvaluatorType.JSON_IOU: json_parsing.JsonIoUEvaluator,
|
||||
EvaluatorType.JSON_F1: json_parsing.JsonF1Evaluator,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Evaluators for parsing strings."""
|
||||
from abc import abstractmethod
|
||||
from operator import eq
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from typing import Any, Callable, Dict, Optional, Union, cast
|
||||
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
@@ -74,7 +75,122 @@ class JsonValidityEvaluator(StringEvaluator):
|
||||
return {"score": 0, "reasoning": str(e)}
|
||||
|
||||
|
||||
class JsonEqualityEvaluator(StringEvaluator):
|
||||
class _JsonComparisonEvaluator(StringEvaluator):
|
||||
"""Evaluated by comparing the predicted structured object to the
|
||||
reference structured object.
|
||||
|
||||
It does not require an input string.
|
||||
|
||||
Attributes:
|
||||
requires_input (bool): Whether this evaluator requires an
|
||||
input string. Always False.
|
||||
requires_reference (bool): Whether this evaluator requires
|
||||
a reference string. Always True.
|
||||
evaluation_name (str): The name of the evaluation metric.s
|
||||
|
||||
Examples:
|
||||
>>> evaluator = JsonEqualityEvaluator()
|
||||
>>> evaluator.evaluate_strings('{"a": 1}', reference='{"a": 1}')
|
||||
{'score': True}
|
||||
>>> evaluator.evaluate_strings('{"a": 1}', reference='{"a": 2}')
|
||||
{'score': False}
|
||||
|
||||
>>> evaluator = JsonEqualityEvaluator(operator=lambda x, y: x['a'] == y['a'])
|
||||
>>> evaluator.evaluate_strings('{"a": 1}', reference='{"a": 1}')
|
||||
{'score': True}
|
||||
>>> evaluator.evaluate_strings('{"a": 1}', reference='{"a": 2}')
|
||||
{'score': False}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return "json_equality"
|
||||
|
||||
def _parse_json(
|
||||
self, string: str
|
||||
) -> Union[dict, list, None, float, bool, int, str]:
|
||||
return parse_json_markdown(string)
|
||||
|
||||
@abstractmethod
|
||||
def _compare_objects(self, prediction: Any, reference: Any) -> dict:
|
||||
"""Compare the prediction and reference objects.
|
||||
|
||||
Args:
|
||||
prediction (Any): The prediction object.
|
||||
reference (Any): The reference object.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the evaluation score.
|
||||
"""
|
||||
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
prediction: str,
|
||||
input: Optional[str] = None,
|
||||
reference: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
) -> dict:
|
||||
"""Evaluate the prediction string.
|
||||
|
||||
Args:
|
||||
prediction (str): The prediction string to evaluate.
|
||||
input (str, optional): Not used in this evaluator.
|
||||
reference (str): The reference string to compare against.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the evaluation score.
|
||||
"""
|
||||
parsed = self._parse_json(prediction)
|
||||
label = self._parse_json(cast(str, reference))
|
||||
return self._compare_objects(parsed, label)
|
||||
|
||||
|
||||
class JsonRecallEvaluator(_JsonComparisonEvaluator):
|
||||
"""Evaluates the recall of JSON field extraction.
|
||||
|
||||
Recall is calculated as (True Positives) / (True Positives + False Negatives).
|
||||
|
||||
Attributes:
|
||||
operator (Callable): A custom function to compare field values.
|
||||
|
||||
Examples:
|
||||
>>> evaluator = JsonRecallEvaluator()
|
||||
>>> evaluator.evaluate_strings('{"a": 1}', reference='{"a": 1, "b": 2}')
|
||||
{'score': 0.5}
|
||||
"""
|
||||
|
||||
def __init__(self, operator: Optional[Callable] = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.operator = operator or eq
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return "json_recall"
|
||||
|
||||
def _compare_objects(self, prediction: Any, reference: Any) -> dict:
|
||||
true_positives = 0
|
||||
total_actual = len(reference.keys())
|
||||
|
||||
for k, v in reference.items():
|
||||
if k in prediction and self.operator(v, prediction[k]):
|
||||
true_positives += 1
|
||||
|
||||
return {"score": true_positives / total_actual if total_actual > 0 else 0}
|
||||
|
||||
|
||||
class JsonEqualityEvaluator(_JsonComparisonEvaluator):
|
||||
"""Evaluates whether the prediction is equal to the reference after
|
||||
parsing both as JSON.
|
||||
|
||||
@@ -106,48 +222,161 @@ class JsonEqualityEvaluator(StringEvaluator):
|
||||
"""
|
||||
|
||||
def __init__(self, operator: Optional[Callable] = None, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.operator = operator or eq
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return "json_equality"
|
||||
|
||||
def _parse_json(
|
||||
self, string: str
|
||||
) -> Union[dict, list, None, float, bool, int, str]:
|
||||
return parse_json_markdown(string)
|
||||
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
prediction: str,
|
||||
input: Optional[str] = None,
|
||||
reference: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
) -> dict:
|
||||
"""Evaluate the prediction string.
|
||||
|
||||
Args:
|
||||
prediction (str): The prediction string to evaluate.
|
||||
input (str, optional): Not used in this evaluator.
|
||||
reference (str): The reference string to compare against.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the evaluation score.
|
||||
"""
|
||||
parsed = self._parse_json(prediction)
|
||||
label = self._parse_json(cast(str, reference))
|
||||
if isinstance(label, list):
|
||||
if not isinstance(parsed, list):
|
||||
def _compare_objects(self, prediction: Any, reference: Any) -> dict:
|
||||
if isinstance(reference, list):
|
||||
if not isinstance(prediction, list):
|
||||
return {"score": 0}
|
||||
parsed = sorted(parsed, key=lambda x: str(x))
|
||||
label = sorted(label, key=lambda x: str(x))
|
||||
return {"score": self.operator(parsed, label)}
|
||||
prediction = sorted(prediction, key=lambda x: str(x))
|
||||
reference = sorted(reference, key=lambda x: str(x))
|
||||
return {"score": self.operator(prediction, reference)}
|
||||
|
||||
|
||||
class JsonAccuracyEvaluator(_JsonComparisonEvaluator):
|
||||
"""Evaluates the accuracy of JSON field extraction.
|
||||
|
||||
Accuracy is calculated as (True Positives + True Negatives)
|
||||
/ (Total Predicted + Total Actual).
|
||||
|
||||
Attributes:
|
||||
operator (Callable): A custom function to compare field values.
|
||||
|
||||
Examples:
|
||||
>>> evaluator = JsonAccuracyEvaluator()
|
||||
>>> evaluator.evaluate_strings('{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}')
|
||||
{'score': 1.0}
|
||||
"""
|
||||
|
||||
def __init__(self, operator: Optional[Callable] = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.operator = operator or eq
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return "json_accuracy"
|
||||
|
||||
def _compare_objects(self, prediction: Any, reference: Any) -> dict:
|
||||
true_positives = 0
|
||||
total = len(set(prediction).union(reference))
|
||||
|
||||
for k, v in reference.items():
|
||||
if k in prediction and self.operator(v, prediction[k]):
|
||||
true_positives += 1
|
||||
|
||||
return {"score": true_positives / total if total > 0 else 0}
|
||||
|
||||
|
||||
class JsonPrecisionEvaluator(_JsonComparisonEvaluator):
|
||||
"""Evaluates the precision of JSON field extraction.
|
||||
|
||||
Precision is calculated as (True Positives) / (True Positives + False Positives).
|
||||
|
||||
Attributes:
|
||||
operator (Callable): A custom function to compare field values.
|
||||
|
||||
Examples:
|
||||
>>> evaluator = JsonPrecisionEvaluator()
|
||||
>>> evaluator.evaluate_strings('{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}')
|
||||
{'score': 1.0}
|
||||
>>> evaluator.evaluate_strings('{"a": 1, "b": 3}', reference='{"a": 1, "b": 2}')
|
||||
{'score': 0.5}
|
||||
"""
|
||||
|
||||
def __init__(self, operator: Optional[Callable] = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.operator = operator or eq
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return "json_precision"
|
||||
|
||||
def _compare_objects(self, prediction: Any, reference: Any) -> dict:
|
||||
true_positives = 0
|
||||
total_predicted = len(prediction.keys())
|
||||
|
||||
for k, v in prediction.items():
|
||||
if k in reference and self.operator(v, reference[k]):
|
||||
true_positives += 1
|
||||
|
||||
return {"score": true_positives / total_predicted if total_predicted > 0 else 0}
|
||||
|
||||
|
||||
class JsonIoUEvaluator(_JsonComparisonEvaluator):
|
||||
"""Evaluates the Intersection over Union (IoU) of JSON field extraction.
|
||||
|
||||
Attributes:
|
||||
operator (Callable): A custom function to compare field values.
|
||||
|
||||
Examples:
|
||||
>>> evaluator = JsonIoUEvaluator()
|
||||
>>> evaluator.evaluate_strings('{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}')
|
||||
{'score': 1.0}
|
||||
>>> evaluator.evaluate_strings('{"a": 1}', reference='{"a": 1, "b": 2}')
|
||||
{'score': 0.5}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, operator: Optional[Callable] = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.operator = operator or eq
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return "json_iou"
|
||||
|
||||
def _compare_objects(self, prediction: Any, reference: Any) -> dict:
|
||||
intersection = 0
|
||||
union = len(set(prediction).union(reference))
|
||||
|
||||
for k, v in reference.items():
|
||||
if k in prediction and self.operator(prediction[k], v):
|
||||
intersection += 1
|
||||
|
||||
return {"score": intersection / union if union > 0 else 0}
|
||||
|
||||
|
||||
class JsonF1Evaluator(_JsonComparisonEvaluator):
|
||||
"""Evaluates the F1 score of JSON field extraction.
|
||||
|
||||
F1 is the harmonic mean of precision and recall.
|
||||
|
||||
Attributes:
|
||||
operator (Callable): A custom function to compare field values.
|
||||
|
||||
Examples:
|
||||
>>> evaluator = JsonF1Evaluator()
|
||||
>>> evaluator.evaluate_strings('{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}')
|
||||
{'score': 1.0}
|
||||
"""
|
||||
|
||||
def __init__(self, operator: Optional[Callable] = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.operator = operator or eq
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return "json_f1"
|
||||
|
||||
def _compare_objects(self, prediction: Any, reference: Any) -> Dict[str, Any]:
|
||||
true_positives = 0
|
||||
total_actual = len(reference.keys())
|
||||
total_predicted = len(prediction.keys())
|
||||
|
||||
for k, v in reference.items():
|
||||
if k in prediction and self.operator(v, prediction[k]):
|
||||
true_positives += 1
|
||||
|
||||
if true_positives == 0:
|
||||
return {"score": 0}
|
||||
|
||||
precision = true_positives / total_predicted
|
||||
recall = true_positives / total_actual
|
||||
|
||||
f1 = 2 * (precision * recall) / (precision + recall)
|
||||
return {"score": f1}
|
||||
|
||||
@@ -54,6 +54,20 @@ class EvaluatorType(str, Enum):
|
||||
"""Check if a prediction is valid JSON."""
|
||||
JSON_EQUALITY = "json_equality"
|
||||
"""Check if a prediction is equal to a reference JSON."""
|
||||
JSON_ACCURACY = "json_accuracy"
|
||||
"""Check the mean equivalence of a prediction to a reference JSON across keys."""
|
||||
JSON_PRECISION = "json_precision"
|
||||
"""Check the mean proportion of keys in the prediction that are
|
||||
also in the reference JSON."""
|
||||
JSON_RECALL = "json_recall"
|
||||
"""Check the mean proportion of keys in the reference JSON that
|
||||
are also in the prediction."""
|
||||
JSON_IOU = "json_iou"
|
||||
"""Check the mean intersection over union of keys between the
|
||||
prediction and reference JSON."""
|
||||
JSON_F1 = "json_f1"
|
||||
"""Check the mean harmonic mean of precision and recall of keys
|
||||
between the prediction and reference JSON."""
|
||||
|
||||
|
||||
class LLMEvalChain(Chain):
|
||||
|
||||
@@ -3,7 +3,12 @@ import random
|
||||
import pytest
|
||||
|
||||
from langchain.evaluation.parsing.base import (
|
||||
JsonAccuracyEvaluator,
|
||||
JsonEqualityEvaluator,
|
||||
JsonF1Evaluator,
|
||||
JsonIoUEvaluator,
|
||||
JsonPrecisionEvaluator,
|
||||
JsonRecallEvaluator,
|
||||
JsonValidityEvaluator,
|
||||
)
|
||||
|
||||
@@ -175,3 +180,164 @@ def test_json_equality_evaluator_evaluate_lists_permutation_invariant() -> None:
|
||||
)
|
||||
result = evaluator.evaluate_strings(prediction=prediction, reference=reference)
|
||||
assert result == {"score": False}
|
||||
|
||||
|
||||
def test_json_recall_evaluator() -> None:
|
||||
evaluator = JsonRecallEvaluator()
|
||||
|
||||
# Test 1: Exact match
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 1.0}
|
||||
|
||||
# Test 2: Missing field
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 0.5}
|
||||
|
||||
# Test 3: Extra field but irrelevant for recall
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2, "c": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 1.0}
|
||||
|
||||
# Test 4: Completely different
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"x": 5}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 0.0}
|
||||
|
||||
# Test 5: Empty reference
|
||||
result = evaluator.evaluate_strings(prediction='{"a": 1}', reference="{}")
|
||||
assert result == {"score": 0}
|
||||
|
||||
|
||||
def test_json_equality_evaluator() -> None:
|
||||
evaluator = JsonEqualityEvaluator()
|
||||
|
||||
# Test 1: Exact match
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": True}
|
||||
|
||||
# Test 2: One missing key
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": False}
|
||||
|
||||
# Test 3: Different values for a key
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": False}
|
||||
|
||||
# Test 4: Extra field
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2, "c": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": False}
|
||||
|
||||
|
||||
def test_json_accuracy_evaluator() -> None:
|
||||
evaluator = JsonAccuracyEvaluator()
|
||||
|
||||
# Test 1: Exact match
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 1.0}
|
||||
|
||||
# Test 2: One incorrect value
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 0.5}
|
||||
|
||||
# Test 3: Additional keys don't affect accuracy
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2, "c": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result["score"] == pytest.approx(0.6667, abs=1e-4)
|
||||
|
||||
|
||||
def test_json_precision_evaluator() -> None:
|
||||
evaluator = JsonPrecisionEvaluator()
|
||||
|
||||
# Test 1: Exact match
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 1.0}
|
||||
|
||||
# Test 2: One incorrect value
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 0.5}
|
||||
|
||||
# Test 3: Additional key reduces precision
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2, "c": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result["score"] == pytest.approx(0.6667, abs=1e-4)
|
||||
|
||||
|
||||
def test_json_iou_evaluator() -> None:
|
||||
evaluator = JsonIoUEvaluator()
|
||||
|
||||
# Test 1: Exact match
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 1.0}
|
||||
|
||||
# Test 2: Partial overlap
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 0.5}
|
||||
|
||||
# Test 3: No overlap
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"x": 5}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 0.0}
|
||||
|
||||
|
||||
def test_json_f1_evaluator() -> None:
|
||||
evaluator = JsonF1Evaluator()
|
||||
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 1.0}
|
||||
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result["score"] == pytest.approx(0.6667, 0.001)
|
||||
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result["score"] == 0.5
|
||||
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"a": 1, "b": 2, "c": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result["score"] == 0.8
|
||||
|
||||
result = evaluator.evaluate_strings(
|
||||
prediction='{"c": 3}', reference='{"a": 1, "b": 2}'
|
||||
)
|
||||
assert result == {"score": 0}
|
||||
|
||||
result = evaluator.evaluate_strings(prediction="{}", reference='{"a": 1, "b": 2}')
|
||||
assert result == {"score": 0}
|
||||
|
||||
result = evaluator.evaluate_strings(prediction='{"a": 1, "b": 2}', reference="{}")
|
||||
assert result == {"score": 0}
|
||||
|
||||
@@ -40,7 +40,14 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
|
||||
EvaluatorType.LABELED_CRITERIA,
|
||||
EvaluatorType.LABELED_PAIRWISE_STRING,
|
||||
],
|
||||
[EvaluatorType.JSON_EQUALITY],
|
||||
[
|
||||
EvaluatorType.JSON_EQUALITY,
|
||||
EvaluatorType.JSON_ACCURACY,
|
||||
EvaluatorType.JSON_F1,
|
||||
EvaluatorType.JSON_IOU,
|
||||
EvaluatorType.JSON_PRECISION,
|
||||
EvaluatorType.JSON_RECALL,
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None:
|
||||
|
||||
Reference in New Issue
Block a user