Compare commits

...

14 Commits

Author SHA1 Message Date
Ankush Gola
0d8df345f5 add interface for evaluating messages 2023-07-06 22:37:31 -07:00
Nuno Campos
1c650f98a8 Small updates to run_evaluator after adding dataset.data_type 2023-07-06 15:31:18 +01:00
William Fu-Hinthorn
8db65f7434 Merge branch 'wfh/load_evals' into wfh/to_run_evaluator 2023-07-03 14:50:07 -07:00
William Fu-Hinthorn
1b35b29a42 Update Loader 2023-07-03 14:49:40 -07:00
William Fu-Hinthorn
13c6783e8f Add Better Errors for Comparison Chain 2023-07-03 14:48:01 -07:00
William Fu-Hinthorn
c650238d9c Merge branch 'wfh/re-add-qa-reasoning-handling' into wfh/to_run_evaluator 2023-07-03 14:41:31 -07:00
William Fu-Hinthorn
43edaff075 Accept no 'reasoning' response in qa evaluator 2023-07-03 14:39:24 -07:00
William Fu-Hinthorn
e8a4e0b144 Merge branch 'wfh/re-add-errors' into wfh/to_run_evaluator 2023-07-03 14:36:59 -07:00
William Fu-Hinthorn
29876609d0 Log errors 2023-07-03 14:34:51 -07:00
William Fu-Hinthorn
1e94cd60a2 Merge 2023-07-03 14:14:13 -07:00
William Fu-Hinthorn
6fd701df88 Merge 2023-07-03 14:12:48 -07:00
William Fu-Hinthorn
f3de5c4f42 Add Better Errors for Comparison Chain 2023-07-03 11:31:36 -07:00
William Fu-Hinthorn
97841f4cfd docs 2023-07-03 11:30:03 -07:00
William Fu-Hinthorn
8b385861a2 Add evaluator loader 2023-07-03 11:30:01 -07:00
18 changed files with 1180 additions and 53 deletions

View File

