mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 11:12:47 +00:00
Load Evaluator (#6942)
Create a `load_evaluators()` function so you don't have to import all the individual evaluator classes
This commit is contained in:
parent
12d14f8947
commit
e736d60516
@ -1,33 +1,45 @@
|
|||||||
"""Functionality relating to evaluation.
|
"""Evaluation chains for grading LLM and Chain outputs.
|
||||||
|
|
||||||
This module contains off-the-shelf evaluation chains for
|
This module contains off-the-shelf evaluation chains for grading the output of
|
||||||
grading the output of LangChain primitives such as LLMs and Chains.
|
LangChain primitives such as language models and chains.
|
||||||
|
|
||||||
|
To load an evaluator, you can use the :func:`load_evaluators <langchain.evaluation.loading.load_evaluators>` function with the
|
||||||
|
names of the evaluators to load.
|
||||||
|
|
||||||
|
To load one of the LangChain HuggingFace datasets, you can use the :func:`load_dataset <langchain.evaluation.loading.load_dataset>` function with the
|
||||||
|
name of the dataset to load.
|
||||||
|
|
||||||
Some common use cases for evaluation include:
|
Some common use cases for evaluation include:
|
||||||
|
|
||||||
- Grading accuracy of a response against ground truth answers: QAEvalChain
|
- Grading the accuracy of a response against ground truth answers: :class:`QAEvalChain <langchain.evaluation.qa.eval_chain.QAEvalChain>`
|
||||||
- Comparing the output of two models: PairwiseStringEvalChain
|
- Comparing the output of two models: :class:`PairwiseStringEvalChain <langchain.evaluation.comparison.eval_chain.PairwiseStringEvalChain>`
|
||||||
- Judging the efficacy of an agent's tool usage: TrajectoryEvalChain
|
- Judging the efficacy of an agent's tool usage: :class:`TrajectoryEvalChain <langchain.evaluation.agents.trajectory_eval_chain.TrajectoryEvalChain>`
|
||||||
- Checking whether an output complies with a set of criteria: CriteriaEvalChain
|
- Checking whether an output complies with a set of criteria: :class:`CriteriaEvalChain <langchain.evaluation.criteria.eval_chain.CriteriaEvalChain>`
|
||||||
|
|
||||||
This module also contains low level APIs for making more evaluators for your
|
This module also contains low-level APIs for creating custom evaluators for
|
||||||
custom evaluation task. These include:
|
specific evaluation tasks. These include:
|
||||||
- StringEvaluator: Evaluates an output string against a reference and/or
|
|
||||||
with input context.
|
|
||||||
- PairwiseStringEvaluator: Evaluates two strings against each other.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
|
- :class:`StringEvaluator <langchain.evaluation.schema.StringEvaluator>`: Evaluate a prediction string against a reference label and/or input context.
|
||||||
|
- :class:`PairwiseStringEvaluator <langchain.evaluation.schema.PairwiseStringEvaluator>`: Evaluate two prediction strings against each other.
|
||||||
|
Useful for scoring preferences, measuring similarity between two chain or llm agents, or comparing outputs on similar inputs.
|
||||||
|
- :class:`AgentTrajectoryEvaluator <langchain.evaluation.schema.AgentTrajectoryEvaluator>`: Evaluate the full sequence of actions
|
||||||
|
taken by an agent.
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
from langchain.evaluation.agents import TrajectoryEvalChain
|
||||||
from langchain.evaluation.comparison import PairwiseStringEvalChain
|
from langchain.evaluation.comparison import PairwiseStringEvalChain
|
||||||
from langchain.evaluation.criteria.eval_chain import CriteriaEvalChain
|
from langchain.evaluation.criteria import CriteriaEvalChain
|
||||||
|
from langchain.evaluation.loading import load_dataset, load_evaluator, load_evaluators
|
||||||
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
|
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
|
||||||
from langchain.evaluation.schema import (
|
from langchain.evaluation.schema import (
|
||||||
AgentTrajectoryEvaluator,
|
AgentTrajectoryEvaluator,
|
||||||
|
EvaluatorType,
|
||||||
PairwiseStringEvaluator,
|
PairwiseStringEvaluator,
|
||||||
StringEvaluator,
|
StringEvaluator,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"EvaluatorType",
|
||||||
"PairwiseStringEvalChain",
|
"PairwiseStringEvalChain",
|
||||||
"QAEvalChain",
|
"QAEvalChain",
|
||||||
"CotQAEvalChain",
|
"CotQAEvalChain",
|
||||||
@ -36,5 +48,8 @@ __all__ = [
|
|||||||
"PairwiseStringEvaluator",
|
"PairwiseStringEvaluator",
|
||||||
"TrajectoryEvalChain",
|
"TrajectoryEvalChain",
|
||||||
"CriteriaEvalChain",
|
"CriteriaEvalChain",
|
||||||
|
"load_evaluators",
|
||||||
|
"load_evaluator",
|
||||||
|
"load_dataset",
|
||||||
"AgentTrajectoryEvaluator",
|
"AgentTrajectoryEvaluator",
|
||||||
]
|
]
|
||||||
|
@ -7,7 +7,7 @@ chain (LLMChain) to generate the reasoning and scores.
|
|||||||
|
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Extra, Field
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -15,14 +15,13 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.evaluation.agents.trajectory_eval_prompt import (
|
from langchain.evaluation.agents.trajectory_eval_prompt import (
|
||||||
EVAL_CHAT_PROMPT,
|
EVAL_CHAT_PROMPT,
|
||||||
TOOL_FREE_EVAL_CHAT_PROMPT,
|
TOOL_FREE_EVAL_CHAT_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.evaluation.schema import AgentTrajectoryEvaluator
|
from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain
|
||||||
from langchain.schema import AgentAction, BaseOutputParser, OutputParserException
|
from langchain.schema import AgentAction, BaseOutputParser, OutputParserException
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
@ -71,7 +70,7 @@ class TrajectoryOutputParser(BaseOutputParser):
|
|||||||
return TrajectoryEval(score=int(score_str), reasoning=reasoning)
|
return TrajectoryEval(score=int(score_str), reasoning=reasoning)
|
||||||
|
|
||||||
|
|
||||||
class TrajectoryEvalChain(AgentTrajectoryEvaluator, Chain):
|
class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
|
||||||
"""A chain for evaluating ReAct style agents.
|
"""A chain for evaluating ReAct style agents.
|
||||||
|
|
||||||
This chain is used to evaluate ReAct style agents by reasoning about
|
This chain is used to evaluate ReAct style agents by reasoning about
|
||||||
@ -125,6 +124,11 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, Chain):
|
|||||||
return_reasoning: bool = False
|
return_reasoning: bool = False
|
||||||
"""Whether to return the reasoning along with the score."""
|
"""Whether to return the reasoning along with the score."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for the QAEvalChain."""
|
||||||
|
|
||||||
|
extra = Extra.ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _tools_description(self) -> str:
|
def _tools_description(self) -> str:
|
||||||
"""Get the description of the agent tools.
|
"""Get the description of the agent tools.
|
||||||
|
@ -3,13 +3,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Extra, Field
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE
|
from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE
|
||||||
from langchain.evaluation.schema import PairwiseStringEvaluator
|
from langchain.evaluation.schema import LLMEvalChain, PairwiseStringEvaluator
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import BaseOutputParser
|
from langchain.schema import BaseOutputParser
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMChain):
|
class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
|
||||||
"""A chain for comparing the output of two models.
|
"""A chain for comparing the output of two models.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -81,6 +81,11 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMChain):
|
|||||||
default_factory=PairwiseStringResultOutputParser
|
default_factory=PairwiseStringResultOutputParser
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for the QAEvalChain."""
|
||||||
|
|
||||||
|
extra = Extra.ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def requires_reference(self) -> bool:
|
def requires_reference(self) -> bool:
|
||||||
return "reference" in self.prompt.input_variables
|
return "reference" in self.prompt.input_variables
|
||||||
|
@ -8,7 +8,7 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
|
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
|
||||||
from langchain.evaluation.schema import StringEvaluator
|
from langchain.evaluation.schema import LLMEvalChain, StringEvaluator
|
||||||
from langchain.schema import BaseOutputParser, BasePromptTemplate
|
from langchain.schema import BaseOutputParser, BasePromptTemplate
|
||||||
|
|
||||||
_SUPPORTED_CRITERIA = {
|
_SUPPORTED_CRITERIA = {
|
||||||
@ -60,7 +60,7 @@ CRITERIA_TYPE = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class CriteriaEvalChain(StringEvaluator, LLMChain):
|
class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||||
"""LLM Chain for evaluating runs against criteria.
|
"""LLM Chain for evaluating runs against criteria.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -1,8 +1,108 @@
|
|||||||
from typing import Dict, List
|
"""Loading datasets and evaluators."""
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Type
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
|
||||||
|
from langchain.evaluation.comparison import PairwiseStringEvalChain
|
||||||
|
from langchain.evaluation.criteria.eval_chain import CriteriaEvalChain
|
||||||
|
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
|
||||||
|
from langchain.evaluation.schema import EvaluatorType, LLMEvalChain
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(uri: str) -> List[Dict]:
|
def load_dataset(uri: str) -> List[Dict]:
|
||||||
|
"""Load a dataset from the LangChainDatasets HuggingFace org."""
|
||||||
|
try:
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"load_dataset requires the `datasets` package."
|
||||||
|
" Please install with `pip install datasets`"
|
||||||
|
)
|
||||||
|
|
||||||
dataset = load_dataset(f"LangChainDatasets/{uri}")
|
dataset = load_dataset(f"LangChainDatasets/{uri}")
|
||||||
return [d for d in dataset["train"]]
|
return [d for d in dataset["train"]]
|
||||||
|
|
||||||
|
|
||||||
|
_EVALUATOR_MAP: Dict[EvaluatorType, Type[LLMEvalChain]] = {
|
||||||
|
EvaluatorType.QA: QAEvalChain,
|
||||||
|
EvaluatorType.COT_QA: CotQAEvalChain,
|
||||||
|
EvaluatorType.CONTEXT_QA: ContextQAEvalChain,
|
||||||
|
EvaluatorType.PAIRWISE_STRING: PairwiseStringEvalChain,
|
||||||
|
EvaluatorType.AGENT_TRAJECTORY: TrajectoryEvalChain,
|
||||||
|
EvaluatorType.CRITERIA: CriteriaEvalChain,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_evaluator(
|
||||||
|
evaluator: EvaluatorType,
|
||||||
|
*,
|
||||||
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Chain:
|
||||||
|
"""Load the requested evaluation chain specified by a string.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
evaluator : EvaluatorType
|
||||||
|
The type of evaluator to load.
|
||||||
|
llm : BaseLanguageModel, optional
|
||||||
|
The language model to use for evaluation, by default None
|
||||||
|
**kwargs : Any
|
||||||
|
Additional keyword arguments to pass to the evaluator.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Chain
|
||||||
|
The loaded evaluation chain.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> llm = ChatOpenAI(model="gpt-4", temperature=0)
|
||||||
|
>>> evaluator = load_evaluator(EvaluatorType.QA, llm=llm)
|
||||||
|
"""
|
||||||
|
llm = llm or ChatOpenAI(model="gpt-4", temperature=0)
|
||||||
|
return _EVALUATOR_MAP[evaluator].from_llm(llm=llm, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def load_evaluators(
|
||||||
|
evaluators: Sequence[EvaluatorType],
|
||||||
|
*,
|
||||||
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
|
config: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Chain]:
|
||||||
|
"""Load evaluators specified by a list of evaluator types.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
evaluators : Sequence[EvaluatorType]
|
||||||
|
The list of evaluator types to load.
|
||||||
|
llm : BaseLanguageModel, optional
|
||||||
|
The language model to use for evaluation, if none is provided, a default
|
||||||
|
ChatOpenAI gpt-4 model will be used.
|
||||||
|
config : dict, optional
|
||||||
|
A dictionary mapping evaluator types to additional keyword arguments,
|
||||||
|
by default None
|
||||||
|
**kwargs : Any
|
||||||
|
Additional keyword arguments to pass to all evaluators.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[Chain]
|
||||||
|
The loaded evaluators.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
.. code-block:: python
|
||||||
|
from langchain.evaluation import load_evaluators, EvaluatorType
|
||||||
|
evaluators = [EvaluatorType.QA, EvaluatorType.CRITERIA]
|
||||||
|
loaded_evaluators = load_evaluators(evaluators, criteria="helpfulness")
|
||||||
|
"""
|
||||||
|
llm = llm or ChatOpenAI(model="gpt-4", temperature=0)
|
||||||
|
loaded = []
|
||||||
|
for evaluator in evaluators:
|
||||||
|
_kwargs = config.get(evaluator, {}) if config else {}
|
||||||
|
loaded.append(load_evaluator(evaluator, llm=llm, **{**kwargs, **_kwargs}))
|
||||||
|
return loaded
|
||||||
|
@ -3,12 +3,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, List, Optional, Sequence
|
from typing import Any, List, Optional, Sequence
|
||||||
|
|
||||||
|
from pydantic import Extra
|
||||||
|
|
||||||
from langchain import PromptTemplate
|
from langchain import PromptTemplate
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
|
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
|
||||||
from langchain.evaluation.schema import StringEvaluator
|
from langchain.evaluation.schema import LLMEvalChain, StringEvaluator
|
||||||
|
|
||||||
|
|
||||||
def _parse_string_eval_output(text: str) -> dict:
|
def _parse_string_eval_output(text: str) -> dict:
|
||||||
@ -39,9 +41,14 @@ def _parse_string_eval_output(text: str) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class QAEvalChain(LLMChain, StringEvaluator):
|
class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||||
"""LLM Chain specifically for evaluating question answering."""
|
"""LLM Chain specifically for evaluating question answering."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for the QAEvalChain."""
|
||||||
|
|
||||||
|
extra = Extra.ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def requires_reference(self) -> bool:
|
def requires_reference(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -143,7 +150,7 @@ class QAEvalChain(LLMChain, StringEvaluator):
|
|||||||
return _parse_string_eval_output(result["text"])
|
return _parse_string_eval_output(result["text"])
|
||||||
|
|
||||||
|
|
||||||
class ContextQAEvalChain(LLMChain, StringEvaluator):
|
class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||||
"""LLM Chain specifically for evaluating QA w/o GT based on context"""
|
"""LLM Chain specifically for evaluating QA w/o GT based on context"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -3,14 +3,47 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Optional, Sequence, Tuple
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains.base import Chain
|
||||||
from langchain.schema.agent import AgentAction
|
from langchain.schema.agent import AgentAction
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorType(str, Enum):
|
||||||
|
"""The types of the evaluators."""
|
||||||
|
|
||||||
|
QA = "qa"
|
||||||
|
"""Question answering evaluator, which grades answers to questions
|
||||||
|
directly using an LLM."""
|
||||||
|
COT_QA = "cot_qa"
|
||||||
|
"""Chain of thought question answering evaluator, which grades
|
||||||
|
answers to questions using
|
||||||
|
chain of thought 'reasoning'."""
|
||||||
|
CONTEXT_QA = "context_qa"
|
||||||
|
"""Question answering evaluator that incorporates 'context' in the response."""
|
||||||
|
PAIRWISE_STRING = "pairwise_string"
|
||||||
|
"""The pairwise string evaluator, which compares the output of two models."""
|
||||||
|
AGENT_TRAJECTORY = "trajectory"
|
||||||
|
"""The agent trajectory evaluator, which grades the agent's intermediate steps."""
|
||||||
|
CRITERIA = "criteria"
|
||||||
|
"""The criteria evaluator, which evaluates a model based on a
|
||||||
|
custom set of criteria."""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEvalChain(Chain):
|
||||||
|
"""A base class for evaluators that use an LLM."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> LLMEvalChain:
|
||||||
|
"""Create a new evaluator from an LLM."""
|
||||||
|
|
||||||
|
|
||||||
class _EvalArgsMixin:
|
class _EvalArgsMixin:
|
||||||
"""Mixin for checking evaluation arguments."""
|
"""Mixin for checking evaluation arguments."""
|
||||||
|
|
||||||
|
16
tests/unit_tests/evaluation/test_loading.py
Normal file
16
tests/unit_tests/evaluation/test_loading.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
"""Test the loading function for evalutors."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.evaluation.loading import EvaluatorType, load_evaluators
|
||||||
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("evaluator_type", EvaluatorType)
|
||||||
|
def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
|
||||||
|
"""Test loading evaluators."""
|
||||||
|
fake_llm = FakeChatModel()
|
||||||
|
load_evaluators([evaluator_type], llm=fake_llm)
|
||||||
|
|
||||||
|
# Test as string
|
||||||
|
load_evaluators([evaluator_type.value], llm=fake_llm) # type: ignore
|
Loading…
Reference in New Issue
Block a user