Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
75fd543b33 Warn if reference passed but evaluator doesn't require it 2023-07-03 16:44:22 -07:00
2 changed files with 38 additions and 5 deletions

View File

@@ -1,6 +1,8 @@
"""Base classes for comparing the output of two models."""
from __future__ import annotations
import logging
from functools import lru_cache
from typing import Any, Optional
from pydantic import Field
@@ -12,6 +14,13 @@ from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def warn_once(message: str) -> None:
logger.warning(message)
class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
"""A parser for the output of the PairwiseStringEvalChain."""
@@ -86,7 +95,7 @@ class PairwiseStringEvalChain(LLMChain):
*,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
require_reference: bool = False,
requires_reference: bool = False,
**kwargs: Any,
) -> PairwiseStringEvalChain:
"""Initialize the PairwiseStringEvalChain from an LLM.
@@ -94,7 +103,7 @@ class PairwiseStringEvalChain(LLMChain):
Args:
llm (BaseLanguageModel): The LLM to use.
prompt (PromptTemplate, optional): The prompt to use.
require_reference (bool, optional): Whether to require a reference
requires_reference (bool, optional): Whether to require a reference
string. Defaults to False.
**kwargs (Any): Additional keyword arguments.
@@ -103,13 +112,13 @@ class PairwiseStringEvalChain(LLMChain):
"""
expected_input_vars = {"prediction", "prediction_b", "input"}
if prompt is None:
if require_reference:
if requires_reference:
expected_input_vars.add("reference")
prompt_ = PROMPT_WITH_REFERENCE
else:
prompt_ = PROMPT
else:
if require_reference:
if requires_reference:
expected_input_vars.add("reference")
prompt_ = prompt
@@ -128,8 +137,18 @@ class PairwiseStringEvalChain(LLMChain):
"prediction_b": prediction_b,
"input": input,
}
if reference is not None and "reference" in self.prompt.input_variables:
if "reference" in self.prompt.input_variables:
if reference is None:
raise ValueError(
"Prompt requires a reference string, but none was provided."
)
input_["reference"] = reference
elif reference is not None:
warn_once(
"Ignoring reference string in PairwiseStringEvalChain."
" To use references, initialize with argument `requires_reference=True`"
' or use a prompt that included "reference" as an input variable.'
)
return input_
def evaluate_string_pairs(

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import logging
from functools import lru_cache
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from pydantic import Field
@@ -10,6 +12,8 @@ from langchain.chains.llm import LLMChain
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
from langchain.schema import BaseOutputParser, BasePromptTemplate
logger = logging.getLogger(__name__)
_SUPPORTED_CRITERIA = {
"conciseness": "Is the submission concise and to the point?",
"relevance": "Is the submission referring to a real quote from the text?",
@@ -25,6 +29,11 @@ _SUPPORTED_CRITERIA = {
}
@lru_cache(maxsize=1)
def warn_once(message: str) -> None:
logger.warning(message)
class CriteriaResultOutputParser(BaseOutputParser[dict]):
"""A parser for the output of the CriteriaEvalChain."""
@@ -250,6 +259,11 @@ class CriteriaEvalChain(LLMChain):
}
if self.requires_reference:
input_["reference"] = reference
elif reference is not None:
warn_once(
"The reference text will be ignored because this Criteria evaluator"
" does not require a reference."
)
return input_
def evaluate_strings(