mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
Make pairwise comparison chain more like LLM as a judge (#11013)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:**: Adds LLM as a judge as an eval chain - **Tag maintainer:** @hwchase17 Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
parent
175ef0a55d
commit
64385c4eae
@ -1,12 +1,20 @@
|
|||||||
"""Base classes for comparing the output of two models."""
|
"""Base classes for comparing the output of two models."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
from langchain.evaluation.comparison.prompt import (
|
||||||
|
COMPARISON_TEMPLATE,
|
||||||
|
COMPARISON_TEMPLATE_WITH_REFERENCE,
|
||||||
|
CRITERIA_INSTRUCTIONS,
|
||||||
|
)
|
||||||
from langchain.evaluation.criteria.eval_chain import (
|
from langchain.evaluation.criteria.eval_chain import (
|
||||||
CRITERIA_TYPE,
|
CRITERIA_TYPE,
|
||||||
Criteria,
|
Criteria,
|
||||||
@ -17,6 +25,10 @@ from langchain.pydantic_v1 import Extra, Field
|
|||||||
from langchain.schema import RUN_KEY, BaseOutputParser
|
from langchain.schema import RUN_KEY, BaseOutputParser
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_FIND_DOUBLE_BRACKETS = re.compile(r"\[\[(.*?)\]\]")
|
||||||
|
|
||||||
_SUPPORTED_CRITERIA = {
|
_SUPPORTED_CRITERIA = {
|
||||||
Criteria.CONCISENESS: "Is the submission concise and to the point?",
|
Criteria.CONCISENESS: "Is the submission concise and to the point?",
|
||||||
Criteria.RELEVANCE: "Is the submission referring to a real quote from the text?",
|
Criteria.RELEVANCE: "Is the submission referring to a real quote from the text?",
|
||||||
@ -112,27 +124,26 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
|
|||||||
ValueError: If the verdict is invalid.
|
ValueError: If the verdict is invalid.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
parsed = text.strip().rsplit("\n", maxsplit=1)
|
match = _FIND_DOUBLE_BRACKETS.search(text)
|
||||||
if len(parsed) == 1:
|
|
||||||
reasoning = ""
|
if match:
|
||||||
verdict = parsed[0]
|
verdict = match.group(1)
|
||||||
else:
|
|
||||||
reasoning, verdict = parsed
|
if not match or verdict not in {"A", "B", "C"}:
|
||||||
verdict = verdict.strip("[").strip("]")
|
|
||||||
if verdict not in {"A", "B", "C"}:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid verdict: {verdict}. "
|
f"Invalid output: {text}. "
|
||||||
"Verdict must be one of 'A', 'B', or 'C'."
|
"Output must contain a double bracketed string\
|
||||||
|
with the verdict 'A', 'B', or 'C'."
|
||||||
)
|
)
|
||||||
# C means the models are tied. Return 'None' meaning no preference
|
# C means the models are tied. Return 'None' meaning no preference
|
||||||
verdict_ = None if verdict == "C" else verdict
|
verdict_ = None if verdict == "C" else verdict
|
||||||
score = {
|
score = {
|
||||||
"A": 1,
|
"A": 1,
|
||||||
"B": 0,
|
"B": 0,
|
||||||
None: 0.5,
|
"C": 0.5,
|
||||||
}.get(verdict_)
|
}[verdict]
|
||||||
return {
|
return {
|
||||||
"reasoning": reasoning,
|
"reasoning": text,
|
||||||
"value": verdict_,
|
"value": verdict_,
|
||||||
"score": score,
|
"score": score,
|
||||||
}
|
}
|
||||||
@ -225,7 +236,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
|
|||||||
"""Initialize the PairwiseStringEvalChain from an LLM.
|
"""Initialize the PairwiseStringEvalChain from an LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm (BaseLanguageModel): The LLM to use.
|
llm (BaseChatModel): The LLM to use (GPT-4 recommended).
|
||||||
prompt (PromptTemplate, optional): The prompt to use.
|
prompt (PromptTemplate, optional): The prompt to use.
|
||||||
**kwargs (Any): Additional keyword arguments.
|
**kwargs (Any): Additional keyword arguments.
|
||||||
|
|
||||||
@ -236,8 +247,17 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
|
|||||||
ValueError: If the input variables are not as expected.
|
ValueError: If the input variables are not as expected.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if not (
|
||||||
|
isinstance(llm, (ChatOpenAI, AzureChatOpenAI))
|
||||||
|
and llm.model_name.startswith("gpt-4")
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"This chain was only tested with GPT-4. \
|
||||||
|
Performance may be significantly worse with other models."
|
||||||
|
)
|
||||||
|
|
||||||
expected_input_vars = {"prediction", "prediction_b", "input", "criteria"}
|
expected_input_vars = {"prediction", "prediction_b", "input", "criteria"}
|
||||||
prompt_ = prompt or PROMPT
|
prompt_ = prompt or COMPARISON_TEMPLATE.partial(reference="")
|
||||||
if expected_input_vars != set(prompt_.input_variables):
|
if expected_input_vars != set(prompt_.input_variables):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input variables should be {expected_input_vars}, "
|
f"Input variables should be {expected_input_vars}, "
|
||||||
@ -245,6 +265,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
|
|||||||
)
|
)
|
||||||
criteria_ = resolve_pairwise_criteria(criteria)
|
criteria_ = resolve_pairwise_criteria(criteria)
|
||||||
criteria_str = "\n".join(f"{k}: {v}" if v else k for k, v in criteria_.items())
|
criteria_str = "\n".join(f"{k}: {v}" if v else k for k, v in criteria_.items())
|
||||||
|
criteria_str = CRITERIA_INSTRUCTIONS + criteria_str if criteria_str else ""
|
||||||
return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)
|
return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)
|
||||||
|
|
||||||
def _prepare_input(
|
def _prepare_input(
|
||||||
@ -418,7 +439,7 @@ class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
|
|||||||
"reference",
|
"reference",
|
||||||
"criteria",
|
"criteria",
|
||||||
}
|
}
|
||||||
prompt_ = prompt or PROMPT_WITH_REFERENCE
|
prompt_ = prompt or COMPARISON_TEMPLATE_WITH_REFERENCE
|
||||||
if expected_input_vars != set(prompt_.input_variables):
|
if expected_input_vars != set(prompt_.input_variables):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input variables should be {expected_input_vars}, "
|
f"Input variables should be {expected_input_vars}, "
|
||||||
@ -426,4 +447,5 @@ class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
|
|||||||
)
|
)
|
||||||
criteria_ = resolve_pairwise_criteria(criteria)
|
criteria_ = resolve_pairwise_criteria(criteria)
|
||||||
criteria_str = "\n".join(f"{k}: {v}" for k, v in criteria_.items())
|
criteria_str = "\n".join(f"{k}: {v}" for k, v in criteria_.items())
|
||||||
|
criteria_str = CRITERIA_INSTRUCTIONS + criteria_str if criteria_str else ""
|
||||||
return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)
|
return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)
|
||||||
|
@ -5,64 +5,55 @@ and answers the question. The prompt is based on the paper from
|
|||||||
Zheng, et. al. https://arxiv.org/abs/2306.05685
|
Zheng, et. al. https://arxiv.org/abs/2306.05685
|
||||||
"""
|
"""
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
|
||||||
template = """Act as a fair judge and rate the two responses to the question below.\
|
SYSTEM_MESSAGE = 'Please act as an impartial judge and evaluate the quality \
|
||||||
Choose the response that best followed the instructions and answered the question.\
|
of the responses provided by two AI assistants to the user question displayed below. \
|
||||||
Your assessment should weigh the following criteria:
|
You should choose the assistant that follows the user\'s instructions \
|
||||||
{criteria}\
|
and answers \the user\'s question better. \
|
||||||
Start by comparing both responses and give a brief rationale.\
|
Your evaluation should consider factors such as the \
|
||||||
Avoid bias from the order of presentation or response length.
|
helpfulness, relevance, accuracy, depth, creativity, \
|
||||||
After giving your rationale, make your final decision using this format:\
|
and level of detail of their responses. \
|
||||||
"[[A]]" if assistant A is better, "[[B]]" if assistant B is better,\
|
Begin your evaluation by comparing the two responses and provide a short explanation. \
|
||||||
and "[[C]]" for a tie. Finally, repeat the decision again on its own on a new line.
|
Avoid any position biases and ensure that the order in which \
|
||||||
|
the responses were presented does not influence your decision. \
|
||||||
|
Do not allow the length of the responses to influence your evaluation. \
|
||||||
|
Do not favor certain names of the assistants. Be as objective as possible. \
|
||||||
|
After providing your explanation, output your final verdict by strictly following \
|
||||||
|
this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, \
|
||||||
|
and "[[C]]" for a tie.'
|
||||||
|
|
||||||
[QUESTION]
|
CRITERIA_INSTRUCTIONS = (
|
||||||
{input}
|
"For this evaluation, you should primarily consider the following criteria:\n"
|
||||||
[/QUESTION]
|
|
||||||
|
|
||||||
[RESPONSE A]
|
|
||||||
{prediction}
|
|
||||||
[/RESPONSE A]
|
|
||||||
|
|
||||||
[RESPONSE B]
|
|
||||||
{prediction_b}
|
|
||||||
[/RESPONSE B]"""
|
|
||||||
PROMPT = PromptTemplate(
|
|
||||||
input_variables=["input", "prediction", "prediction_b", "criteria"],
|
|
||||||
template=template,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
template = """Act as a fair judge and rate the two responses to the question below.\
|
COMPARISON_TEMPLATE = ChatPromptTemplate.from_messages(
|
||||||
Choose the response that best followed the instructions and answered the question.\
|
[
|
||||||
Your assessment should weigh the following criteria:
|
("system", SYSTEM_MESSAGE),
|
||||||
{criteria}\
|
(
|
||||||
Start by comparing both responses and give a brief rationale.\
|
"human",
|
||||||
Avoid bias from the order of presentation or response length.\
|
"{criteria}[User Question]\n{input}\n\n\
|
||||||
Weigh accuracy based on the following ground truth reference\
|
[The Start of Assistant A's Answer]\n{prediction}\n\
|
||||||
answer to the question:
|
[The End of Assistant A's Answer]\
|
||||||
|
\n\n[The Start of Assistant B's Answer]\n{prediction_b}\n\
|
||||||
[REFERENCE]
|
[The End of Assistant B's Answer]",
|
||||||
{reference}
|
),
|
||||||
[/REFERENCE]
|
]
|
||||||
|
)
|
||||||
After giving your rationale, make your final decision using this format:\
|
|
||||||
"[[A]]" if assistant A is better, "[[B]]" if assistant B is better,\
|
COMPARISON_TEMPLATE_WITH_REFERENCE = ChatPromptTemplate.from_messages(
|
||||||
and "[[C]]" for a tie. Finally, repeat the decision again on its own on a new line.
|
[
|
||||||
|
("system", SYSTEM_MESSAGE),
|
||||||
[QUESTION]
|
(
|
||||||
{input}
|
"human",
|
||||||
[/QUESTION]
|
"{criteria}\n\nTo help you evaluate the responses, \
|
||||||
|
here is a reference answer to the user's question:\n\
|
||||||
[RESPONSE A]
|
{reference}\
|
||||||
{prediction}
|
[User Question]\n{input}\n\n\
|
||||||
[/RESPONSE A]
|
[The Start of Assistant A's Answer]\n{prediction}\n\
|
||||||
|
[The End of Assistant A's Answer]\
|
||||||
[RESPONSE B]
|
\n\n[The Start of Assistant B's Answer]\n{prediction_b}\n\
|
||||||
{prediction_b}
|
[The End of Assistant B's Answer]",
|
||||||
[/RESPONSE B]"""
|
),
|
||||||
|
]
|
||||||
PROMPT_WITH_REFERENCE = PromptTemplate(
|
|
||||||
input_variables=["input", "prediction", "prediction_b", "reference", "criteria"],
|
|
||||||
template=template,
|
|
||||||
)
|
)
|
||||||
|
@ -34,7 +34,7 @@ def test_PairwiseStringResultOutputParser_parse() -> None:
|
|||||||
[[A]]"""
|
[[A]]"""
|
||||||
got = output_parser.parse(text)
|
got = output_parser.parse(text)
|
||||||
want = {
|
want = {
|
||||||
"reasoning": "I like pie better than cake.",
|
"reasoning": text,
|
||||||
"value": "A",
|
"value": "A",
|
||||||
"score": 1,
|
"score": 1,
|
||||||
}
|
}
|
||||||
@ -46,7 +46,7 @@ def test_PairwiseStringResultOutputParser_parse() -> None:
|
|||||||
[[B]]"""
|
[[B]]"""
|
||||||
got = output_parser.parse(text)
|
got = output_parser.parse(text)
|
||||||
want = {
|
want = {
|
||||||
"reasoning": "I like cake better than pie.",
|
"reasoning": text,
|
||||||
"value": "B",
|
"value": "B",
|
||||||
"score": 0,
|
"score": 0,
|
||||||
}
|
}
|
||||||
@ -58,7 +58,7 @@ def test_PairwiseStringResultOutputParser_parse() -> None:
|
|||||||
[[C]]"""
|
[[C]]"""
|
||||||
got = output_parser.parse(text)
|
got = output_parser.parse(text)
|
||||||
want = {
|
want = {
|
||||||
"reasoning": "I like cake and pie.",
|
"reasoning": text,
|
||||||
"value": None,
|
"value": None,
|
||||||
"score": 0.5,
|
"score": 0.5,
|
||||||
}
|
}
|
||||||
@ -84,7 +84,7 @@ def test_pairwise_string_comparison_chain() -> None:
|
|||||||
)
|
)
|
||||||
assert res["value"] is None
|
assert res["value"] is None
|
||||||
assert res["score"] == 0.5
|
assert res["score"] == 0.5
|
||||||
assert res["reasoning"] == "The values are the same."
|
assert res["reasoning"] == "The values are the same.\n[[C]]"
|
||||||
res = chain.evaluate_string_pairs(
|
res = chain.evaluate_string_pairs(
|
||||||
prediction="I like pie.",
|
prediction="I like pie.",
|
||||||
prediction_b="I like pie.",
|
prediction_b="I like pie.",
|
||||||
|
Loading…
Reference in New Issue
Block a user