ContextQAEvalChain

This commit is contained in:
Chester Curme
2024-08-15 12:21:11 -04:00
parent 9bd4459f9a
commit 8a70754dfe

View File

@@ -13,7 +13,6 @@ from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.evaluation.schema import LLMEvalChain, StringEvaluator
from langchain.schema import RUN_KEY
@@ -227,6 +226,7 @@ class QAEvalChain(StringEvaluator, LLMEvalChain):
"result": prediction,
},
config=config,
include_run_info=include_run_info,
)
return self._prepare_output(result)
@@ -251,13 +251,20 @@ class QAEvalChain(StringEvaluator, LLMEvalChain):
"result": prediction,
},
config=config,
include_run_info=include_run_info,
)
return self._prepare_output(result)
class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
class ContextQAEvalChain(StringEvaluator, LLMEvalChain):
"""LLM Chain for evaluating QA w/o GT based on context"""
output_key: str = "text" #: :meta private:
llm: BaseLanguageModel
"""The language model to use for scoring."""
prompt: BasePromptTemplate
"""The prompt to use for scoring."""
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@@ -272,6 +279,22 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
"""Whether the chain requires an input string."""
return True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
class Config:
extra = "ignore"
@@ -288,6 +311,19 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
def evaluation_name(self) -> str:
return "Contextual Accuracy"
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if run_manager:
config = RunnableConfig(callbacks=run_manager.get_child())
else:
config = None
chain = self.prompt | self.llm | StrOutputParser()
response = chain.invoke(inputs, config=config)
return {self.output_key: response}
@classmethod
def from_llm(
cls,
@@ -333,8 +369,13 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
}
for i, example in enumerate(examples)
]
if callbacks:
config = RunnableConfig(callbacks=callbacks)
else:
config = None
outputs = self.batch(inputs, config=config)
return self.apply(inputs, callbacks=callbacks)
return [{self.output_key: output[self.output_key]} for output in outputs]
def _prepare_output(self, result: dict) -> dict:
parsed_result = _parse_string_eval_output(result[self.output_key])
@@ -352,13 +393,17 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
result = self(
if callbacks:
config = RunnableConfig(callbacks=callbacks)
else:
config = None
result = self.invoke(
{
"query": input,
"context": reference,
"result": prediction,
},
callbacks=callbacks,
config=config,
include_run_info=include_run_info,
)
return self._prepare_output(result)
@@ -373,9 +418,17 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
result = await self.acall(
inputs={"query": input, "context": reference, "result": prediction},
callbacks=callbacks,
if callbacks:
config = RunnableConfig(callbacks=callbacks)
else:
config = None
result = await self.ainvoke(
{
"query": input,
"context": reference,
"result": prediction,
},
config=config,
include_run_info=include_run_info,
)
return self._prepare_output(result)