diff --git a/libs/langchain/langchain/evaluation/qa/eval_chain.py b/libs/langchain/langchain/evaluation/qa/eval_chain.py index ca9caca1a83..37e939efceb 100644 --- a/libs/langchain/langchain/evaluation/qa/eval_chain.py +++ b/libs/langchain/langchain/evaluation/qa/eval_chain.py @@ -13,7 +13,6 @@ from langchain_core.prompts import PromptTemplate from langchain_core.prompts.base import BasePromptTemplate from langchain_core.runnables import RunnableConfig -from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT from langchain.evaluation.schema import LLMEvalChain, StringEvaluator from langchain.schema import RUN_KEY @@ -227,6 +226,7 @@ class QAEvalChain(StringEvaluator, LLMEvalChain): "result": prediction, }, config=config, + include_run_info=include_run_info, ) return self._prepare_output(result) @@ -251,13 +251,20 @@ class QAEvalChain(StringEvaluator, LLMEvalChain): "result": prediction, }, config=config, + include_run_info=include_run_info, ) return self._prepare_output(result) -class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): +class ContextQAEvalChain(StringEvaluator, LLMEvalChain): """LLM Chain for evaluating QA w/o GT based on context""" + output_key: str = "text" #: :meta private: + llm: BaseLanguageModel + """The language model to use for scoring.""" + prompt: BasePromptTemplate + """The prompt to use for scoring.""" + @classmethod def is_lc_serializable(cls) -> bool: return False @@ -272,6 +279,22 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): """Whether the chain requires an input string.""" return True + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + return self.prompt.input_variables + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] + class Config: extra = "ignore" @@ -288,6 +311,19 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): def evaluation_name(self) -> str: return "Contextual Accuracy" + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + if run_manager: + config = RunnableConfig(callbacks=run_manager.get_child()) + else: + config = None + chain = self.prompt | self.llm | StrOutputParser() + response = chain.invoke(inputs, config=config) + return {self.output_key: response} + @classmethod def from_llm( cls, @@ -333,8 +369,13 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): } for i, example in enumerate(examples) ] + if callbacks: + config = RunnableConfig(callbacks=callbacks) + else: + config = None + outputs = self.batch(inputs, config=config) - return self.apply(inputs, callbacks=callbacks) + return [{self.output_key: output[self.output_key]} for output in outputs] def _prepare_output(self, result: dict) -> dict: parsed_result = _parse_string_eval_output(result[self.output_key]) @@ -352,13 +393,17 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): include_run_info: bool = False, **kwargs: Any, ) -> dict: - result = self( + if callbacks: + config = RunnableConfig(callbacks=callbacks) + else: + config = None + result = self.invoke( { "query": input, "context": reference, "result": prediction, }, - callbacks=callbacks, + config=config, include_run_info=include_run_info, ) return self._prepare_output(result) @@ -373,9 +418,17 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): include_run_info: bool = False, **kwargs: Any, ) -> dict: - result = await self.acall( - inputs={"query": input, "context": reference, "result": prediction}, - callbacks=callbacks, + if callbacks: + config = RunnableConfig(callbacks=callbacks) + else: + config = None + result = await self.ainvoke( + { + "query": input, + "context": reference, + "result": prediction, + }, + config=config, include_run_info=include_run_info, ) return self._prepare_output(result)