mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
Adding an in-context QA evaluation chain + chain of thought reasoning chain for improved accuracy (#2444)
Right now, eval chains require an answer for every question. It's cumbersome to collect this ground truth so getting around this issue with 2 things: * Adding a context param in `ContextQAEvalChain` and simply evaluating if the question is answered accurately from context * Adding chain of though explanation prompting to improve the accuracy of this w/o GT. This also gets to feature parity with openai/evals which has the same contextual eval w/o GT. TODO in follow-up: * Better prompt inheritance. No need for seperate prompt for CoT reasoning. How can we merge them together --------- Co-authored-by: Vashisht Madhavan <vashishtmadhavan@Vashs-MacBook-Pro.local>
This commit is contained in:
parent
e131156805
commit
aa439ac2ff
@ -234,6 +234,93 @@
|
||||
"evalchain.evaluate(examples, predictions, question_key=\"question\", answer_key=\"answer\", prediction_key=\"text\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cb1cf335",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluation without Ground Truth\n",
|
||||
"Its possible to evaluate question answering systems without ground truth. You would need a `\"context\"` input that reflects what the information the LLM uses to answer the question. This context can be obtained by any retreival system. Here's an example of how it works:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6c59293f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"context_examples = [\n",
|
||||
" {\n",
|
||||
" \"question\": \"How old am I?\",\n",
|
||||
" \"context\": \"I am 30 years old. I live in New York and take the train to work everyday.\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"question\": 'Who won the NFC championship game in 2023?\"',\n",
|
||||
" \"context\": \"NFC Championship Game 2023: Philadelphia Eagles 31, San Francisco 49ers 7\"\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"QA_PROMPT = \"Answer the question based on the context\\nContext:{context}\\nQuestion:{question}\\nAnswer:\"\n",
|
||||
"template = PromptTemplate(input_variables=[\"context\", \"question\"], template=QA_PROMPT)\n",
|
||||
"qa_chain = LLMChain(llm=llm, prompt=template)\n",
|
||||
"predictions = qa_chain.apply(context_examples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "e500d0cc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'text': 'You are 30 years old.'},\n",
|
||||
" {'text': ' The Philadelphia Eagles won the NFC championship game in 2023.'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "6d8cbc1d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.evaluation.qa import ContextQAEvalChain\n",
|
||||
"eval_chain = ContextQAEvalChain.from_llm(llm)\n",
|
||||
"graded_outputs = eval_chain.evaluate(context_examples, predictions, question_key=\"question\", prediction_key=\"text\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "6c5262d0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'text': ' CORRECT'}, {'text': ' CORRECT'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"graded_outputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aaa61f0c",
|
||||
@ -329,7 +416,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.9.16"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -1,5 +1,9 @@
|
||||
"""Chains and utils related to evaluating question answering functionality."""
|
||||
from langchain.evaluation.qa.eval_chain import QAEvalChain
|
||||
from langchain.evaluation.qa.eval_chain import (
|
||||
ContextQAEvalChain,
|
||||
CotQAEvalChain,
|
||||
QAEvalChain,
|
||||
)
|
||||
from langchain.evaluation.qa.generate_chain import QAGenerateChain
|
||||
|
||||
__all__ = ["QAEvalChain", "QAGenerateChain"]
|
||||
__all__ = ["QAEvalChain", "QAGenerateChain", "ContextQAEvalChain", "CotQAEvalChain"]
|
||||
|
@ -5,7 +5,7 @@ from typing import Any, List
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation.qa.eval_prompt import PROMPT
|
||||
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
@ -58,3 +58,69 @@ class QAEvalChain(LLMChain):
|
||||
]
|
||||
|
||||
return self.apply(inputs)
|
||||
|
||||
|
||||
class ContextQAEvalChain(LLMChain):
|
||||
"""LLM Chain specifically for evaluating QA w/o GT based on context"""
|
||||
|
||||
@classmethod
|
||||
def _validate_input_vars(cls, prompt: PromptTemplate) -> None:
|
||||
expected_input_vars = {"query", "context", "result"}
|
||||
if expected_input_vars != set(prompt.input_variables):
|
||||
raise ValueError(
|
||||
f"Input variables should be {expected_input_vars}, "
|
||||
f"but got {prompt.input_variables}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLLM, prompt: PromptTemplate = CONTEXT_PROMPT, **kwargs: Any
|
||||
) -> ContextQAEvalChain:
|
||||
"""Load QA Eval Chain from LLM.
|
||||
|
||||
Args:
|
||||
llm (BaseLLM): the base language model to use.
|
||||
|
||||
prompt (PromptTemplate): A prompt template containing the input_variables:
|
||||
'query', 'context' and 'result' that will be used as the prompt
|
||||
for evaluation.
|
||||
Defaults to PROMPT.
|
||||
|
||||
**kwargs: additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
ContextQAEvalChain: the loaded QA eval chain.
|
||||
"""
|
||||
cls._validate_input_vars(prompt)
|
||||
return cls(llm=llm, prompt=prompt, **kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
examples: List[dict],
|
||||
predictions: List[dict],
|
||||
question_key: str = "query",
|
||||
context_key: str = "context",
|
||||
prediction_key: str = "result",
|
||||
) -> List[dict]:
|
||||
"""Evaluate question answering examples and predictions."""
|
||||
inputs = [
|
||||
{
|
||||
"query": example[question_key],
|
||||
"context": example[context_key],
|
||||
"result": predictions[i][prediction_key],
|
||||
}
|
||||
for i, example in enumerate(examples)
|
||||
]
|
||||
|
||||
return self.apply(inputs)
|
||||
|
||||
|
||||
class CotQAEvalChain(ContextQAEvalChain):
|
||||
"""LLM Chain specifically for evaluating QA using chain of thought reasoning."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLLM, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any
|
||||
) -> CotQAEvalChain:
|
||||
cls._validate_input_vars(prompt)
|
||||
return cls(llm=llm, prompt=prompt, **kwargs)
|
||||
|
@ -19,3 +19,44 @@ GRADE:"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["query", "result", "answer"], template=template
|
||||
)
|
||||
|
||||
context_template = """You are a teacher grading a quiz.
|
||||
You are given a question, the contex the question is about, and the student's answer You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context.
|
||||
|
||||
Example Format:
|
||||
QUESTION: question here
|
||||
CONTEXT: context the question is about here
|
||||
STUDENT ANSWER: student's answer here
|
||||
GRADE: CORRECT or INCORRECT here
|
||||
|
||||
Please remember to grade them based on being factually accurate. Begin!
|
||||
|
||||
QUESTION: {query}
|
||||
CONTEXT: {context}
|
||||
STUDENT ANSWER: {result}
|
||||
GRADE:"""
|
||||
CONTEXT_PROMPT = PromptTemplate(
|
||||
input_variables=["query", "context", "result"], template=context_template
|
||||
)
|
||||
|
||||
|
||||
cot_template = """You are a teacher grading a quiz.
|
||||
You are given a question, the contex the question is about, and the student's answer You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context.
|
||||
Write out in a step by step manner your reasoning to be sure that your conclusion is correct. Avoid simply stating the correct answer at the outset.
|
||||
|
||||
Example Format:
|
||||
QUESTION: question here
|
||||
CONTEXT: context the question is about here
|
||||
STUDENT ANSWER: student's answer here
|
||||
EXPLANATION: step by step reasoning here
|
||||
GRADE: CORRECT or INCORRECT here
|
||||
|
||||
Please remember to grade them based on being factually accurate. Begin!
|
||||
|
||||
QUESTION: {query}
|
||||
CONTEXT: {context}
|
||||
STUDENT ANSWER: {result}
|
||||
EXPLANATION:"""
|
||||
COT_PROMPT = PromptTemplate(
|
||||
input_variables=["query", "context", "result"], template=cot_template
|
||||
)
|
||||
|
1
tests/unit_tests/evaluation/__init__.py
Normal file
1
tests/unit_tests/evaluation/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""New unit tests for the evaluation module."""
|
1
tests/unit_tests/evaluation/qa/__init__.py
Normal file
1
tests/unit_tests/evaluation/qa/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for QA evaluation chains."""
|
46
tests/unit_tests/evaluation/qa/test_eval_chain.py
Normal file
46
tests/unit_tests/evaluation/qa/test_eval_chain.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Test LLM Bash functionality."""
|
||||
import sys
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.evaluation.qa.eval_chain import (
|
||||
ContextQAEvalChain,
|
||||
CotQAEvalChain,
|
||||
QAEvalChain,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_eval_chain() -> None:
|
||||
"""Test a simple eval chain."""
|
||||
example = {"query": "What's my name", "answer": "John Doe"}
|
||||
prediction = {"result": "John Doe"}
|
||||
fake_qa_eval_chain = QAEvalChain.from_llm(FakeLLM())
|
||||
|
||||
outputs = fake_qa_eval_chain.evaluate([example, example], [prediction, prediction])
|
||||
assert outputs[0] == outputs[1]
|
||||
assert "text" in outputs[0]
|
||||
assert outputs[0]["text"] == "foo"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
@pytest.mark.parametrize("chain_cls", [ContextQAEvalChain, CotQAEvalChain])
|
||||
def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
|
||||
"""Test a simple eval chain."""
|
||||
example = {
|
||||
"query": "What's my name",
|
||||
"context": "The name of this person is John Doe",
|
||||
}
|
||||
prediction = {"result": "John Doe"}
|
||||
fake_qa_eval_chain = chain_cls.from_llm(FakeLLM())
|
||||
|
||||
outputs = fake_qa_eval_chain.evaluate([example, example], [prediction, prediction])
|
||||
assert outputs[0] == outputs[1]
|
||||
assert "text" in outputs[0]
|
||||
assert outputs[0]["text"] == "foo"
|
Loading…
Reference in New Issue
Block a user