@@ -1,29 +1,41 @@
"""Functionality relating to evaluation.
"""Evaluation chains for grading LLM and Chain outputs.
This module contains off-the-shelf evaluation chains for
grading the output of LangChain primitives such as LLMs and Chains.
This module contains off-the-shelf evaluation chains for grading the output of
LangChain primitives such as language models and chains.
To load an evaluator, you can use the :func:`load_evaluators <langchain.evaluation.loading.load_evaluators>` function with the
name of the evaluator to load.
To load one of the LangChain HuggingFace datasets, you can use the :func:`load_dataset <langchain.evaluation.loading.load_dataset>` function with the
name of the dataset to load.
Some common use cases for evaluation include:
- Grading accuracy of a response against ground truth answers: QAEvalChain
- Comparing the output of two models: PairwiseStringEvalChain
- Judging the efficacy of an agent's tool usage: TrajectoryEvalChain
- Checking whether an output complies with a set of criteria: CriteriaEvalChain
- Grading the accuracy of a response against ground truth answers: :class:`QAEvalChain <langchain.evaluation.qa.eval_chain.QAEvalChain>`
- Comparing the output of two models: :class:`PairwiseStringEvalChain <langchain.evaluation.comparison.eval_chain.PairwiseStringEvalChain>`
- Judging the efficacy of an agent's tool usage: :class:`TrajectoryEvalChain <langchain.evaluation.agents.trajectory_eval_chain.TrajectoryEvalChain>`
- Checking whether an output complies with a set of criteria: :class:`CriteriaEvalChain <langchain.evaluation.criteria.eval_chain.CriteriaEvalChain>`
This module also contains low level APIs for making more evaluators for your
custom evaluation task. These include:
- StringEvaluator: Evaluates an output string against a reference and/or
with input context.
- PairwiseStringEvaluator: Evaluates two strings against each other.
"""
This module also contains low-level APIs for creating custom evaluators for
specific evaluation tasks. These include:
- :class:`StringEvaluator <langchain.evaluation.schema.StringEvaluator>`: Evaluates an output string against a reference and/or input context.
- :class:`PairwiseStringEvaluator <langchain.evaluation.schema.PairwiseStringEvaluator>`: Evaluates two strings against each other.
""" # noqa: E501
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
from langchain.evaluation.comparison import PairwiseStringEvalChain
from langchain.evaluation.comparison.eval_chain import PairwiseStringEvalChain
from langchain.evaluation.criteria.eval_chain import CriteriaEvalChain
from langchain.evaluation.loading import load_dataset, load_evaluators
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
from langchain.evaluation.schema import (
EvaluatorType,
PairwiseStringEvaluator,
StringEvaluator,
)
__all__ = [
"EvaluatorType",
"PairwiseStringEvalChain",
"QAEvalChain",
"CotQAEvalChain",
@@ -32,4 +44,6 @@ __all__ = [
"PairwiseStringEvaluator",
"TrajectoryEvalChain",
"CriteriaEvalChain",
"load_evaluators",
"load_dataset",
]

View File

@@ -7,20 +7,21 @@ chain (LLMChain) to generate the reasoning and scores.
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from pydantic import Field
from pydantic import Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.evaluation.agents.trajectory_eval_prompt import (
EVAL_CHAT_PROMPT,
TOOL_FREE_EVAL_CHAT_PROMPT,
)
from langchain.evaluation.schema import EvalChain
from langchain.schema import AgentAction, BaseOutputParser, OutputParserException
from langchain.tools.base import BaseTool
@@ -69,7 +70,7 @@ class TrajectoryOutputParser(BaseOutputParser):
return TrajectoryEval(score=int(score_str), reasoning=reasoning)
class TrajectoryEvalChain(Chain):
class TrajectoryEvalChain(EvalChain):
"""A chain for evaluating ReAct style agents.
This chain is used to evaluate ReAct style agents by reasoning about
@@ -123,6 +124,11 @@ class TrajectoryEvalChain(Chain):
return_reasoning: bool = False
"""Whether to return the reasoning along with the score."""
class Config:
"""Configuration for the QAEvalChain."""
extra = Extra.ignore
@property
def _tools_description(self) -> str:
"""Get the description of the agent tools.
@@ -186,10 +192,11 @@ The following is the expected answer. Use this to measure correctness:
@classmethod
def from_llm(
cls,
llm: BaseChatModel,
llm: BaseLanguageModel,
agent_tools: Optional[Sequence[BaseTool]] = None,
output_parser: Optional[TrajectoryOutputParser] = None,
return_reasoning: bool = False,
**kwargs: Any,
) -> "TrajectoryEvalChain":
"""Create a TrajectoryEvalChain object from a language model chain.
@@ -205,6 +212,10 @@ The following is the expected answer. Use this to measure correctness:
Returns:
TrajectoryEvalChain: The TrajectoryEvalChain object.
"""
if not isinstance(llm, BaseChatModel):
raise NotImplementedError(
"Only chat models supported by the current trajectory eval"
)
if agent_tools:
prompt = EVAL_CHAT_PROMPT
else:
@@ -215,6 +226,7 @@ The following is the expected answer. Use this to measure correctness:
return_reasoning=return_reasoning,
eval_chain=eval_chain,
output_parser=output_parser or TrajectoryOutputParser(),
**kwargs,
)
@property

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from typing import Any, Optional
from pydantic import Field
from pydantic import Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE
from langchain.evaluation.schema import EvalChain, PairwiseStringEvaluator
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
@@ -50,7 +51,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
}
class PairwiseStringEvalChain(LLMChain):
class PairwiseStringEvalChain(PairwiseStringEvaluator, EvalChain, LLMChain):
"""A chain for comparing the output of two models.
Example:
@@ -80,11 +81,16 @@ class PairwiseStringEvalChain(LLMChain):
default_factory=PairwiseStringResultOutputParser
)
class Config:
"""Configuration for the QAEvalChain."""
extra = Extra.ignore
@classmethod
def from_llm(
cls,
*,
llm: BaseLanguageModel,
*,
prompt: Optional[PromptTemplate] = None,
require_reference: bool = False,
**kwargs: Any,
@@ -121,14 +127,23 @@ class PairwiseStringEvalChain(LLMChain):
return cls(llm=llm, prompt=prompt_, **kwargs)
def _prepare_input(
self, prediction: str, prediction_b: str, input: str, reference: Optional[str]
self,
prediction: str,
prediction_b: str,
input: Optional[str],
reference: Optional[str],
) -> dict:
input_ = {
"prediction": prediction,
"prediction_b": prediction_b,
"input": input,
}
if reference is not None and "reference" in self.prompt.input_variables:
if "input" in self.prompt.input_variables:
if not input:
raise ValueError("Input is require for this comparison evaluator")
input_["input"] = input
if "reference" in self.prompt.input_variables:
if reference is None:
raise ValueError("Reference is required for this comparison evaluator")
input_["reference"] = reference
return input_
@@ -137,7 +152,7 @@ class PairwiseStringEvalChain(LLMChain):
*,
prediction: str,
prediction_b: str,
input: str,
input: Optional[str] = None,
reference: Optional[str] = None,
callbacks: Callbacks = None,
**kwargs: Any,
@@ -173,8 +188,8 @@ class PairwiseStringEvalChain(LLMChain):
*,
prediction: str,
prediction_b: str,
input: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> dict:

View File

@@ -2,12 +2,13 @@ from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from pydantic import Field
from pydantic import Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
from langchain.evaluation.schema import EvalChain, StringEvaluator
from langchain.schema import BaseOutputParser, BasePromptTemplate
_SUPPORTED_CRITERIA = {
@@ -59,7 +60,7 @@ CRITERIA_TYPE = Union[
]
class CriteriaEvalChain(LLMChain):
class CriteriaEvalChain(StringEvaluator, EvalChain, LLMChain):
"""LLM Chain for evaluating runs against criteria.
Parameters
@@ -96,10 +97,30 @@ class CriteriaEvalChain(LLMChain):
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
"""
requires_reference: bool = False
"""Whether the evaluation template expects a reference text."""
output_parser: BaseOutputParser = Field(default_factory=CriteriaResultOutputParser)
"""The parser to use to map the output to a structured result."""
criteria_names: List[str] = Field(default_factory=list)
class Config:
"""Configuration for the QAEvalChain."""
extra = Extra.ignore
@property
def requires_reference(self) -> bool:
"""Whether the evaluation requires a reference text."""
return "reference" in self.prompt.input_variables
@property
def evaluation_name(self) -> str:
"""Get the name of the evaluation.
Returns
-------
str
The name of the evaluation.
"""
return " ".join(self.criteria_names)
@staticmethod
def get_supported_default_criteria() -> List[str]:
@@ -122,7 +143,7 @@ class CriteriaEvalChain(LLMChain):
@classmethod
def resolve_criteria(
cls,
criteria: CRITERIA_TYPE,
criteria: Optional[CRITERIA_TYPE],
) -> Dict[str, str]:
"""Resolve the criteria to evaluate.
@@ -148,6 +169,10 @@ class CriteriaEvalChain(LLMChain):
{'relevance': 'Is the submission referring to a real quote from the text?',
'coherence': 'Is the submission coherent, well-structured, and organized?'}
""" # noqa: E501
if criteria is None:
return {
"helpfulness": _SUPPORTED_CRITERIA["helpfulness"],
}
if isinstance(criteria, str):
criteria_ = {criteria: _SUPPORTED_CRITERIA[criteria]}
elif isinstance(criteria, ConstitutionalPrinciple):
@@ -172,7 +197,7 @@ class CriteriaEvalChain(LLMChain):
def from_llm(
cls,
llm: BaseLanguageModel,
criteria: CRITERIA_TYPE,
criteria: Optional[CRITERIA_TYPE] = None,
*,
prompt: Optional[BasePromptTemplate] = None,
requires_reference: bool = False,
@@ -184,7 +209,7 @@ class CriteriaEvalChain(LLMChain):
----------
llm : BaseLanguageModel
The language model to use for evaluation.
criteria : CRITERIA_TYPE
criteria : CRITERIA_TYPE - default=None for "helpfulness"
The criteria to evaluate the runs against. It can be:
- a mapping of criterion names to descriptions
- a sequence of criterion names
@@ -231,10 +256,14 @@ class CriteriaEvalChain(LLMChain):
else:
prompt = PROMPT
criteria_ = cls.resolve_criteria(criteria)
criteria_names = list(criteria_.keys())
criteria_str = " ".join(f"{k}: {v}" for k, v in criteria_.items())
prompt_ = prompt.partial(criteria=criteria_str)
return cls(
llm=llm, prompt=prompt_, requires_reference=requires_reference, **kwargs
llm=llm,
prompt=prompt_,
criteria_names=criteria_names,
**kwargs,
)
def _get_eval_input(

View File

@@ -1,8 +1,107 @@
from typing import Dict, List
"""Loading datasets and evaluators."""
from typing import Any, Dict, List, Optional, Sequence, Type
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chat_models.openai import ChatOpenAI
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
from langchain.evaluation.comparison import PairwiseStringEvalChain
from langchain.evaluation.criteria.eval_chain import CriteriaEvalChain
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
from langchain.evaluation.schema import EvalChain, EvaluatorType
def load_dataset(uri: str) -> List[Dict]:
"""Load a dataset from the LangChainDatasets HuggingFace org."""
from datasets import load_dataset
dataset = load_dataset(f"LangChainDatasets/{uri}")
return [d for d in dataset["train"]]
_EVALUATOR_MAP: Dict[EvaluatorType, Type[EvalChain]] = {
EvaluatorType.QA: QAEvalChain,
EvaluatorType.COT_QA: CotQAEvalChain,
EvaluatorType.CONTEXT_QA: ContextQAEvalChain,
EvaluatorType.PAIRWISE_STRING: PairwiseStringEvalChain,
EvaluatorType.AGENT_TRAJECTORY: TrajectoryEvalChain,
EvaluatorType.CRITERIA: CriteriaEvalChain,
}
def _load_evaluator(
evaluator: EvaluatorType,
*,
llm: Optional[BaseLanguageModel] = None,
**kwargs: Any,
) -> Chain:
"""Load the requested evaluation chain specified by a string.
Parameters
----------
evaluator : EvaluatorType
The type of evaluator to load.
llm : BaseLanguageModel, optional
The language model to use for evaluation, by default None
**kwargs : Any
Additional keyword arguments to pass to the evaluator.
Returns
-------
Chain
The loaded evaluation chain.
Examples
--------
>>> llm = ChatOpenAI(model="gpt-4", temperature=0)
>>> evaluator = _load_evaluator("qa", llm=llm)
"""
llm = llm or ChatOpenAI(model="gpt-4", temperature=0)
if evaluator not in _EVALUATOR_MAP:
raise ValueError(
f"Unknown evaluator type: {evaluator}"
f"Valid types are: {list(_EVALUATOR_MAP.keys())}"
)
return _EVALUATOR_MAP[evaluator].from_llm(llm=llm, **kwargs)
def load_evaluators(
evaluators: Sequence[EvaluatorType],
*,
llm: Optional[BaseLanguageModel] = None,
config: Optional[dict] = None,
**kwargs: Any,
) -> List[Chain]:
"""Load evaluators specified by a list of evaluator types.
Parameters
----------
evaluators : Sequence[EvaluatorType]
The list of evaluator types to load.
llm : BaseLanguageModel, optional
The language model to use for evaluation, if none is provided, a default
ChatOpenAI gpt-4 model will be used.
config : dict, optional
A dictionary mapping evaluator types to additional keyword arguments,
by default None
**kwargs : Any
Additional keyword arguments to pass to all evaluators.
Returns
-------
List[Chain]
The loaded evaluators.
Examples
--------
.. code-block:: python
from langchain.evaluation import load_evaluators, EvaluatorType
evaluators = [EvaluatorType.QA, EvaluatorType.CRITERIA]
loaded_evaluators = load_evaluators(evaluators, criteria="helpfulness")
"""
llm = llm or ChatOpenAI(model="gpt-4", temperature=0)
loaded = []
for evaluator in evaluators:
_kwargs = config.get(evaluator, {}) if config else {}
loaded.append(_load_evaluator(evaluator, llm=llm, **{**kwargs, **_kwargs}))
return loaded

View File

@@ -3,11 +3,14 @@ from __future__ import annotations
from typing import Any, List, Optional, Sequence
from pydantic import Extra
from langchain import PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.evaluation.schema import EvalChain, StringEvaluator
def _parse_string_eval_output(text: str) -> dict:
@@ -38,9 +41,22 @@ def _parse_string_eval_output(text: str) -> dict:
}
class QAEvalChain(LLMChain):
class QAEvalChain(LLMChain, StringEvaluator, EvalChain):
"""LLM Chain specifically for evaluating question answering."""
class Config:
"""Configuration for the QAEvalChain."""
extra = Extra.ignore
@property
def evaluation_name(self) -> str:
return "correctness"
@property
def requires_reference(self) -> bool:
return True
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any
@@ -134,7 +150,7 @@ class QAEvalChain(LLMChain):
return _parse_string_eval_output(result["text"])
class ContextQAEvalChain(LLMChain):
class ContextQAEvalChain(LLMChain, StringEvaluator, EvalChain):
"""LLM Chain specifically for evaluating QA w/o GT based on context"""
@classmethod
@@ -146,6 +162,10 @@ class ContextQAEvalChain(LLMChain):
f"but got {prompt.input_variables}"
)
@property
def evaluation_name(self) -> str:
return "Contextual Accuracy"
@classmethod
def from_llm(
cls,
@@ -226,6 +246,10 @@ class ContextQAEvalChain(LLMChain):
class CotQAEvalChain(ContextQAEvalChain):
"""LLM Chain specifically for evaluating QA using chain of thought reasoning."""
@property
def evaluation_name(self) -> str:
return "COT Contextual Accuracy"
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any

View File

@@ -11,6 +11,9 @@ from langchain.evaluation.run_evaluators.implementations import (
get_qa_evaluator,
get_trajectory_evaluator,
)
from langchain.evaluation.run_evaluators.string_run_evaluator import (
StringRunEvaluatorChain,
)
__all__ = [
"RunEvaluatorChain",
@@ -21,4 +24,5 @@ __all__ = [
"get_trajectory_evaluator",
"StringRunEvaluatorInputMapper",
"ChoicesOutputParser",
"StringRunEvaluatorChain",
]

View File

@@ -21,6 +21,10 @@ class RunEvaluatorInputMapper:
def map(self, run: Run, example: Optional[Example] = None) -> Dict[str, Any]:
"""Maps the Run and Optional[Example] to a dictionary"""
def __call__(self, run: Run, example: Optional[Example] = None) -> Any:
"""Maps the Run and Optional[Example] to a dictionary"""
return self.map(run, example)
class RunEvaluatorOutputParser(BaseOutputParser[EvaluationResult]):
"""Parse the output of a run."""

View File

@@ -0,0 +1,69 @@
""""Loading helpers for run evaluators."""
from typing import Any, List, Optional, Sequence, Union
from langchainplus_sdk import RunEvaluator
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.evaluation.loading import load_evaluators
from langchain.evaluation.run_evaluators.string_run_evaluator import (
StringRunEvaluatorChain,
)
from langchain.evaluation.schema import EvaluatorType, StringEvaluator
from langchain.tools.base import Tool
def load_run_evaluators_for_model(
evaluators: Sequence[EvaluatorType],
model: Union[Chain, BaseLanguageModel, Tool],
*,
input_key: Optional[str] = None,
prediction_key: Optional[str] = None,
reference_key: Optional[str] = None,
eval_llm: Optional[BaseLanguageModel] = None,
**kwargs: Any,
) -> List[RunEvaluator]:
"""Load evaluators specified by a list of evaluator types.
Parameters
----------
evaluators : Sequence[EvaluatorType]
The list of evaluator types to load.
model : Union[Chain, BaseLanguageModel, Tool]
The model to evaluate. Used to infer how to parse the run.
input_key : Optional[str], a chain run's input key to map
to the evaluator's input
prediction_key : Optional[str], the key in the run's outputs to
represent the Chain prediction
reference_key : Optional[str], the key in the dataset example (row)
outputs to represent the reference, or ground-truth label
eval_llm : BaseLanguageModel, optional
The language model to use for evaluation, if none is provided, a default
ChatOpenAI gpt-4 model will be used.
**kwargs : Any
Additional keyword arguments to pass to all evaluators.
Returns
-------
List[RunEvaluator]
The loaded Run evaluators.
"""
evaluators_ = load_evaluators(evaluators, llm=eval_llm, **kwargs)
run_evaluators = []
for evaluator in evaluators_:
if isinstance(evaluator, StringEvaluator):
run_evaluator = StringRunEvaluatorChain.from_model_and_evaluator(
model,
evaluator,
input_key=input_key,
prediction_key=prediction_key,
reference_key=reference_key,
)
else:
raise NotImplementedError(
f"Run evaluator for {evaluator} is not implemented"
)
run_evaluators.append(run_evaluator)
return run_evaluators

