Add Better Errors for Comparison Chain (#7033)

+ change to ABC - this lets us add things like the evaluation name for
loading
This commit is contained in:
William FH 2023-07-06 06:37:04 -07:00 committed by GitHub
parent e61cfb6e99
commit ec66d5188c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 358 additions and 55 deletions

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from pydantic import Field from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun, CallbackManagerForChainRun,
@ -186,10 +187,11 @@ The following is the expected answer. Use this to measure correctness:
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseChatModel, llm: BaseLanguageModel,
agent_tools: Optional[Sequence[BaseTool]] = None, agent_tools: Optional[Sequence[BaseTool]] = None,
output_parser: Optional[TrajectoryOutputParser] = None, output_parser: Optional[TrajectoryOutputParser] = None,
return_reasoning: bool = False, return_reasoning: bool = False,
**kwargs: Any,
) -> "TrajectoryEvalChain": ) -> "TrajectoryEvalChain":
"""Create a TrajectoryEvalChain object from a language model chain. """Create a TrajectoryEvalChain object from a language model chain.
@ -205,6 +207,10 @@ The following is the expected answer. Use this to measure correctness:
Returns: Returns:
TrajectoryEvalChain: The TrajectoryEvalChain object. TrajectoryEvalChain: The TrajectoryEvalChain object.
""" """
if not isinstance(llm, BaseChatModel):
raise NotImplementedError(
"Only chat models supported by the current trajectory eval"
)
if agent_tools: if agent_tools:
prompt = EVAL_CHAT_PROMPT prompt = EVAL_CHAT_PROMPT
else: else:
@ -215,6 +221,7 @@ The following is the expected answer. Use this to measure correctness:
return_reasoning=return_reasoning, return_reasoning=return_reasoning,
eval_chain=eval_chain, eval_chain=eval_chain,
output_parser=output_parser or TrajectoryOutputParser(), output_parser=output_parser or TrajectoryOutputParser(),
**kwargs,
) )
@property @property

View File

