From f36ef0739dbb548cabdb4453e6819fc3d826414f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 26 Dec 2023 17:28:10 -0800 Subject: [PATCH] Add create_conv_retrieval_chain func (#15084) ``` +----------+ | MapInput | **+----------+**** **** **** **** *** ** **** +------------------------------------+ ** | Lambda(itemgetter('chat_history')) | * +------------------------------------+ * * * * * * * +---------------------------+ +--------------------------------+ | Lambda(_get_chat_history) | | Lambda(itemgetter('question')) | +---------------------------+ +--------------------------------+ * * * * * * +----------------------------+ +------------------------+ | ContextSet('chat_history') | | ContextSet('question') | +----------------------------+ +------------------------+ **** **** **** **** ** ** +-----------+ | MapOutput | +-----------+ * * * +----------------+ | PromptTemplate | +----------------+ * * * +-------------+ | FakeListLLM | +-------------+ * * * +-----------------+ | StrOutputParser | +-----------------+ * * * +----------------------------+ | ContextSet('new_question') | +----------------------------+ * * * +---------------------+ | SequentialRetriever | +---------------------+ * * * +------------------------------------+ | Lambda(_reduce_tokens_below_limit) | +------------------------------------+ * * * +-------------------------------+ | ContextSet('input_documents') | +-------------------------------+ * * * +----------+ ***| MapInput |**** ******* +----------+ ******** ******** * ******* ******* * ******** **** * **** +-------------------------------+ +----------------------------+ +----------------------------+ | ContextGet('input_documents') | | ContextGet('chat_history') | | ContextGet('new_question') | +-------------------------------+**** +----------------------------+ +----------------------------+ ********* * ******* ******** * ****** ***** * **** +-----------+ | MapOutput | +-----------+ * * * +-------------+ | FakeListLLM | +-------------+ * * * +----------+ ***| MapInput |*** ******** +----------+ ****** ******* * ***** ******** * ****** **** * *** +-------------------------------+ +----------------------------+ +-------------+ | ContextGet('input_documents') | | ContextGet('new_question') | **| Passthrough | +-------------------------------+ +----------------------------+ ******* +-------------+ ******* * ****** ****** * ******* **** * **** +-----------+ | MapOutput | +-----------+ ``` --------- Co-authored-by: Harrison Chase --- libs/core/langchain_core/retrievers.py | 9 ++- libs/langchain/langchain/chains/__init__.py | 9 ++- .../chains/conversational_retrieval/base.py | 2 +- .../chains/history_aware_retriever.py | 67 +++++++++++++++++ libs/langchain/langchain/chains/retrieval.py | 71 +++++++++++++++++++ libs/langchain/poetry.lock | 44 +++++++++--- libs/langchain/pyproject.toml | 2 +- .../chains/test_conversation_retrieval.py | 11 ++- .../chains/test_history_aware_retriever.py | 26 +++++++ .../tests/unit_tests/chains/test_imports.py | 2 + .../tests/unit_tests/chains/test_retrieval.py | 32 +++++++++ .../unit_tests/retrievers/parrot_retriever.py | 20 ++++++ 12 files changed, 275 insertions(+), 20 deletions(-) create mode 100644 libs/langchain/langchain/chains/history_aware_retriever.py create mode 100644 libs/langchain/langchain/chains/retrieval.py create mode 100644 libs/langchain/tests/unit_tests/chains/test_history_aware_retriever.py create mode 100644 libs/langchain/tests/unit_tests/chains/test_retrieval.py create mode 100644 libs/langchain/tests/unit_tests/retrievers/parrot_retriever.py diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 7da99ae56e6..5fe912e0321 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain_core.documents import Document from langchain_core.load.dump import dumpd -from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable if TYPE_CHECKING: from langchain_core.callbacks.manager import ( @@ -18,8 +18,13 @@ if TYPE_CHECKING: Callbacks, ) +RetrieverInput = str +RetrieverOutput = List[Document] +RetrieverLike = Runnable[RetrieverInput, RetrieverOutput] +RetrieverOutputLike = Runnable[Any, RetrieverOutput] -class BaseRetriever(RunnableSerializable[str, List[Document]], ABC): + +class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): """Abstract base class for a Document retrieval system. A retrieval system is defined as something that can take string queries and return diff --git a/libs/langchain/langchain/chains/__init__.py b/libs/langchain/langchain/chains/__init__.py index f03927512fa..9df02c6d0d3 100644 --- a/libs/langchain/langchain/chains/__init__.py +++ b/libs/langchain/langchain/chains/__init__.py @@ -42,6 +42,7 @@ from langchain.chains.graph_qa.kuzu import KuzuQAChain from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain from langchain.chains.graph_qa.sparql import GraphSparqlQAChain +from langchain.chains.history_aware_retriever import create_history_aware_retriever from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain from langchain.chains.llm_checker.base import LLMCheckerChain @@ -65,7 +66,11 @@ from langchain.chains.qa_generation.base import QAGenerationChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain -from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA +from langchain.chains.retrieval import create_retrieval_chain +from langchain.chains.retrieval_qa.base import ( + RetrievalQA, + VectorDBQA, +) from langchain.chains.router import ( LLMRouterChain, MultiPromptChain, @@ -133,4 +138,6 @@ __all__ = [ "generate_example", "load_chain", "create_sql_query_chain", + "create_retrieval_chain", + "create_history_aware_retriever", ] diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 0b15f4e41c9..304fc95e34c 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -13,7 +13,7 @@ from langchain_core.messages import BaseMessage from langchain_core.prompts import BasePromptTemplate from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain_core.retrievers import BaseRetriever -from langchain_core.runnables.config import RunnableConfig +from langchain_core.runnables import RunnableConfig from langchain_core.vectorstores import VectorStore from langchain.callbacks.manager import ( diff --git a/libs/langchain/langchain/chains/history_aware_retriever.py b/libs/langchain/langchain/chains/history_aware_retriever.py new file mode 100644 index 00000000000..6e4704e7a89 --- /dev/null +++ b/libs/langchain/langchain/chains/history_aware_retriever.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from langchain_core.language_models import LanguageModelLike +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import BasePromptTemplate +from langchain_core.retrievers import RetrieverLike, RetrieverOutputLike +from langchain_core.runnables import RunnableBranch + + +def create_history_aware_retriever( + llm: LanguageModelLike, + retriever: RetrieverLike, + prompt: BasePromptTemplate, +) -> RetrieverOutputLike: + """Create a chain that takes conversation history and returns documents. + + If there is no `chat_history`, then the `input` is just passed directly to the + retriever. If there is `chat_history`, then the prompt and LLM will be used + to generate a search query. That search query is then passed to the retriever. + + Args: + llm: Language model to use for generating a search term given chat history + retriever: RetrieverLike object that takes a string as input and outputs + a list of Documents. + prompt: The prompt used to generate the search query for the retriever. + + Returns: + An LCEL Runnable. The runnable input must take in `input`, and if there + is chat history should take it in the form of `chat_history`. + The Runnable output is a list of Documents + + Example: + .. code-block:: python + + # pip install -U langchain langchain-community + + from langchain_community.chat_models import ChatOpenAI + from langchain.chains import create_chat_history_retriever + from langchain import hub + + rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase") + llm = ChatOpenAI() + retriever = ... + chat_retriever_chain = create_chat_retriever_chain( + llm, retriever, rephrase_prompt + ) + + chain.invoke({"input": "...", "chat_history": }) + + """ + if "input" not in prompt.input_variables: + raise ValueError( + "Expected `input` to be a prompt variable, " + f"but got {prompt.input_variables}" + ) + + retrieve_documents: RetrieverOutputLike = RunnableBranch( + ( + # Both empty string and empty list evaluate to False + lambda x: not x.get("chat_history", False), + # If no chat history, then we just pass input to retriever + (lambda x: x["input"]) | retriever, + ), + # If chat history, then we pass inputs to LLM chain, then to retriever + prompt | llm | StrOutputParser() | retriever, + ).with_config(run_name="chat_retriever_chain") + return retrieve_documents diff --git a/libs/langchain/langchain/chains/retrieval.py b/libs/langchain/langchain/chains/retrieval.py new file mode 100644 index 00000000000..ea53ff99ece --- /dev/null +++ b/libs/langchain/langchain/chains/retrieval.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import Any, Dict, Union + +from langchain_core.retrievers import ( + BaseRetriever, + RetrieverOutput, +) +from langchain_core.runnables import Runnable, RunnablePassthrough + + +def create_retrieval_chain( + retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]], + combine_docs_chain: Runnable[Dict[str, Any], str], +) -> Runnable: + """Create retrieval chain that retrieves documents and then passes them on. + + Args: + retriever: Retriever-like object that returns list of documents. Should + either be a subclass of BaseRetriever or a Runnable that returns + a list of documents. If a subclass of BaseRetriever, then it + is expected that an `input` key be passed in - this is what + is will be used to pass into the retriever. If this is NOT a + subclass of BaseRetriever, then all the inputs will be passed + into this runnable, meaning that runnable should take a dictionary + as input. + combine_docs_chain: Runnable that takes inputs and produces a string output. + The inputs to this will be any original inputs to this chain, a new + context key with the retrieved documents, and chat_history (if not present + in the inputs) with a value of `[]` (to easily enable conversational + retrieval. + + Returns: + An LCEL Runnable. The Runnable return is a dictionary containing at the very + least a `context` and `answer` key. + + Example: + .. code-block:: python + + # pip install -U langchain langchain-community + + from langchain_community.chat_models import ChatOpenAI + from langchain.chains.combine_documents import create_stuff_documents_chain + from langchain.chains import create_retrieval_chain + from langchain import hub + + retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat") + llm = ChatOpenAI() + retriever = ... + combine_docs_chain = create_stuff_documents_chain( + llm, retrieval_qa_chat_prompt + ) + retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain) + + chain.invoke({"input": "..."}) + + """ + if not isinstance(retriever, BaseRetriever): + retrieval_docs: Runnable[dict, RetrieverOutput] = retriever + else: + retrieval_docs = (lambda x: x["input"]) | retriever + + retrieval_chain = ( + RunnablePassthrough.assign( + context=retrieval_docs.with_config(run_name="retrieve_documents"), + chat_history=lambda x: x.get("chat_history", []), + ) + | RunnablePassthrough.assign(answer=combine_docs_chain) + ).with_config(run_name="retrieval_chain") + + return retrieval_chain diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 2409e6a8d58..6bf802ed96b 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiodns" @@ -3049,7 +3049,6 @@ files = [ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"}, {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"}, - {file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"}, @@ -3133,6 +3132,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -3726,6 +3726,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -5251,6 +5261,8 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -5293,6 +5305,7 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -5301,6 +5314,8 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -5773,6 +5788,7 @@ files = [ {file = "pymongo-4.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6422b6763b016f2ef2beedded0e546d6aa6ba87910f9244d86e0ac7690f75c96"}, {file = "pymongo-4.5.0-cp312-cp312-win32.whl", hash = "sha256:77cfff95c1fafd09e940b3fdcb7b65f11442662fad611d0e69b4dd5d17a81c60"}, {file = "pymongo-4.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:e57d859b972c75ee44ea2ef4758f12821243e99de814030f69a3decb2aa86807"}, + {file = "pymongo-4.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8443f3a8ab2d929efa761c6ebce39a6c1dca1c9ac186ebf11b62c8fe1aef53f4"}, {file = "pymongo-4.5.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:2b0176f9233a5927084c79ff80b51bd70bfd57e4f3d564f50f80238e797f0c8a"}, {file = "pymongo-4.5.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:89b3f2da57a27913d15d2a07d58482f33d0a5b28abd20b8e643ab4d625e36257"}, {file = "pymongo-4.5.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:5caee7bd08c3d36ec54617832b44985bd70c4cbd77c5b313de6f7fce0bb34f93"}, @@ -6065,21 +6081,21 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no [[package]] name = "pytest-asyncio" -version = "0.20.3" +version = "0.23.2" description = "Pytest support for asyncio" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytest-asyncio-0.20.3.tar.gz", hash = "sha256:83cbf01169ce3e8eb71c6c278ccb0574d1a7a3bb8eaaf5e50e0ad342afb33b36"}, - {file = "pytest_asyncio-0.20.3-py3-none-any.whl", hash = "sha256:f129998b209d04fcc65c96fc85c11e5316738358909a8399e93be553d7656442"}, + {file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"}, + {file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"}, ] [package.dependencies] -pytest = ">=6.1.0" +pytest = ">=7.0.0" [package.extras] docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] -testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] name = "pytest-cov" @@ -6289,6 +6305,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -6296,8 +6313,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -6314,6 +6338,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -6321,6 +6346,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -9077,4 +9103,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "e93141191088db7b4aec1a976ebd8cb20075e26d4a987bf97c0495ad865b7460" +content-hash = "65a21aaeb20f13601e11567bdb582c8b486b23e12f63c0efa72df7675a299c52" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index b2f4e43bf77..3f048501475 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -125,7 +125,7 @@ duckdb-engine = "^0.9.2" pytest-watcher = "^0.2.6" freezegun = "^1.2.2" responses = "^0.22.0" -pytest-asyncio = "^0.20.3" +pytest-asyncio = "^0.23.2" lark = "^1.1.5" pandas = "^2.0.0" pytest-mock = "^3.10.0" diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py index d7a56603bbb..fbc61116579 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py @@ -1,15 +1,15 @@ """Test conversation chain and memory.""" -import pytest from langchain_core.documents import Document -from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain +from langchain.chains.conversational_retrieval.base import ( + ConversationalRetrievalChain, +) from langchain.llms.fake import FakeListLLM from langchain.memory.buffer import ConversationBufferMemory from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever -@pytest.mark.asyncio -async def atest_simple() -> None: +async def test_simplea() -> None: fixed_resp = "I don't know" answer = "I know the answer!" llm = FakeListLLM(responses=[answer]) @@ -31,8 +31,7 @@ async def atest_simple() -> None: assert got["answer"] == fixed_resp -@pytest.mark.asyncio -async def atest_fixed_message_response_when_docs_found() -> None: +async def test_fixed_message_response_when_docs_founda() -> None: fixed_resp = "I don't know" answer = "I know the answer!" llm = FakeListLLM(responses=[answer]) diff --git a/libs/langchain/tests/unit_tests/chains/test_history_aware_retriever.py b/libs/langchain/tests/unit_tests/chains/test_history_aware_retriever.py new file mode 100644 index 00000000000..ebb8a3686cb --- /dev/null +++ b/libs/langchain/tests/unit_tests/chains/test_history_aware_retriever.py @@ -0,0 +1,26 @@ +from langchain_core.documents import Document +from langchain_core.prompts import PromptTemplate + +from langchain.chains import create_history_aware_retriever +from langchain.llms.fake import FakeListLLM +from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever + + +def test_create() -> None: + answer = "I know the answer!" + llm = FakeListLLM(responses=[answer]) + retriever = FakeParrotRetriever() + question_gen_prompt = PromptTemplate.from_template("hi! {input} {chat_history}") + chain = create_history_aware_retriever(llm, retriever, question_gen_prompt) + expected_output = [Document(page_content="What is the answer?")] + output = chain.invoke({"input": "What is the answer?", "chat_history": []}) + assert output == expected_output + + output = chain.invoke({"input": "What is the answer?"}) + assert output == expected_output + + expected_output = [Document(page_content="I know the answer!")] + output = chain.invoke( + {"input": "What is the answer?", "chat_history": ["hi", "hi"]} + ) + assert output == expected_output diff --git a/libs/langchain/tests/unit_tests/chains/test_imports.py b/libs/langchain/tests/unit_tests/chains/test_imports.py index 39b57b3124c..a790c19793c 100644 --- a/libs/langchain/tests/unit_tests/chains/test_imports.py +++ b/libs/langchain/tests/unit_tests/chains/test_imports.py @@ -56,6 +56,8 @@ EXPECTED_ALL = [ "generate_example", "load_chain", "create_sql_query_chain", + "create_history_aware_retriever", + "create_retrieval_chain", ] diff --git a/libs/langchain/tests/unit_tests/chains/test_retrieval.py b/libs/langchain/tests/unit_tests/chains/test_retrieval.py new file mode 100644 index 00000000000..fcda44b07cc --- /dev/null +++ b/libs/langchain/tests/unit_tests/chains/test_retrieval.py @@ -0,0 +1,32 @@ +"""Test conversation chain and memory.""" +from langchain_core.documents import Document +from langchain_core.prompts.prompt import PromptTemplate + +from langchain.chains import create_retrieval_chain +from langchain.llms.fake import FakeListLLM +from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever + + +def test_create() -> None: + answer = "I know the answer!" + llm = FakeListLLM(responses=[answer]) + retriever = FakeParrotRetriever() + question_gen_prompt = PromptTemplate.from_template("hi! {input} {chat_history}") + chain = create_retrieval_chain(retriever, question_gen_prompt | llm) + expected_output = { + "answer": "I know the answer!", + "chat_history": [], + "context": [Document(page_content="What is the answer?")], + "input": "What is the answer?", + } + output = chain.invoke({"input": "What is the answer?"}) + assert output == expected_output + + expected_output = { + "answer": "I know the answer!", + "chat_history": "foo", + "context": [Document(page_content="What is the answer?")], + "input": "What is the answer?", + } + output = chain.invoke({"input": "What is the answer?", "chat_history": "foo"}) + assert output == expected_output diff --git a/libs/langchain/tests/unit_tests/retrievers/parrot_retriever.py b/libs/langchain/tests/unit_tests/retrievers/parrot_retriever.py new file mode 100644 index 00000000000..536908979d6 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/parrot_retriever.py @@ -0,0 +1,20 @@ +from typing import List + +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + + +class FakeParrotRetriever(BaseRetriever): + """Test util that parrots the query back as documents.""" + + def _get_relevant_documents( # type: ignore[override] + self, + query: str, + ) -> List[Document]: + return [Document(page_content=query)] + + async def _aget_relevant_documents( # type: ignore[override] + self, + query: str, + ) -> List[Document]: + return [Document(page_content=query)]