View File

@@ -0,0 +1,138 @@
"""Run evaluator mapper for message evaluators."""
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union, TypedDict
from langchain.schema import BaseMessage
from langchainplus_sdk.schemas import Example, Run
from langchain.load.serializable import Serializable
from langchain.schema import messages_from_dict
class RunMapping(TypedDict):
prediction: BaseMessage
input: List[BaseMessage]
class ExampleMapping(TypedDict):
reference: BaseMessage
class MessageRunMapper(Serializable):
"""Extract items to evaluate from run object"""
@property
def output_keys(self) -> List[str]:
"""The keys to extract from the run."""
return ["prediction", "input"]
@abstractmethod
def map(self, run: Run) -> RunMapping:
"""Maps the Run to a dictionary."""
def __call__(self, run: Run) -> RunMapping:
"""Maps the Run to a dictionary."""
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
return self.map(run)
class MessageExampleMapper(Serializable):
"""Map an example, or row in the dataset, to the inputs of an evaluation."""
reference_key: Optional[str] = None
@property
def output_keys(self) -> List[str]:
"""The keys to extract from the run."""
return ["reference"]
def map(self, example: Example) -> ExampleMapping:
"""Maps the Example, or dataset row to a dictionary."""
if not example.outputs:
raise ValueError(
f"Example {example.id} has no outputs to use as a reference."
)
if self.reference_key is None:
if len(example.outputs) > 1:
raise ValueError(
f"Example {example.id} has multiple outputs, so you must"
" specify a reference_key."
)
else:
output = list(example.outputs.values())[0]
return {
"reference": output if isinstance(output, BaseMessage) else messages_from_dict([output])[0]
}
elif self.reference_key not in example.outputs:
raise ValueError(
f"Example {example.id} does not have reference key"
f" {self.reference_key}."
)
output = example.outputs[self.reference_key]
return {"reference": output if isinstance(output, BaseMessage) else messages_from_dict([output])[0]}
def __call__(self, example: Example) -> ExampleMapping:
"""Maps the Run and Example to a dictionary."""
if not example.outputs:
raise ValueError(
f"Example {example.id} has no outputs to use as areference label."
)
return self.map(example)
class ChatModelMessageRunMapper(MessageRunMapper):
"""Extract items to evaluate from run object."""
@staticmethod
def extract_inputs(inputs: Dict) -> List[BaseMessage]:
if not inputs.get("messages"):
raise ValueError("Run must have messages as inputs.")
if "messages" in inputs:
if isinstance(inputs["messages"], list) and inputs["messages"]:
if isinstance(inputs["messages"][0], BaseMessage):
return messages_from_dict(inputs["messages"])
elif isinstance(inputs["messages"][0], list):
# Runs from Tracer have messages as a list of lists of dicts
return messages_from_dict(inputs["messages"][0])
raise ValueError(f"Could not extract messages from inputs: {inputs}")
@staticmethod
def extract_outputs(outputs: Dict) -> BaseMessage:
if not outputs.get("generations"):
raise ValueError("LLM Run must have generations as outputs.")
first_generation: Dict = outputs["generations"][0]
if isinstance(first_generation, list):
# Runs from Tracer have generations as a list of lists of dicts
# Whereas Runs from the API have a list of dicts
first_generation = first_generation[0]
if "message" in first_generation:
return messages_from_dict([first_generation["message"]])[0]
def map(self, run: Run) -> RunMapping:
"""Maps the Run to a dictionary."""
if run.run_type != "llm":
raise ValueError("ChatModel RunMapper only supports LangSmith runs of type llm.")
elif not run.outputs:
if run.error:
raise ValueError(
f"Cannot evaluate errored LLM run {run.id}: {run.error}"
)
else:
raise ValueError(
f"Run {run.id} has no outputs. Cannot evaluate this run."
)
else:
try:
inputs = self.extract_inputs(run.inputs)
except Exception as e:
raise ValueError(
f"Could not parse LM input from run inputs {run.inputs}"
) from e
try:
output_ = self.extract_outputs(run.outputs)
except Exception as e:
raise ValueError(
f"Could not parse LM prediction from run outputs {run.outputs}"
) from e
return {"input": inputs, "prediction": output_}

