mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
Add LLMCheckerChain (#281)
Implementation of https://github.com/jagilley/fact-checker. Works pretty well. <img width="993" alt="Screenshot 2022-12-07 at 4 41 47 PM" src="https://user-images.githubusercontent.com/101075607/206302751-356a19ff-d000-4798-9aee-9c38b7f532b9.png"> Verifying this manually: 1. "Only two kinds of egg-laying mammals are left on the planet today—the duck-billed platypus and the echidna, or spiny anteater." https://www.scientificamerican.com/article/extreme-monotremes/ 2. "An [Echidna] egg weighs 1.5 to 2 grams (0.05 to 0.07 oz)[[19]](https://en.wikipedia.org/wiki/Echidna#cite_note-19) and is about 1.4 centimetres (0.55 in) long." https://en.wikipedia.org/wiki/Echidna#:~:text=sleep%20is%20suppressed.-,Reproduction,a%20reptile%2Dlike%20egg%20tooth. 3. "A [platypus] lays one to three (usually two) small, leathery eggs (similar to those of reptiles), about 11 mm (7⁄16 in) in diameter and slightly rounder than bird eggs." https://en.wikipedia.org/wiki/Platypus#:~:text=It%20lays%20one%20to%20three,slightly%20rounder%20than%20bird%20eggs. 4. Therefore, an Echidna is the mammal that lays the biggest eggs. cc @hwchase17
This commit is contained in:
parent
43c9bd869f
commit
5267ebce2d
58
docs/examples/chains/llm_checker.ipynb
Normal file
58
docs/examples/chains/llm_checker.ipynb
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# LLMCheckerChain\n",
|
||||||
|
"This notebook showcases how to use LLMCheckerChain."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import LLMCheckerChain\n",
|
||||||
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"llm = OpenAI(temperature=0.7)\n",
|
||||||
|
"\n",
|
||||||
|
"text = \"What type of mammal lays the biggest eggs?\"\n",
|
||||||
|
"\n",
|
||||||
|
"checker_chain = LLMCheckerChain(llm=llm, verbose=True)\n",
|
||||||
|
"\n",
|
||||||
|
"checker_chain.run(text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -5,6 +5,7 @@ from langchain.chains import (
|
|||||||
ConversationChain,
|
ConversationChain,
|
||||||
LLMBashChain,
|
LLMBashChain,
|
||||||
LLMChain,
|
LLMChain,
|
||||||
|
LLMCheckerChain,
|
||||||
LLMMathChain,
|
LLMMathChain,
|
||||||
PALChain,
|
PALChain,
|
||||||
QAWithSourcesChain,
|
QAWithSourcesChain,
|
||||||
@ -27,6 +28,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"LLMChain",
|
"LLMChain",
|
||||||
"LLMBashChain",
|
"LLMBashChain",
|
||||||
|
"LLMCheckerChain",
|
||||||
"LLMMathChain",
|
"LLMMathChain",
|
||||||
"SelfAskWithSearchChain",
|
"SelfAskWithSearchChain",
|
||||||
"SerpAPIWrapper",
|
"SerpAPIWrapper",
|
||||||
|
@ -3,6 +3,7 @@ from langchain.chains.api.base import APIChain
|
|||||||
from langchain.chains.conversation.base import ConversationChain
|
from langchain.chains.conversation.base import ConversationChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_bash.base import LLMBashChain
|
from langchain.chains.llm_bash.base import LLMBashChain
|
||||||
|
from langchain.chains.llm_checker.base import LLMCheckerChain
|
||||||
from langchain.chains.llm_math.base import LLMMathChain
|
from langchain.chains.llm_math.base import LLMMathChain
|
||||||
from langchain.chains.llm_requests import LLMRequestsChain
|
from langchain.chains.llm_requests import LLMRequestsChain
|
||||||
from langchain.chains.mapreduce import MapReduceChain
|
from langchain.chains.mapreduce import MapReduceChain
|
||||||
@ -19,6 +20,7 @@ __all__ = [
|
|||||||
"ConversationChain",
|
"ConversationChain",
|
||||||
"LLMChain",
|
"LLMChain",
|
||||||
"LLMBashChain",
|
"LLMBashChain",
|
||||||
|
"LLMCheckerChain",
|
||||||
"LLMMathChain",
|
"LLMMathChain",
|
||||||
"PALChain",
|
"PALChain",
|
||||||
"QAWithSourcesChain",
|
"QAWithSourcesChain",
|
||||||
|
4
langchain/chains/llm_checker/__init__.py
Normal file
4
langchain/chains/llm_checker/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
"""Chain that tries to verify assumptions before answering a question.
|
||||||
|
|
||||||
|
Heavily borrowed from https://github.com/jagilley/fact-checker
|
||||||
|
"""
|
98
langchain/chains/llm_checker/base.py
Normal file
98
langchain/chains/llm_checker/base.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
"""Chain for question-answering with self-verification."""
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.llm_checker.prompt import (
|
||||||
|
CHECK_ASSERTIONS_PROMPT,
|
||||||
|
CREATE_DRAFT_ANSWER_PROMPT,
|
||||||
|
LIST_ASSERTIONS_PROMPT,
|
||||||
|
REVISED_ANSWER_PROMPT,
|
||||||
|
)
|
||||||
|
from langchain.chains.sequential import SequentialChain
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCheckerChain(Chain, BaseModel):
|
||||||
|
"""Chain for question-answering with self-verification.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
from langchain import OpenAI, LLMCheckerChain
|
||||||
|
llm = OpenAI(temperature=0.7)
|
||||||
|
checker_chain = LLMCheckerChain(llm=llm)
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm: LLM
|
||||||
|
"""LLM wrapper to use."""
|
||||||
|
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
|
||||||
|
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
|
||||||
|
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
|
||||||
|
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT
|
||||||
|
"""Prompt to use when questioning the documents."""
|
||||||
|
input_key: str = "query" #: :meta private:
|
||||||
|
output_key: str = "result" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Return the singular input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return the singular output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
|
||||||
|
create_draft_answer_chain = LLMChain(
|
||||||
|
llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement"
|
||||||
|
)
|
||||||
|
list_assertions_chain = LLMChain(
|
||||||
|
llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions"
|
||||||
|
)
|
||||||
|
check_assertions_chain = LLMChain(
|
||||||
|
llm=self.llm,
|
||||||
|
prompt=self.check_assertions_prompt,
|
||||||
|
output_key="checked_assertions",
|
||||||
|
)
|
||||||
|
|
||||||
|
revised_answer_chain = LLMChain(
|
||||||
|
llm=self.llm,
|
||||||
|
prompt=self.revised_answer_prompt,
|
||||||
|
output_key="revised_statement",
|
||||||
|
)
|
||||||
|
|
||||||
|
chains = [
|
||||||
|
create_draft_answer_chain,
|
||||||
|
list_assertions_chain,
|
||||||
|
check_assertions_chain,
|
||||||
|
revised_answer_chain,
|
||||||
|
]
|
||||||
|
|
||||||
|
question_to_checked_assertions_chain = SequentialChain(
|
||||||
|
chains=chains,
|
||||||
|
input_variables=["question"],
|
||||||
|
output_variables=["revised_statement"],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
output = question_to_checked_assertions_chain({"question": question})
|
||||||
|
return {self.output_key: output["revised_statement"]}
|
31
langchain/chains/llm_checker/prompt.py
Normal file
31
langchain/chains/llm_checker/prompt.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
_CREATE_DRAFT_ANSWER_TEMPLATE = """{question}\n\n"""
|
||||||
|
CREATE_DRAFT_ANSWER_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["question"], template=_CREATE_DRAFT_ANSWER_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
|
_LIST_ASSERTIONS_TEMPLATE = """Here is a statement:
|
||||||
|
{statement}
|
||||||
|
Make a bullet point list of the assumptions you made when producing the above statement.\n\n"""
|
||||||
|
LIST_ASSERTIONS_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["statement"], template=_LIST_ASSERTIONS_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
|
_CHECK_ASSERTIONS_TEMPLATE = """Here is a bullet point list of assertions:
|
||||||
|
{assertions}
|
||||||
|
For each assertion, determine whether it is true or false. If it is false, explain why.\n\n"""
|
||||||
|
CHECK_ASSERTIONS_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["assertions"], template=_CHECK_ASSERTIONS_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
|
_REVISED_ANSWER_TEMPLATE = """{checked_assertions}
|
||||||
|
|
||||||
|
Question: In light of the above assertions and checks, how would you answer the question '{question}'?
|
||||||
|
|
||||||
|
Answer:"""
|
||||||
|
REVISED_ANSWER_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["checked_assertions", "question"],
|
||||||
|
template=_REVISED_ANSWER_TEMPLATE,
|
||||||
|
)
|
43
tests/unit_tests/chains/test_llm_checker.py
Normal file
43
tests/unit_tests/chains/test_llm_checker.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# flake8: noqa E501
|
||||||
|
|
||||||
|
"""Test LLMCheckerChain functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chains.llm_checker.base import LLMCheckerChain
|
||||||
|
from langchain.chains.llm_checker.prompt import (
|
||||||
|
_CHECK_ASSERTIONS_TEMPLATE,
|
||||||
|
_CREATE_DRAFT_ANSWER_TEMPLATE,
|
||||||
|
_LIST_ASSERTIONS_TEMPLATE,
|
||||||
|
_REVISED_ANSWER_TEMPLATE,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_llm_checker_chain() -> LLMCheckerChain:
|
||||||
|
"""Fake LLMCheckerChain for testing."""
|
||||||
|
queries = {
|
||||||
|
_CREATE_DRAFT_ANSWER_TEMPLATE.format(
|
||||||
|
question="Which mammal lays the biggest eggs?"
|
||||||
|
): "I don't know which mammal layers the biggest eggs.",
|
||||||
|
_LIST_ASSERTIONS_TEMPLATE.format(
|
||||||
|
statement="I don't know which mammal layers the biggest eggs.",
|
||||||
|
): "1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.",
|
||||||
|
_CHECK_ASSERTIONS_TEMPLATE.format(
|
||||||
|
assertions="1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.",
|
||||||
|
): "1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE",
|
||||||
|
_REVISED_ANSWER_TEMPLATE.format(
|
||||||
|
checked_assertions="1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE",
|
||||||
|
question="Which mammal lays the biggest eggs?",
|
||||||
|
): "I still don't know.",
|
||||||
|
}
|
||||||
|
fake_llm = FakeLLM(queries=queries)
|
||||||
|
return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a")
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None:
|
||||||
|
"""Test simple question that should not need python."""
|
||||||
|
question = "Which mammal lays the biggest eggs?"
|
||||||
|
output = fake_llm_checker_chain.run(question)
|
||||||
|
assert output == "I still don't know."
|
Loading…
Reference in New Issue
Block a user