mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
ContextQAEvalChain
This commit is contained in:
@@ -13,7 +13,6 @@ from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
|
||||
from langchain.evaluation.schema import LLMEvalChain, StringEvaluator
|
||||
from langchain.schema import RUN_KEY
|
||||
@@ -227,6 +226,7 @@ class QAEvalChain(StringEvaluator, LLMEvalChain):
|
||||
"result": prediction,
|
||||
},
|
||||
config=config,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
@@ -251,13 +251,20 @@ class QAEvalChain(StringEvaluator, LLMEvalChain):
|
||||
"result": prediction,
|
||||
},
|
||||
config=config,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
|
||||
class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
class ContextQAEvalChain(StringEvaluator, LLMEvalChain):
|
||||
"""LLM Chain for evaluating QA w/o GT based on context"""
|
||||
|
||||
output_key: str = "text" #: :meta private:
|
||||
llm: BaseLanguageModel
|
||||
"""The language model to use for scoring."""
|
||||
prompt: BasePromptTemplate
|
||||
"""The prompt to use for scoring."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
@@ -272,6 +279,22 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
"""Whether the chain requires an input string."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
@@ -288,6 +311,19 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
def evaluation_name(self) -> str:
|
||||
return "Contextual Accuracy"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if run_manager:
|
||||
config = RunnableConfig(callbacks=run_manager.get_child())
|
||||
else:
|
||||
config = None
|
||||
chain = self.prompt | self.llm | StrOutputParser()
|
||||
response = chain.invoke(inputs, config=config)
|
||||
return {self.output_key: response}
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
@@ -333,8 +369,13 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
}
|
||||
for i, example in enumerate(examples)
|
||||
]
|
||||
if callbacks:
|
||||
config = RunnableConfig(callbacks=callbacks)
|
||||
else:
|
||||
config = None
|
||||
outputs = self.batch(inputs, config=config)
|
||||
|
||||
return self.apply(inputs, callbacks=callbacks)
|
||||
return [{self.output_key: output[self.output_key]} for output in outputs]
|
||||
|
||||
def _prepare_output(self, result: dict) -> dict:
|
||||
parsed_result = _parse_string_eval_output(result[self.output_key])
|
||||
@@ -352,13 +393,17 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
include_run_info: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
result = self(
|
||||
if callbacks:
|
||||
config = RunnableConfig(callbacks=callbacks)
|
||||
else:
|
||||
config = None
|
||||
result = self.invoke(
|
||||
{
|
||||
"query": input,
|
||||
"context": reference,
|
||||
"result": prediction,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
config=config,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
@@ -373,9 +418,17 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
include_run_info: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
result = await self.acall(
|
||||
inputs={"query": input, "context": reference, "result": prediction},
|
||||
callbacks=callbacks,
|
||||
if callbacks:
|
||||
config = RunnableConfig(callbacks=callbacks)
|
||||
else:
|
||||
config = None
|
||||
result = await self.ainvoke(
|
||||
{
|
||||
"query": input,
|
||||
"context": reference,
|
||||
"result": prediction,
|
||||
},
|
||||
config=config,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
Reference in New Issue
Block a user