mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 11:47:49 +00:00
Explicitly list requires_reference in function (#7357)
This commit is contained in:
parent
49b2b0e3c0
commit
38ca5c84cb
@ -251,17 +251,24 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
requires_reference=True,
|
||||
)
|
||||
"""
|
||||
expected_input_vars = {"input", "output", "criteria"}
|
||||
if prompt is None:
|
||||
if requires_reference:
|
||||
prompt = PROMPT_WITH_REFERENCES
|
||||
else:
|
||||
prompt = PROMPT
|
||||
if requires_reference:
|
||||
expected_input_vars.add("reference")
|
||||
if expected_input_vars != set(prompt.input_variables):
|
||||
raise ValueError(
|
||||
f"Input variables should be {expected_input_vars}, "
|
||||
f"but got {prompt.input_variables}"
|
||||
)
|
||||
|
||||
criteria_ = cls.resolve_criteria(criteria)
|
||||
criteria_str = " ".join(f"{k}: {v}" for k, v in criteria_.items())
|
||||
prompt_ = prompt.partial(criteria=criteria_str)
|
||||
return cls(
|
||||
llm=llm, prompt=prompt_, requires_reference=requires_reference, **kwargs
|
||||
)
|
||||
return cls(llm=llm, prompt=prompt_, **kwargs)
|
||||
|
||||
def _get_eval_input(
|
||||
self,
|
||||
|
@ -14,7 +14,6 @@ from langchain.evaluation.criteria.eval_chain import (
|
||||
CriteriaEvalChain,
|
||||
CriteriaResultOutputParser,
|
||||
)
|
||||
from langchain.evaluation.criteria.prompt import PROMPT as CRITERIA_PROMPT
|
||||
from langchain.evaluation.qa.eval_chain import QAEvalChain
|
||||
from langchain.evaluation.qa.eval_prompt import PROMPT as QA_DEFAULT_PROMPT
|
||||
from langchain.evaluation.qa.eval_prompt import SQL_PROMPT
|
||||
@ -152,8 +151,9 @@ def get_criteria_evaluator(
|
||||
*,
|
||||
input_key: str = "input",
|
||||
prediction_key: str = "output",
|
||||
prompt: BasePromptTemplate = CRITERIA_PROMPT,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
evaluation_name: Optional[str] = None,
|
||||
requires_reference: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> RunEvaluatorChain:
|
||||
"""Get an eval chain for grading a model's response against a map of criteria."""
|
||||
@ -174,7 +174,11 @@ def get_criteria_evaluator(
|
||||
)
|
||||
tags = kwargs.pop("tags", [])
|
||||
eval_chain = CriteriaEvalChain.from_llm(
|
||||
llm=llm, criteria=criteria_, prompt=prompt, **kwargs
|
||||
llm=llm,
|
||||
criteria=criteria_,
|
||||
prompt=prompt,
|
||||
requires_reference=requires_reference,
|
||||
**kwargs,
|
||||
)
|
||||
return RunEvaluatorChain(
|
||||
eval_chain=eval_chain,
|
||||
|
Loading…
Reference in New Issue
Block a user