Update String Evaluator (#6615)

- Add protocol for `evaluate_strings` 
- Move the criteria evaluator out so it's not restricted to being
applied on traced runs
This commit is contained in:
Zander Chase
2023-06-26 14:16:14 -07:00
committed by GitHub
parent b3f8324de9
commit c460b04c64
15 changed files with 1001 additions and 76 deletions

View File

@@ -0,0 +1,31 @@
"""Test the criteria eval chain."""
from langchain.evaluation.criteria.eval_chain import (
HELPFULNESS_CRITERION,
CriteriaEvalChain,
)
from langchain.evaluation.schema import StringEvaluator
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_resolve_criteria() -> None:
assert CriteriaEvalChain.resolve_criteria("helpfulness") == HELPFULNESS_CRITERION
assert CriteriaEvalChain.resolve_criteria(["helpfulness"]) == HELPFULNESS_CRITERION
def test_criteria_eval_chain() -> None:
chain = CriteriaEvalChain.from_llm(
llm=FakeLLM(
queries={"text": "The meaning of life\nY"}, sequential_responses=True
),
criteria={"my criterion": "my criterion description"},
)
result = chain.evaluate_strings(
prediction="my prediction", reference="my reference", input="my input"
)
assert result["reasoning"] == "The meaning of life"
def test_implements_string_protocol() -> None:
assert isinstance(CriteriaEvalChain, StringEvaluator)

View File

@@ -4,11 +4,13 @@ from typing import Type
import pytest
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_chain import (
ContextQAEvalChain,
CotQAEvalChain,
QAEvalChain,
)
from langchain.evaluation.schema import StringEvaluator
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -44,3 +46,24 @@ def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
assert outputs[0] == outputs[1]
assert "text" in outputs[0]
assert outputs[0]["text"] == "foo"
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])
def test_implements_string_evaluator_protocol(
chain_cls: Type[LLMChain],
) -> None:
assert isinstance(chain_cls, StringEvaluator)
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])
def test_returns_expected_results(
chain_cls: Type[LLMChain],
) -> None:
fake_llm = FakeLLM(
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
)
chain = chain_cls.from_llm(fake_llm) # type: ignore
results = chain.evaluate_strings(
prediction="my prediction", reference="my reference", input="my input"
)
assert results["score"] == 1

View File

@@ -0,0 +1,54 @@
"""Test run evaluator implementations basic functionality."""
from uuid import UUID
import pytest
from langchainplus_sdk.schemas import Example, Run
from langchain.evaluation.run_evaluators import get_criteria_evaluator, get_qa_evaluator
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture
def run() -> Run:
return Run(
id=UUID("f77cd087-48f7-4c62-9e0e-297842202107"),
name="My Run",
inputs={"input": "What is the answer to life, the universe, and everything?"},
outputs={"output": "The answer is 42."},
start_time="2021-07-20T15:00:00.000000+00:00",
end_time="2021-07-20T15:00:00.000000+00:00",
run_type="chain",
execution_order=1,
)
@pytest.fixture
def example() -> Example:
return Example(
id=UUID("f77cd087-48f7-4c62-9e0e-297842202106"),
dataset_id=UUID("f77cd087-48f7-4c62-9e0e-297842202105"),
inputs={"input": "What is the answer to life, the universe, and everything?"},
outputs={"output": "The answer is 42."},
created_at="2021-07-20T15:00:00.000000+00:00",
)
def test_get_qa_evaluator(run: Run, example: Example) -> None:
"""Test get_qa_evaluator."""
eval_llm = FakeLLM(
queries={"a": "This checks out.\nCORRECT"}, sequential_responses=True
)
qa_evaluator = get_qa_evaluator(eval_llm)
res = qa_evaluator.evaluate_run(run, example)
assert res.value == "CORRECT"
assert res.score == 1
def test_get_criteria_evaluator(run: Run, example: Example) -> None:
"""Get a criteria evaluator."""
eval_llm = FakeLLM(queries={"a": "This checks out.\nY"}, sequential_responses=True)
criteria_evaluator = get_criteria_evaluator(eval_llm, criteria="conciseness")
res = criteria_evaluator.evaluate_run(run, example)
assert res.value == "Y"
assert res.score == 1

View File

@@ -1,5 +1,6 @@
"""Test the BaseOutputParser class and its sub-classes."""
from abc import ABC
from collections import defaultdict
from typing import List, Optional, Set, Type
import pytest
@@ -42,12 +43,12 @@ def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
def test_all_subclasses_implement_unique_type() -> None:
types = []
types = defaultdict(list)
for cls in _NON_ABSTRACT_PARSERS:
try:
types.append(cls._type)
types[cls._type].append(cls.__name__)
except NotImplementedError:
# This is handled in the previous test
pass
dups = set([t for t in types if types.count(t) > 1])
dups = {t: names for t, names in types.items() if len(names) > 1}
assert not dups, f"Duplicate types: {dups}"