View File

@@ -0,0 +1,392 @@
"""Run evaluator wrapper for string evaluators."""
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union, Protocol
from langchainplus_sdk import EvaluationResult, RunEvaluator
from langchainplus_sdk.schemas import Example, Run
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.evaluation.schema import StringEvaluator, MessageEvaluator
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, get_buffer_string, messages_from_dict
from langchain.tools.base import Tool
class StringRunMapper(Serializable):
"""Extract items to evaluate from the run object."""
@property
def output_keys(self) -> List[str]:
"""The keys to extract from the run."""
return ["prediction", "input"]
@abstractmethod
def map(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
def __call__(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
return self.map(run)
class LLMStringRunMapper(StringRunMapper):
"""Extract items to evaluate from the run object."""
def serialize_chat_messages(self, messages: List[Dict]) -> str:
"""Extract the input messages from the run."""
chat_messages = messages_from_dict(messages)
return get_buffer_string(chat_messages)
def serialize_inputs(self, inputs: Dict) -> str:
if "prompts" in inputs: # Should we even accept this?
input_ = "\n\n".join(inputs["prompts"])
elif "prompt" in inputs:
input_ = inputs["prompt"]
elif "messages" in inputs:
input_ = self.serialize_chat_messages(inputs["messages"])
else:
raise ValueError("LLM Run must have either messages or prompts as inputs.")
return input_
def serialize_outputs(self, outputs: Dict) -> str:
if not outputs.get("generations"):
raise ValueError("LLM Run must have generations as outputs.")
first_generation: Dict = outputs["generations"][0]
if isinstance(first_generation, list):
# Runs from Tracer have generations as a list of lists of dicts
# Whereas Runs from the API have a list of dicts
first_generation = first_generation[0]
if "message" in first_generation:
output_ = self.serialize_chat_messages([first_generation["message"]])
else:
output_ = first_generation["text"]
return output_
def map(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
if run.run_type != "llm":
raise ValueError("LLM RunMapper only supports LLM runs.")
elif not run.outputs:
if run.error:
raise ValueError(
f"Cannot evaluate errored LLM run {run.id}: {run.error}"
)
else:
raise ValueError(
f"Run {run.id} has no outputs. Cannot evaluate this run."
)
else:
try:
inputs = self.serialize_inputs(run.inputs)
except Exception as e:
raise ValueError(
f"Could not parse LM input from run inputs {run.inputs}"
) from e
try:
output_ = self.serialize_outputs(run.outputs)
except Exception as e:
raise ValueError(
f"Could not parse LM prediction from run outputs {run.outputs}"
) from e
return {"input": inputs, "prediction": output_}
class ChainStringRunMapper(StringRunMapper):
"""Extract items to evaluate from the run object from a chain."""
input_key: str
"""The key from the chain Run's inputs to use as the eval input."""
prediction_key: str
"""The key from the chain Run's outputs to use as the eval prediction."""
@classmethod
def from_chain(
cls,
model: Chain,
input_key: Optional[str] = None,
prediction_key: Optional[str] = None,
) -> ChainStringRunMapper:
"""Create a RunMapper from a chain."""
error_messages = []
if input_key is None:
if len(model.input_keys) > 1:
error_messages.append(
f"Chain {model.lc_namespace} has multiple input"
" keys. Please specify 'input_key' when loading."
)
else:
input_key = model.input_keys[0]
elif input_key not in model.input_keys:
error_messages.append(
f"Chain {model.lc_namespace} does not have specified"
f" input key {input_key}."
)
if prediction_key is None:
if len(model.output_keys) > 1:
error_messages.append(
f"Chain {model.lc_namespace} has multiple"
" output keys. Please specify 'prediction_key' when loading."
)
else:
prediction_key = model.output_keys[0]
elif prediction_key not in model.output_keys:
error_messages.append(
f"Chain {model.lc_namespace} does not have specified"
f" prediction_key {prediction_key}."
)
if error_messages:
raise ValueError("\n".join(error_messages))
if input_key is None or prediction_key is None:
# This should never happen, but mypy doesn't know that.
raise ValueError(f"Chain {model.lc_namespace} has no input or output keys.")
return cls(input_key=input_key, prediction_key=prediction_key)
def map(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
if run.run_type != "chain":
raise ValueError("Chain RunMapper only supports Chain runs.")
if self.input_key not in run.inputs:
raise ValueError(f"Run {run.id} does not have input key {self.input_key}.")
elif self.prediction_key not in run.outputs:
raise ValueError(
f"Run {run.id} does not have prediction key {self.prediction_key}."
)
else:
return {
"input": run.inputs[self.input_key],
"prediction": run.outputs[self.prediction_key],
}
class ToolStringRunMapper(StringRunMapper):
"""Map an input to the tool."""
def map(self, run: Run) -> Dict[str, str]:
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
return {"input": run.inputs["input"], "prediction": run.outputs["output"]}
class StringExampleMapper(Serializable):
"""Map an example, or row in the dataset, to the inputs of an evaluation."""
reference_key: Optional[str] = None
@property
def output_keys(self) -> List[str]:
"""The keys to extract from the run."""
return ["reference"]
def serialize_chat_messages(self, messages: List[Dict]) -> str:
"""Extract the input messages from the run."""
chat_messages = messages_from_dict(messages)
return get_buffer_string(chat_messages)
def map(self, example: Example) -> Dict[str, str]:
"""Maps the Example, or dataset row to a dictionary."""
if not example.outputs:
raise ValueError(
f"Example {example.id} has no outputs to use as a reference."
)
if self.reference_key is None:
if len(example.outputs) > 1:
raise ValueError(
f"Example {example.id} has multiple outputs, so you must"
" specify a reference_key."
)
else:
output = list(example.outputs.values())[0]
return {
"reference": output
if type(output) == str
else self.serialize_chat_messages([output])
}
elif self.reference_key not in example.outputs:
raise ValueError(
f"Example {example.id} does not have reference key"
f" {self.reference_key}."
)
return {"reference": example.outputs[self.reference_key]}
def __call__(self, example: Example) -> Dict[str, Any]:
"""Maps the Run and Example to a dictionary."""
if not example.outputs:
raise ValueError(
f"Example {example.id} has no outputs to use as areference label."
)
return self.map(example)
# TODO(agola11) can make these abstract classes
class BaseRunMapper(Protocol):
def map(self, run: Run) -> Dict[str, Any]: ...
class BaseExampleMapper(Protocol):
def map(self, example: Example) -> Dict[str, Any]: ...
class SimpleRunEvaluatorChain(Chain, RunEvaluator):
"""Evaluate Run and optional examples."""
run_mapper: BaseRunMapper
"""Maps the Run to a dictionary with 'input' and 'prediction' strings."""
example_mapper: Optional[BaseExampleMapper] = None
"""Maps the Example (dataset row) to a dictionary
with a 'reference' string."""
name: str
"""The name of the evaluation metric."""
evaluator: Union[StringEvaluator, MessageEvaluator]
"""The evaluation chain."""
@property
def input_keys(self) -> List[str]:
return ["run", "example"]
@property
def output_keys(self) -> List[str]:
return ["feedback"]
def _prepare_input(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
run: Run = inputs["run"]
example: Optional[Example] = inputs.get("example")
evaluate_inputs = self.run_mapper.map(run)
if self.example_mapper:
if not example:
raise ValueError(
f"Evaluator {self.name} requires an reference"
" example from the dataset,"
f" but none was provided for run {run.id}."
)
evaluate_inputs.update(self.example_mapper.map(example))
return evaluate_inputs
def _prepare_output(self, output: Dict[str, Any]) -> EvaluationResult:
evaluation_result = EvaluationResult(key=self.name, **output)
if RUN_KEY in output:
# TODO: Not currently surfaced. Update
evaluation_result.evaluator_info[RUN_KEY] = output[RUN_KEY]
return evaluation_result
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
evaluate_inputs = self._prepare_input(inputs)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
if isinstance(self.evaluator, StringEvaluator):
chain_output = self.evaluator.evaluate_strings(
**evaluate_inputs,
callbacks=callbacks,
)
elif isinstance(self.evaluator, MessageEvaluator):
chain_output = self.evaluator.evaluate_messages(
**evaluate_inputs,
callbacks=callbacks,
)
else:
raise ValueError("Unsupported evaluator type")
evaluation_result = self._prepare_output(chain_output)
return {"feedback": evaluation_result}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun | None = None,
) -> Dict[str, Any]:
evaluate_inputs = self._prepare_input(inputs)
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
if isinstance(self.evaluator, StringEvaluator):
chain_output = await self.evaluator.aevaluate_strings(
**evaluate_inputs,
callbacks=callbacks,
)
elif isinstance(self.evaluator, MessageEvaluator):
chain_output = await self.evaluator.aevaluate_messages(
**evaluate_inputs,
callbacks=callbacks,
)
else:
raise ValueError("Unsupported evaluator type")
evaluation_result = self._prepare_output(chain_output)
return {"feedback": evaluation_result}
def evaluate_run(
self, run: Run, example: Optional[Example] = None
) -> EvaluationResult:
"""Evaluate an example."""
return self({"run": run, "example": example})["feedback"]
async def aevaluate_run(
self, run: Run, example: Optional[Example] = None
) -> EvaluationResult:
"""Evaluate an example."""
result = await self.acall({"run": run, "example": example})
return result["feedback"]
# TODO: Add ability to load message evaluators
@classmethod
def from_model_and_evaluator(
cls,
model: Union[Chain, BaseLanguageModel, Tool],
evaluator: Union[StringEvaluator, MessageEvaluator],
input_key: Optional[str] = None,
prediction_key: Optional[str] = None,
reference_key: Optional[str] = None,
) -> SimpleRunEvaluatorChain:
"""Create a StringRunEvaluatorChain from a model and evaluator."""
if isinstance(evaluator, StringEvaluator):
if isinstance(model, BaseLanguageModel):
run_mapper: StringRunMapper = LLMStringRunMapper()
elif isinstance(model, Chain):
run_mapper = ChainStringRunMapper.from_chain(
model, input_key=input_key, prediction_key=prediction_key
)
elif isinstance(model, Tool):
run_mapper = ToolStringRunMapper()
else:
raise NotImplementedError(
f"{cls.__name__}.from_model_and_evaluator({type(model)})"
" not yet implemented."
"Expected one of [BaseLanguageModel, Chain, Tool]."
)
if reference_key is not None or isinstance(model, BaseLanguageModel):
example_mapper = StringExampleMapper(reference_key=reference_key)
elif evaluator.requires_reference:
raise ValueError(
f"Evaluator {evaluator.evaluation_name} requires a reference"
" example from the dataset. Please specify the reference key from"
" amongst the dataset outputs keys."
)
else:
example_mapper = None
elif isinstance(evaluator, MessageEvaluator):
raise NotImplementedError()
else:
raise NotImplementedError()
return cls(
name=evaluator.evaluation_name,
run_mapper=run_mapper,
example_mapper=example_mapper,
evaluator=evaluator,
)

View File

@@ -1,12 +1,56 @@
"""Interfaces to be implemented by general evaluators."""
from abc import abstractmethod
from typing import Any, Optional, Protocol, runtime_checkable
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Optional, List
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.schema import BaseMessage, get_buffer_string
@runtime_checkable
class StringEvaluator(Protocol):
class EvaluatorType(str, Enum):
"""The types of the evaluators."""
QA = "qa"
"""Question answering evaluator, which grades answers to questions
directly using an LLM."""
COT_QA = "cot_qa"
"""Chain of thought question answering evaluator, which grades
answers to questions using
chain of thought 'reasoning'."""
CONTEXT_QA = "context_qa"
"""Question answering evaluator that incorporates 'context' in the response."""
PAIRWISE_STRING = "pairwise_string"
"""The pairwise string evaluator, which compares the output of two models."""
AGENT_TRAJECTORY = "trajectory"
"""The agent trajectory evaluator, which grades the agent's intermediate steps."""
CRITERIA = "criteria"
"""The criteria evaluator, which evaluates a model based on a
custom set of criteria."""
class EvalChain(Chain):
"""A base class for evaluators that use an LLM."""
@classmethod
@abstractmethod
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> EvalChain:
"""Create a new evaluator from an LLM."""
class StringEvaluator(ABC):
"""Protocol for evaluating strings."""
@property
def evaluation_name(self) -> str:
raise NotImplementedError()
@property
def requires_reference(self) -> bool:
return False
@abstractmethod
def evaluate_strings(
self,
@@ -26,6 +70,10 @@ class StringEvaluator(Protocol):
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
"""
async def aevaluate_strings(
@@ -47,6 +95,10 @@ class StringEvaluator(Protocol):
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
"""
raise NotImplementedError(
f"{self.__class__.__name__} hasn't implemented an "
@@ -54,8 +106,114 @@ class StringEvaluator(Protocol):
)
@runtime_checkable
class PairwiseStringEvaluator(Protocol):
class MessageEvaluator(ABC):
"""Protocol for evaluating messages."""
@property
def evaluation_name(self) -> str:
raise NotImplementedError()
@property
def requires_reference(self) -> bool:
return False
@abstractmethod
def evaluate_messages(
self,
*,
prediction: BaseMessage,
reference: Optional[BaseMessage] = None,
input: Optional[List[BaseMessage]] = None,
**kwargs: Any,
) -> dict:
"""Evaluate Chain or LLM output, based on optional input and label.
Args:
prediction (BaseMessage): the prediction to evaluate.
reference (Optional[BaseMessage], optional): the reference label
to evaluate against.
input (Optional[List[BaseMessage]], optional): the input to consider during
evaluation
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
"""
async def aevaluate_messages(
self,
*,
prediction: BaseMessage,
reference: Optional[BaseMessage] = None,
input: Optional[List[BaseMessage]] = None,
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate Chain or LLM output, based on optional
input and label.
Args:
prediction (BaseMessage): the prediction to evaluate.
reference (Optional[BaseMessage], optional): the reference label
to evaluate against.
input (Optional[List[BaseMessage]], optional): the input to consider during
evaluation
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
"""
raise NotImplementedError(
f"{self.__class__.__name__} hasn't implemented an "
"async aevaluate_messages method."
)
# TODO(agola11): move this out of schema
class SimpleMessageEvaluator(MessageEvaluator):
"""Simple implementation of MessageEvaluator that delegates to a StringEvaluator."""
def __init__(self, string_evaluator: StringEvaluator):
self.string_evaluator = string_evaluator
def evaluate_messages(
self,
*,
prediction: BaseMessage,
reference: Optional[BaseMessage] = None,
input: Optional[List[BaseMessage]] = None,
**kwargs: Any,
) -> dict:
return self.string_evaluator.evaluate_strings(
prediction=get_buffer_string([prediction]),
reference=get_buffer_string([reference]) if reference else None,
input=get_buffer_string(input) if input else None,
**kwargs,
)
async def aevaluate_messages(
self,
*,
prediction: BaseMessage,
reference: Optional[BaseMessage] = None,
input: Optional[List[BaseMessage]] = None,
**kwargs: Any,
) -> dict:
return await self.string_evaluator.aevaluate_strings(
prediction=get_buffer_string([prediction]),
reference=get_buffer_string([reference]) if reference else None,
input=get_buffer_string(input) if input else None,
**kwargs,
)
class PairwiseStringEvaluator(ABC):
"""A protocol for comparing the output of two models."""
@abstractmethod
@@ -86,6 +244,7 @@ class PairwiseStringEvaluator(Protocol):
async def aevaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
reference: Optional[str] = None,

View File

@@ -168,7 +168,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
elif _type == "chat":
return ChatMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:

View File

@@ -1,13 +1,15 @@
"""Test agent trajectory evaluation chain."""
from typing import List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import pytest
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
from langchain.schema import AgentAction
from langchain.schema import AgentAction, BaseMessage
from langchain.tools.base import tool
from tests.unit_tests.llms.fake_llm import FakeLLM
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
@pytest.fixture
@@ -30,10 +32,31 @@ def foo(bar: str) -> str:
return bar
class _FakeTrajectoryChatModel(FakeChatModel):
queries: Dict = Field(default_factory=dict)
sequential_responses: Optional[bool] = False
response_index: int = 0
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.sequential_responses:
response = self.queries[list(self.queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response
else:
prompt = messages[0].content
return self.queries[prompt]
def test_trajectory_eval_chain(
intermediate_steps: List[Tuple[AgentAction, str]]
) -> None:
llm = FakeLLM(
llm = _FakeTrajectoryChatModel(
queries={
"a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1",
@@ -61,7 +84,7 @@ def test_trajectory_eval_chain(
def test_trajectory_eval_chain_no_tools(
intermediate_steps: List[Tuple[AgentAction, str]]
) -> None:
llm = FakeLLM(
llm = _FakeTrajectoryChatModel(
queries={
"a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1",
@@ -85,7 +108,7 @@ def test_trajectory_eval_chain_no_tools(
def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> None:
llm = FakeLLM(
llm = _FakeTrajectoryChatModel(
queries={
"a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1",

View File

@@ -32,4 +32,4 @@ def test_criteria_eval_chain() -> None:
def test_implements_string_protocol() -> None:
assert isinstance(CriteriaEvalChain, StringEvaluator)
assert issubclass(CriteriaEvalChain, StringEvaluator)

View File

@@ -52,7 +52,7 @@ def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
def test_implements_string_evaluator_protocol(
chain_cls: Type[LLMChain],
) -> None:
assert isinstance(chain_cls, StringEvaluator)
assert issubclass(chain_cls, StringEvaluator)
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])

View File

@@ -0,0 +1,114 @@
"""Test the loading function for evalutors."""
from unittest.mock import MagicMock
import pytest
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
from langchain.evaluation.loading import load_evaluators
from langchain.evaluation.run_evaluators.string_run_evaluator import (
StringRunEvaluatorChain,
)
from langchain.evaluation.schema import StringEvaluator
from tests.unit_tests.chains.test_base import FakeChain
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.mark.parametrize("evaluator_type", ["qa", "cot_qa", "context_qa", "criteria"])
def test_load_string_run_evaluators_with_llm(evaluator_type: str) -> None:
"""Test loading evaluators."""
fake_llm = FakeLLM(
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
)
evaluator = load_evaluators([evaluator_type], llm=fake_llm)[0] # type: ignore
if not isinstance(evaluator, StringEvaluator):
raise ValueError("Evaluator is not a string evaluator")
model = FakeLLM(queries={"text": "Foo output"}, sequential_responses=True)
kwargs = {}
if evaluator.requires_reference:
kwargs["reference_key"] = "generations"
run_evaluator = StringRunEvaluatorChain.from_model_and_evaluator(
model, evaluator, **kwargs
)
callback = RunCollectorCallbackHandler()
model.predict("Foo input", callbacks=[callback])
run = callback.traced_runs[0]
example = MagicMock()
example.inputs = {}
example.outputs = {"generations": "Foo output"}
result = run_evaluator._prepare_input({"run": run, "example": example})
assert result["input"] == "Foo input"
assert result["prediction"] == "Foo output"
if evaluator.requires_reference:
assert "reference" in result
assert result["reference"] == "Foo output"
@pytest.mark.parametrize("evaluator_type", ["qa", "cot_qa", "context_qa", "criteria"])
def test_load_string_run_evaluators_with_chat_model(evaluator_type: str) -> None:
"""Test loading evaluators."""
fake_llm = FakeLLM(
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
)
evaluator = load_evaluators([evaluator_type], llm=fake_llm)[0] # type: ignore
if not isinstance(evaluator, StringEvaluator):
raise ValueError("Evaluator is not a string evaluator")
model = FakeChatModel()
kwargs = {}
if evaluator.requires_reference:
kwargs["reference_key"] = "generations"
run_evaluator = StringRunEvaluatorChain.from_model_and_evaluator(
model, evaluator, **kwargs
)
callback = RunCollectorCallbackHandler()
model.predict("Foo input", callbacks=[callback])
run = callback.traced_runs[0]
example = MagicMock()
example.inputs = {}
example.outputs = {"generations": "Another fake response"}
result = run_evaluator._prepare_input({"run": run, "example": example})
assert result["input"] == "Human: Foo input"
assert result["prediction"] == "fake response"
if evaluator.requires_reference:
assert "reference" in result
assert result["reference"] == "Another fake response"
@pytest.mark.parametrize("evaluator_type", ["qa", "cot_qa", "context_qa", "criteria"])
def test_load_string_run_evaluators_with_chain(evaluator_type: str) -> None:
model = FakeChain(
the_input_keys=["an_input", "another_input"],
)
fake_llm = FakeChatModel()
evaluator = load_evaluators([evaluator_type], llm=fake_llm)[0] # type: ignore
if not isinstance(evaluator, StringEvaluator):
raise ValueError("Evaluator is not a string evaluator")
# No input key
with pytest.raises(ValueError, match="multiple input keys"):
StringRunEvaluatorChain.from_model_and_evaluator(model, evaluator)
with pytest.raises(ValueError, match="does not have specified"):
StringRunEvaluatorChain.from_model_and_evaluator(
model, evaluator, input_key="some_input"
)
kwargs = {}
if evaluator.requires_reference:
kwargs["reference_key"] = "label_column"
run_evaluator = StringRunEvaluatorChain.from_model_and_evaluator(
model, evaluator, input_key="an_input", **kwargs
)
callback = RunCollectorCallbackHandler()
model(
{"an_input": "Foo input", "another_input": "Another fake response"},
callbacks=[callback],
)
run = callback.traced_runs[0]
example = MagicMock()
example.inputs = {}
example.outputs = {"label_column": "Another fake response"}
result = run_evaluator._prepare_input({"run": run, "example": example})
assert result["input"] == "Foo input"
assert result["prediction"] == "baz"
if evaluator.requires_reference:
assert "reference" in result
assert result["reference"] == "Another fake response"

View File

@@ -0,0 +1,31 @@
"""Test the loading function for evalutors."""
import pytest
from langchain.evaluation.loading import EvaluatorType, load_evaluators
from langchain.evaluation.schema import StringEvaluator
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.mark.parametrize("evaluator_type", EvaluatorType)
def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
"""Test loading evaluators."""
fake_llm = FakeChatModel()
load_evaluators([evaluator_type], llm=fake_llm)
# Test as string
load_evaluators([evaluator_type.value], llm=fake_llm) # type: ignore
def test_criteria_eval_chain_requires_reference() -> None:
"""Test loading evaluators."""
fake_llm = FakeLLM(
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
)
evaluator = load_evaluators(
[EvaluatorType.CRITERIA], llm=fake_llm, requires_reference=True
)[0]
if not isinstance(evaluator, StringEvaluator):
raise ValueError("Evaluator is not a string evaluator")
assert evaluator.requires_reference