mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
Explicitly list requires_reference in function (#7357)
This commit is contained in:
parent
49b2b0e3c0
commit
38ca5c84cb
@ -251,17 +251,24 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
|||||||
requires_reference=True,
|
requires_reference=True,
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
expected_input_vars = {"input", "output", "criteria"}
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
if requires_reference:
|
if requires_reference:
|
||||||
prompt = PROMPT_WITH_REFERENCES
|
prompt = PROMPT_WITH_REFERENCES
|
||||||
else:
|
else:
|
||||||
prompt = PROMPT
|
prompt = PROMPT
|
||||||
|
if requires_reference:
|
||||||
|
expected_input_vars.add("reference")
|
||||||
|
if expected_input_vars != set(prompt.input_variables):
|
||||||
|
raise ValueError(
|
||||||
|
f"Input variables should be {expected_input_vars}, "
|
||||||
|
f"but got {prompt.input_variables}"
|
||||||
|
)
|
||||||
|
|
||||||
criteria_ = cls.resolve_criteria(criteria)
|
criteria_ = cls.resolve_criteria(criteria)
|
||||||
criteria_str = " ".join(f"{k}: {v}" for k, v in criteria_.items())
|
criteria_str = " ".join(f"{k}: {v}" for k, v in criteria_.items())
|
||||||
prompt_ = prompt.partial(criteria=criteria_str)
|
prompt_ = prompt.partial(criteria=criteria_str)
|
||||||
return cls(
|
return cls(llm=llm, prompt=prompt_, **kwargs)
|
||||||
llm=llm, prompt=prompt_, requires_reference=requires_reference, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_eval_input(
|
def _get_eval_input(
|
||||||
self,
|
self,
|
||||||
|
@ -14,7 +14,6 @@ from langchain.evaluation.criteria.eval_chain import (
|
|||||||
CriteriaEvalChain,
|
CriteriaEvalChain,
|
||||||
CriteriaResultOutputParser,
|
CriteriaResultOutputParser,
|
||||||
)
|
)
|
||||||
from langchain.evaluation.criteria.prompt import PROMPT as CRITERIA_PROMPT
|
|
||||||
from langchain.evaluation.qa.eval_chain import QAEvalChain
|
from langchain.evaluation.qa.eval_chain import QAEvalChain
|
||||||
from langchain.evaluation.qa.eval_prompt import PROMPT as QA_DEFAULT_PROMPT
|
from langchain.evaluation.qa.eval_prompt import PROMPT as QA_DEFAULT_PROMPT
|
||||||
from langchain.evaluation.qa.eval_prompt import SQL_PROMPT
|
from langchain.evaluation.qa.eval_prompt import SQL_PROMPT
|
||||||
@ -152,8 +151,9 @@ def get_criteria_evaluator(
|
|||||||
*,
|
*,
|
||||||
input_key: str = "input",
|
input_key: str = "input",
|
||||||
prediction_key: str = "output",
|
prediction_key: str = "output",
|
||||||
prompt: BasePromptTemplate = CRITERIA_PROMPT,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
evaluation_name: Optional[str] = None,
|
evaluation_name: Optional[str] = None,
|
||||||
|
requires_reference: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RunEvaluatorChain:
|
) -> RunEvaluatorChain:
|
||||||
"""Get an eval chain for grading a model's response against a map of criteria."""
|
"""Get an eval chain for grading a model's response against a map of criteria."""
|
||||||
@ -174,7 +174,11 @@ def get_criteria_evaluator(
|
|||||||
)
|
)
|
||||||
tags = kwargs.pop("tags", [])
|
tags = kwargs.pop("tags", [])
|
||||||
eval_chain = CriteriaEvalChain.from_llm(
|
eval_chain = CriteriaEvalChain.from_llm(
|
||||||
llm=llm, criteria=criteria_, prompt=prompt, **kwargs
|
llm=llm,
|
||||||
|
criteria=criteria_,
|
||||||
|
prompt=prompt,
|
||||||
|
requires_reference=requires_reference,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return RunEvaluatorChain(
|
return RunEvaluatorChain(
|
||||||
eval_chain=eval_chain,
|
eval_chain=eval_chain,
|
||||||
|
Loading…
Reference in New Issue
Block a user