@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE
from langchain.evaluation.schema import PairwiseStringEvaluator
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
@ -50,7 +51,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
} }
class PairwiseStringEvalChain(LLMChain): class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMChain):
"""A chain for comparing the output of two models. """A chain for comparing the output of two models.
Example: Example:
@ -80,13 +81,31 @@ class PairwiseStringEvalChain(LLMChain):
default_factory=PairwiseStringResultOutputParser default_factory=PairwiseStringResultOutputParser
) )
@property
def requires_reference(self) -> bool:
return "reference" in self.prompt.input_variables
@property
def requires_input(self) -> bool:
return True
@property
def _skip_reference_warning(self) -> str:
"""Warning to show when reference is ignored."""
return (
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
"\nTo use a reference, initialize PairwiseStringEvalChain with"
" `requires_reference=True` or with a prompt with 'reference' as an"
" input variable."
)
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
*,
llm: BaseLanguageModel, llm: BaseLanguageModel,
*,
prompt: Optional[PromptTemplate] = None, prompt: Optional[PromptTemplate] = None,
require_reference: bool = False, requires_reference: bool = False,
**kwargs: Any, **kwargs: Any,
) -> PairwiseStringEvalChain: ) -> PairwiseStringEvalChain:
"""Initialize the PairwiseStringEvalChain from an LLM. """Initialize the PairwiseStringEvalChain from an LLM.
@ -94,7 +113,7 @@ class PairwiseStringEvalChain(LLMChain):
Args: Args:
llm (BaseLanguageModel): The LLM to use. llm (BaseLanguageModel): The LLM to use.
prompt (PromptTemplate, optional): The prompt to use. prompt (PromptTemplate, optional): The prompt to use.
require_reference (bool, optional): Whether to require a reference requires_reference (bool, optional): Whether to require a reference
string. Defaults to False. string. Defaults to False.
**kwargs (Any): Additional keyword arguments. **kwargs (Any): Additional keyword arguments.
@ -103,13 +122,13 @@ class PairwiseStringEvalChain(LLMChain):
""" """
expected_input_vars = {"prediction", "prediction_b", "input"} expected_input_vars = {"prediction", "prediction_b", "input"}
if prompt is None: if prompt is None:
if require_reference: if requires_reference:
expected_input_vars.add("reference") expected_input_vars.add("reference")
prompt_ = PROMPT_WITH_REFERENCE prompt_ = PROMPT_WITH_REFERENCE
else: else:
prompt_ = PROMPT prompt_ = PROMPT
else: else:
if require_reference: if requires_reference:
expected_input_vars.add("reference") expected_input_vars.add("reference")
prompt_ = prompt prompt_ = prompt
@ -121,23 +140,32 @@ class PairwiseStringEvalChain(LLMChain):
return cls(llm=llm, prompt=prompt_, **kwargs) return cls(llm=llm, prompt=prompt_, **kwargs)
def _prepare_input( def _prepare_input(
self, prediction: str, prediction_b: str, input: str, reference: Optional[str] self,
prediction: str,
prediction_b: str,
input: Optional[str],
reference: Optional[str],
) -> dict: ) -> dict:
input_ = { input_ = {
"prediction": prediction, "prediction": prediction,
"prediction_b": prediction_b, "prediction_b": prediction_b,
"input": input,
} }
if reference is not None and "reference" in self.prompt.input_variables: if self.requires_input:
if not input:
raise ValueError("Input is require for this comparison evaluator")
input_["input"] = input
if self.requires_reference:
if reference is None:
raise ValueError("Reference is required for this comparison evaluator")
input_["reference"] = reference input_["reference"] = reference
return input_ return input_
def evaluate_string_pairs( def _evaluate_string_pairs(
self, self,
*, *,
prediction: str, prediction: str,
prediction_b: str, prediction_b: str,
input: str, input: Optional[str] = None,
reference: Optional[str] = None, reference: Optional[str] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
@ -168,13 +196,13 @@ class PairwiseStringEvalChain(LLMChain):
) )
return result["text"] return result["text"]
async def aevaluate_string_pairs( async def _aevaluate_string_pairs(
self, self,
*, *,
prediction: str, prediction: str,
prediction_b: str, prediction_b: str,
input: str,
reference: Optional[str] = None, reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:

View File

@ -2,12 +2,13 @@ from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from pydantic import Field from pydantic import Extra, Field
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
from langchain.evaluation.schema import StringEvaluator
from langchain.schema import BaseOutputParser, BasePromptTemplate from langchain.schema import BaseOutputParser, BasePromptTemplate
_SUPPORTED_CRITERIA = { _SUPPORTED_CRITERIA = {
@ -59,7 +60,7 @@ CRITERIA_TYPE = Union[
] ]
class CriteriaEvalChain(LLMChain): class CriteriaEvalChain(StringEvaluator, LLMChain):
"""LLM Chain for evaluating runs against criteria. """LLM Chain for evaluating runs against criteria.
Parameters Parameters
@ -96,11 +97,32 @@ class CriteriaEvalChain(LLMChain):
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria) >>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
""" """
requires_reference: bool = False
"""Whether the evaluation template expects a reference text."""
output_parser: BaseOutputParser = Field(default_factory=CriteriaResultOutputParser) output_parser: BaseOutputParser = Field(default_factory=CriteriaResultOutputParser)
"""The parser to use to map the output to a structured result.""" """The parser to use to map the output to a structured result."""
class Config:
"""Configuration for the QAEvalChain."""
extra = Extra.ignore
@property
def requires_reference(self) -> bool:
return "reference" in self.prompt.input_variables
@property
def requires_input(self) -> bool:
return True
@property
def _skip_reference_warning(self) -> str:
"""Warning to show when reference is ignored."""
return (
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
"\nTo use a reference, initialize CriteriaEvalChain with"
" `require_reference=True` or with a prompt with 'reference'"
" as an input variable."
)
@staticmethod @staticmethod
def get_supported_default_criteria() -> List[str]: def get_supported_default_criteria() -> List[str]:
"""Get the list of supported default criteria. """Get the list of supported default criteria.
@ -122,7 +144,7 @@ class CriteriaEvalChain(LLMChain):
@classmethod @classmethod
def resolve_criteria( def resolve_criteria(
cls, cls,
criteria: CRITERIA_TYPE, criteria: Optional[CRITERIA_TYPE],
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Resolve the criteria to evaluate. """Resolve the criteria to evaluate.
@ -148,6 +170,10 @@ class CriteriaEvalChain(LLMChain):
{'relevance': 'Is the submission referring to a real quote from the text?', {'relevance': 'Is the submission referring to a real quote from the text?',
'coherence': 'Is the submission coherent, well-structured, and organized?'} 'coherence': 'Is the submission coherent, well-structured, and organized?'}
""" # noqa: E501 """ # noqa: E501
if criteria is None:
return {
"helpfulness": _SUPPORTED_CRITERIA["helpfulness"],
}
if isinstance(criteria, str): if isinstance(criteria, str):
criteria_ = {criteria: _SUPPORTED_CRITERIA[criteria]} criteria_ = {criteria: _SUPPORTED_CRITERIA[criteria]}
elif isinstance(criteria, ConstitutionalPrinciple): elif isinstance(criteria, ConstitutionalPrinciple):
@ -172,7 +198,7 @@ class CriteriaEvalChain(LLMChain):
def from_llm( def from_llm(
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
criteria: CRITERIA_TYPE, criteria: Optional[CRITERIA_TYPE] = None,
*, *,
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
requires_reference: bool = False, requires_reference: bool = False,
@ -184,7 +210,7 @@ class CriteriaEvalChain(LLMChain):
---------- ----------
llm : BaseLanguageModel llm : BaseLanguageModel
The language model to use for evaluation. The language model to use for evaluation.
criteria : CRITERIA_TYPE criteria : CRITERIA_TYPE - default=None for "helpfulness"
The criteria to evaluate the runs against. It can be: The criteria to evaluate the runs against. It can be:
- a mapping of criterion names to descriptions - a mapping of criterion names to descriptions
- a sequence of criterion names - a sequence of criterion names
@ -252,7 +278,7 @@ class CriteriaEvalChain(LLMChain):
input_["reference"] = reference input_["reference"] = reference
return input_ return input_
def evaluate_strings( def _evaluate_strings(
self, self,
*, *,
prediction: str, prediction: str,
@ -296,7 +322,7 @@ class CriteriaEvalChain(LLMChain):
input_ = self._get_eval_input(prediction, reference, input) input_ = self._get_eval_input(prediction, reference, input)
return self(input_, **kwargs)["text"] return self(input_, **kwargs)["text"]
async def aevaluate_strings( async def _aevaluate_strings(
self, self,
*, *,
prediction: str, prediction: str,

View File

@ -8,6 +8,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.evaluation.schema import StringEvaluator
def _parse_string_eval_output(text: str) -> dict: def _parse_string_eval_output(text: str) -> dict:
@ -38,9 +39,17 @@ def _parse_string_eval_output(text: str) -> dict:
} }
class QAEvalChain(LLMChain): class QAEvalChain(LLMChain, StringEvaluator):
"""LLM Chain specifically for evaluating question answering.""" """LLM Chain specifically for evaluating question answering."""
@property
def requires_reference(self) -> bool:
return True
@property
def requires_input(self) -> bool:
return True
@classmethod @classmethod
def from_llm( def from_llm(
cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any
@ -90,7 +99,7 @@ class QAEvalChain(LLMChain):
return self.apply(inputs, callbacks=callbacks) return self.apply(inputs, callbacks=callbacks)
def evaluate_strings( def _evaluate_strings(
self, self,
*, *,
prediction: str, prediction: str,
@ -118,7 +127,7 @@ class QAEvalChain(LLMChain):
)[0] )[0]
return _parse_string_eval_output(result["text"]) return _parse_string_eval_output(result["text"])
async def aevaluate_strings( async def _aevaluate_strings(
self, self,
*, *,
prediction: str, prediction: str,
@ -134,9 +143,17 @@ class QAEvalChain(LLMChain):
return _parse_string_eval_output(result["text"]) return _parse_string_eval_output(result["text"])
class ContextQAEvalChain(LLMChain): class ContextQAEvalChain(LLMChain, StringEvaluator):
"""LLM Chain specifically for evaluating QA w/o GT based on context""" """LLM Chain specifically for evaluating QA w/o GT based on context"""
@property
def requires_reference(self) -> bool:
return True
@property
def requires_input(self) -> bool:
return True
@classmethod @classmethod
def _validate_input_vars(cls, prompt: PromptTemplate) -> None: def _validate_input_vars(cls, prompt: PromptTemplate) -> None:
expected_input_vars = {"query", "context", "result"} expected_input_vars = {"query", "context", "result"}
@ -193,7 +210,7 @@ class ContextQAEvalChain(LLMChain):
return self.apply(inputs, callbacks=callbacks) return self.apply(inputs, callbacks=callbacks)
def evaluate_strings( def _evaluate_strings(
self, self,
*, *,
prediction: str, prediction: str,
@ -208,7 +225,7 @@ class ContextQAEvalChain(LLMChain):
)[0] )[0]
return _parse_string_eval_output(result["text"]) return _parse_string_eval_output(result["text"])
async def aevaluate_strings( async def _aevaluate_strings(
self, self,
*, *,
prediction: str, prediction: str,

View File

@ -1,14 +1,63 @@
"""Interfaces to be implemented by general evaluators.""" """Interfaces to be implemented by general evaluators."""
from abc import abstractmethod from __future__ import annotations
from typing import Any, Optional, Protocol, runtime_checkable
import logging
from abc import ABC, abstractmethod
from typing import Any, Optional
from warnings import warn
logger = logging.getLogger(__name__)
@runtime_checkable class _EvalArgsMixin:
class StringEvaluator(Protocol): """Mixin for checking evaluation arguments."""
@property
def requires_reference(self) -> bool:
"""Whether this evaluator requires a reference label."""
return False
@property
def requires_input(self) -> bool:
"""Whether this evaluator requires an input string."""
return False
@property
def _skip_input_warning(self) -> str:
"""Warning to show when input is ignored."""
return f"Ignoring input in {self.__class__.__name__}, as it is not expected."
@property
def _skip_reference_warning(self) -> str:
"""Warning to show when reference is ignored."""
return (
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
)
def _check_evaluation_args(
self,
reference: Optional[str] = None,
input: Optional[str] = None,
) -> None:
if self.requires_input and input is None:
raise ValueError(f"{self.__class__.__name__} requires an input string.")
elif input is not None and not self.requires_input:
warn(self._skip_input_warning)
else:
pass
if self.requires_reference and reference is None:
raise ValueError(f"{self.__class__.__name__} requires a reference string.")
elif reference is not None and not self.requires_reference:
warn(self._skip_reference_warning)
else:
pass
class StringEvaluator(_EvalArgsMixin, ABC):
"""Protocol for evaluating strings.""" """Protocol for evaluating strings."""
@abstractmethod @abstractmethod
def evaluate_strings( def _evaluate_strings(
self, self,
*, *,
prediction: str, prediction: str,
@ -28,7 +77,7 @@ class StringEvaluator(Protocol):
dict: The evaluation results containing the score or value. dict: The evaluation results containing the score or value.
""" """
async def aevaluate_strings( async def _aevaluate_strings(
self, self,
*, *,
prediction: str, prediction: str,
@ -53,13 +102,61 @@ class StringEvaluator(Protocol):
"async aevaluate_strings method." "async aevaluate_strings method."
) )
def evaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
**kwargs: Any,
) -> dict:
"""Evaluate Chain or LLM output, based on optional input and label.
@runtime_checkable Args:
class PairwiseStringEvaluator(Protocol): prediction (str): the LLM or chain prediction to evaluate.
reference (Optional[str], optional): the reference label
to evaluate against.
input (Optional[str], optional): the input to consider during evaluation
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
"""
self._check_evaluation_args(reference=reference, input=input)
return self._evaluate_strings(
prediction=prediction, reference=reference, input=input, **kwargs
)
async def aevaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate Chain or LLM output, based on optional
input and label.
Args:
prediction (str): the LLM or chain prediction to evaluate.
reference (Optional[str], optional): the reference label
to evaluate against.
input (Optional[str], optional): the input to consider during evaluation
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
"""
self._check_evaluation_args(reference=reference, input=input)
return await self._aevaluate_strings(
prediction=prediction, reference=reference, input=input, **kwargs
)
class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
"""A protocol for comparing the output of two models.""" """A protocol for comparing the output of two models."""
@abstractmethod @abstractmethod
def evaluate_string_pairs( def _evaluate_string_pairs(
self, self,
*, *,
prediction: str, prediction: str,
@ -84,8 +181,9 @@ class PairwiseStringEvaluator(Protocol):
other information. other information.
""" """
async def aevaluate_string_pairs( async def _aevaluate_string_pairs(
self, self,
*,
prediction: str, prediction: str,
prediction_b: str, prediction_b: str,
reference: Optional[str] = None, reference: Optional[str] = None,
@ -111,3 +209,69 @@ class PairwiseStringEvaluator(Protocol):
f"{self.__class__.__name__} hasn't implemented an async " f"{self.__class__.__name__} hasn't implemented an async "
"aevaluate_string_pairs method." "aevaluate_string_pairs method."
) )
def evaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
**kwargs: Any,
) -> dict:
"""Evaluate the output string pairs.
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
reference (str, optional): The expected output / reference
string. Defaults to None.
input (str, optional): The input string. Defaults to None.
**kwargs (Any): Additional keyword arguments, such
as callbacks and optional reference strings.
Returns:
dict: A dictionary containing the preference, scores, and/or
other information.
"""
self._check_evaluation_args(reference=reference, input=input)
return self._evaluate_string_pairs(
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
)
async def aevaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
**kwargs: Any,
) -> dict:
"""Evaluate the output string pairs.
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
reference (str, optional): The expected output / reference
string. Defaults to None.
input (str, optional): The input string. Defaults to None.
**kwargs (Any): Additional keyword arguments, such
as callbacks and optional reference strings.
Returns:
dict: A dictionary containing the preference, scores, and/or
other information.
"""
self._check_evaluation_args(reference=reference, input=input)
return await self._aevaluate_string_pairs(
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
)

View File

@ -1,13 +1,15 @@
"""Test agent trajectory evaluation chain.""" """Test agent trajectory evaluation chain."""
from typing import List, Tuple from typing import Any, Dict, List, Optional, Tuple
import pytest import pytest
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
from langchain.schema import AgentAction from langchain.schema import AgentAction, BaseMessage
from langchain.tools.base import tool from langchain.tools.base import tool
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_chat_model import FakeChatModel
@pytest.fixture @pytest.fixture
@ -30,10 +32,31 @@ def foo(bar: str) -> str:
return bar return bar
class _FakeTrajectoryChatModel(FakeChatModel):
queries: Dict = Field(default_factory=dict)
sequential_responses: Optional[bool] = False
response_index: int = 0
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.sequential_responses:
response = self.queries[list(self.queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response
else:
prompt = messages[0].content
return self.queries[prompt]
def test_trajectory_eval_chain( def test_trajectory_eval_chain(
intermediate_steps: List[Tuple[AgentAction, str]] intermediate_steps: List[Tuple[AgentAction, str]]
) -> None: ) -> None:
llm = FakeLLM( llm = _FakeTrajectoryChatModel(
queries={ queries={
"a": "Trajectory good\nScore: 5", "a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1", "b": "Trajectory not good\nScore: 1",
@ -61,7 +84,7 @@ def test_trajectory_eval_chain(
def test_trajectory_eval_chain_no_tools( def test_trajectory_eval_chain_no_tools(
intermediate_steps: List[Tuple[AgentAction, str]] intermediate_steps: List[Tuple[AgentAction, str]]
) -> None: ) -> None:
llm = FakeLLM( llm = _FakeTrajectoryChatModel(
queries={ queries={
"a": "Trajectory good\nScore: 5", "a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1", "b": "Trajectory not good\nScore: 1",
@ -85,7 +108,7 @@ def test_trajectory_eval_chain_no_tools(
def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> None: def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> None:
llm = FakeLLM( llm = _FakeTrajectoryChatModel(
queries={ queries={
"a": "Trajectory good\nScore: 5", "a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1", "b": "Trajectory not good\nScore: 1",

View File

@ -1,6 +1,8 @@
"""Test the comparison chains.""" """Test the comparison chains."""
import pytest
from langchain.evaluation.comparison.eval_chain import PairwiseStringEvalChain from langchain.evaluation.comparison.eval_chain import PairwiseStringEvalChain
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
@ -30,10 +32,30 @@ def test_pairwise_string_comparison_chain() -> None:
) )
assert res["value"] == "A" assert res["value"] == "A"
assert res["score"] == 1 assert res["score"] == 1
with pytest.warns(UserWarning, match=chain._skip_reference_warning):
res = chain.evaluate_string_pairs( res = chain.evaluate_string_pairs(
prediction="I like pie.", prediction="I like pie.",
prediction_b="I hate pie.", prediction_b="I hate pie.",
input="What is your favorite food?", input="What is your favorite food?",
reference="I enjoy pie.",
) )
assert res["value"] == "B" assert res["value"] == "B"
assert res["score"] == 0 assert res["score"] == 0
def test_pairwise_string_comparison_chain_missing_ref() -> None:
llm = FakeLLM(
queries={
"a": "The values are the same.\n[[C]]",
"b": "A is clearly better than b.\n[[A]]",
"c": "B is clearly better than a.\n[[B]]",
},
sequential_responses=True,
)
chain = PairwiseStringEvalChain.from_llm(llm=llm, requires_reference=True)
with pytest.raises(ValueError):
chain.evaluate_string_pairs(
prediction="I like pie.",
prediction_b="I love pie.",
input="What is your favorite food?",
)

View File

@ -1,6 +1,8 @@
"""Test the criteria eval chain.""" """Test the criteria eval chain."""
import pytest
from langchain.evaluation.criteria.eval_chain import ( from langchain.evaluation.criteria.eval_chain import (
_SUPPORTED_CRITERIA, _SUPPORTED_CRITERIA,
CriteriaEvalChain, CriteriaEvalChain,
@ -25,11 +27,25 @@ def test_criteria_eval_chain() -> None:
), ),
criteria={"my criterion": "my criterion description"}, criteria={"my criterion": "my criterion description"},
) )
with pytest.warns(UserWarning, match=chain._skip_reference_warning):
result = chain.evaluate_strings( result = chain.evaluate_strings(
prediction="my prediction", reference="my reference", input="my input" prediction="my prediction", reference="my reference", input="my input"
) )
assert result["reasoning"] == "The meaning of life" assert result["reasoning"] == "The meaning of life"
def test_criteria_eval_chain_missing_reference() -> None:
chain = CriteriaEvalChain.from_llm(
llm=FakeLLM(
queries={"text": "The meaning of life\nY"},
sequential_responses=True,
),
requires_reference=True,
criteria={"my criterion": "my criterion description"},
)
with pytest.raises(ValueError):
chain.evaluate_strings(prediction="my prediction", input="my input")
def test_implements_string_protocol() -> None: def test_implements_string_protocol() -> None:
assert isinstance(CriteriaEvalChain, StringEvaluator) assert issubclass(CriteriaEvalChain, StringEvaluator)

View File

@ -52,7 +52,7 @@ def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
def test_implements_string_evaluator_protocol( def test_implements_string_evaluator_protocol(
chain_cls: Type[LLMChain], chain_cls: Type[LLMChain],
) -> None: ) -> None:
assert isinstance(chain_cls, StringEvaluator) assert issubclass(chain_cls, StringEvaluator)
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain]) @pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])