From 5267ebce2d9144ef0b530ae04434008fc5618d8d Mon Sep 17 00:00:00 2001 From: andersenchen <101075607+andersenchen@users.noreply.github.com> Date: Fri, 9 Dec 2022 15:49:05 -0500 Subject: [PATCH] Add LLMCheckerChain (#281) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implementation of https://github.com/jagilley/fact-checker. Works pretty well. Screenshot 2022-12-07 at 4 41 47 PM Verifying this manually: 1. "Only two kinds of egg-laying mammals are left on the planet today—the duck-billed platypus and the echidna, or spiny anteater." https://www.scientificamerican.com/article/extreme-monotremes/ 2. "An [Echidna] egg weighs 1.5 to 2 grams (0.05 to 0.07 oz)[[19]](https://en.wikipedia.org/wiki/Echidna#cite_note-19) and is about 1.4 centimetres (0.55 in) long." https://en.wikipedia.org/wiki/Echidna#:~:text=sleep%20is%20suppressed.-,Reproduction,a%20reptile%2Dlike%20egg%20tooth. 3. "A [platypus] lays one to three (usually two) small, leathery eggs (similar to those of reptiles), about 11 mm (7⁄16 in) in diameter and slightly rounder than bird eggs." https://en.wikipedia.org/wiki/Platypus#:~:text=It%20lays%20one%20to%20three,slightly%20rounder%20than%20bird%20eggs. 4. Therefore, an Echidna is the mammal that lays the biggest eggs. cc @hwchase17 --- docs/examples/chains/llm_checker.ipynb | 58 ++++++++++++ langchain/__init__.py | 2 + langchain/chains/__init__.py | 2 + langchain/chains/llm_checker/__init__.py | 4 + langchain/chains/llm_checker/base.py | 98 +++++++++++++++++++++ langchain/chains/llm_checker/prompt.py | 31 +++++++ tests/unit_tests/chains/test_llm_checker.py | 43 +++++++++ 7 files changed, 238 insertions(+) create mode 100644 docs/examples/chains/llm_checker.ipynb create mode 100644 langchain/chains/llm_checker/__init__.py create mode 100644 langchain/chains/llm_checker/base.py create mode 100644 langchain/chains/llm_checker/prompt.py create mode 100644 tests/unit_tests/chains/test_llm_checker.py diff --git a/docs/examples/chains/llm_checker.ipynb b/docs/examples/chains/llm_checker.ipynb new file mode 100644 index 00000000000..43aa0d1f59a --- /dev/null +++ b/docs/examples/chains/llm_checker.ipynb @@ -0,0 +1,58 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLMCheckerChain\n", + "This notebook showcases how to use LLMCheckerChain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import LLMCheckerChain\n", + "from langchain.llms import OpenAI\n", + "\n", + "llm = OpenAI(temperature=0.7)\n", + "\n", + "text = \"What type of mammal lays the biggest eggs?\"\n", + "\n", + "checker_chain = LLMCheckerChain(llm=llm, verbose=True)\n", + "\n", + "checker_chain.run(text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index 8d8179b76bd..647ab939019 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -5,6 +5,7 @@ from langchain.chains import ( ConversationChain, LLMBashChain, LLMChain, + LLMCheckerChain, LLMMathChain, PALChain, QAWithSourcesChain, @@ -27,6 +28,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch __all__ = [ "LLMChain", "LLMBashChain", + "LLMCheckerChain", "LLMMathChain", "SelfAskWithSearchChain", "SerpAPIWrapper", diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 5e1e551004b..eb65ba1d51b 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -3,6 +3,7 @@ from langchain.chains.api.base import APIChain from langchain.chains.conversation.base import ConversationChain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.base import LLMBashChain +from langchain.chains.llm_checker.base import LLMCheckerChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_requests import LLMRequestsChain from langchain.chains.mapreduce import MapReduceChain @@ -19,6 +20,7 @@ __all__ = [ "ConversationChain", "LLMChain", "LLMBashChain", + "LLMCheckerChain", "LLMMathChain", "PALChain", "QAWithSourcesChain", diff --git a/langchain/chains/llm_checker/__init__.py b/langchain/chains/llm_checker/__init__.py new file mode 100644 index 00000000000..95516d81e6a --- /dev/null +++ b/langchain/chains/llm_checker/__init__.py @@ -0,0 +1,4 @@ +"""Chain that tries to verify assumptions before answering a question. + +Heavily borrowed from https://github.com/jagilley/fact-checker +""" diff --git a/langchain/chains/llm_checker/base.py b/langchain/chains/llm_checker/base.py new file mode 100644 index 00000000000..4dfc1ba38b4 --- /dev/null +++ b/langchain/chains/llm_checker/base.py @@ -0,0 +1,98 @@ +"""Chain for question-answering with self-verification.""" + + +from typing import Dict, List + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.llm_checker.prompt import ( + CHECK_ASSERTIONS_PROMPT, + CREATE_DRAFT_ANSWER_PROMPT, + LIST_ASSERTIONS_PROMPT, + REVISED_ANSWER_PROMPT, +) +from langchain.chains.sequential import SequentialChain +from langchain.llms.base import LLM +from langchain.prompts import PromptTemplate + + +class LLMCheckerChain(Chain, BaseModel): + """Chain for question-answering with self-verification. + + Example: + .. code-block:: python + from langchain import OpenAI, LLMCheckerChain + llm = OpenAI(temperature=0.7) + checker_chain = LLMCheckerChain(llm=llm) + """ + + llm: LLM + """LLM wrapper to use.""" + create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT + list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT + check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT + revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT + """Prompt to use when questioning the documents.""" + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Return the singular input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the singular output key. + + :meta private: + """ + return [self.output_key] + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + question = inputs[self.input_key] + + create_draft_answer_chain = LLMChain( + llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement" + ) + list_assertions_chain = LLMChain( + llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions" + ) + check_assertions_chain = LLMChain( + llm=self.llm, + prompt=self.check_assertions_prompt, + output_key="checked_assertions", + ) + + revised_answer_chain = LLMChain( + llm=self.llm, + prompt=self.revised_answer_prompt, + output_key="revised_statement", + ) + + chains = [ + create_draft_answer_chain, + list_assertions_chain, + check_assertions_chain, + revised_answer_chain, + ] + + question_to_checked_assertions_chain = SequentialChain( + chains=chains, + input_variables=["question"], + output_variables=["revised_statement"], + verbose=True, + ) + output = question_to_checked_assertions_chain({"question": question}) + return {self.output_key: output["revised_statement"]} diff --git a/langchain/chains/llm_checker/prompt.py b/langchain/chains/llm_checker/prompt.py new file mode 100644 index 00000000000..73c883d0c20 --- /dev/null +++ b/langchain/chains/llm_checker/prompt.py @@ -0,0 +1,31 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + +_CREATE_DRAFT_ANSWER_TEMPLATE = """{question}\n\n""" +CREATE_DRAFT_ANSWER_PROMPT = PromptTemplate( + input_variables=["question"], template=_CREATE_DRAFT_ANSWER_TEMPLATE +) + +_LIST_ASSERTIONS_TEMPLATE = """Here is a statement: +{statement} +Make a bullet point list of the assumptions you made when producing the above statement.\n\n""" +LIST_ASSERTIONS_PROMPT = PromptTemplate( + input_variables=["statement"], template=_LIST_ASSERTIONS_TEMPLATE +) + +_CHECK_ASSERTIONS_TEMPLATE = """Here is a bullet point list of assertions: +{assertions} +For each assertion, determine whether it is true or false. If it is false, explain why.\n\n""" +CHECK_ASSERTIONS_PROMPT = PromptTemplate( + input_variables=["assertions"], template=_CHECK_ASSERTIONS_TEMPLATE +) + +_REVISED_ANSWER_TEMPLATE = """{checked_assertions} + +Question: In light of the above assertions and checks, how would you answer the question '{question}'? + +Answer:""" +REVISED_ANSWER_PROMPT = PromptTemplate( + input_variables=["checked_assertions", "question"], + template=_REVISED_ANSWER_TEMPLATE, +) diff --git a/tests/unit_tests/chains/test_llm_checker.py b/tests/unit_tests/chains/test_llm_checker.py new file mode 100644 index 00000000000..0c9b9343550 --- /dev/null +++ b/tests/unit_tests/chains/test_llm_checker.py @@ -0,0 +1,43 @@ +# flake8: noqa E501 + +"""Test LLMCheckerChain functionality.""" + +import pytest + +from langchain.chains.llm_checker.base import LLMCheckerChain +from langchain.chains.llm_checker.prompt import ( + _CHECK_ASSERTIONS_TEMPLATE, + _CREATE_DRAFT_ANSWER_TEMPLATE, + _LIST_ASSERTIONS_TEMPLATE, + _REVISED_ANSWER_TEMPLATE, +) +from tests.unit_tests.llms.fake_llm import FakeLLM + + +@pytest.fixture +def fake_llm_checker_chain() -> LLMCheckerChain: + """Fake LLMCheckerChain for testing.""" + queries = { + _CREATE_DRAFT_ANSWER_TEMPLATE.format( + question="Which mammal lays the biggest eggs?" + ): "I don't know which mammal layers the biggest eggs.", + _LIST_ASSERTIONS_TEMPLATE.format( + statement="I don't know which mammal layers the biggest eggs.", + ): "1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.", + _CHECK_ASSERTIONS_TEMPLATE.format( + assertions="1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.", + ): "1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE", + _REVISED_ANSWER_TEMPLATE.format( + checked_assertions="1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE", + question="Which mammal lays the biggest eggs?", + ): "I still don't know.", + } + fake_llm = FakeLLM(queries=queries) + return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a") + + +def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None: + """Test simple question that should not need python.""" + question = "Which mammal lays the biggest eggs?" + output = fake_llm_checker_chain.run(question) + assert output == "I still don't know."