mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 22:04:37 +00:00
langchain[patch]: deprecate various chains (#25310)
- [x] NatbotChain: move to community, deprecate langchain version. Update to use `prompt | llm | output_parser` instead of LLMChain. - [x] LLMMathChain: deprecate + add langgraph replacement example to API ref - [x] HypotheticalDocumentEmbedder (retriever): update to use `prompt | llm | output_parser` instead of LLMChain - [x] FlareChain: update to use `prompt | llm | output_parser` instead of LLMChain - [x] ConstitutionalChain: deprecate + add langgraph replacement example to API ref - [x] LLMChainExtractor (document compressor): update to use `prompt | llm | output_parser` instead of LLMChain - [x] LLMChainFilter (document compressor): update to use `prompt | llm | output_parser` instead of LLMChain - [x] RePhraseQueryRetriever (retriever): update to use `prompt | llm | output_parser` instead of LLMChain
This commit is contained in:
parent
66e30efa61
commit
8afbab4cf6
332
docs/docs/versions/migrating_chains/constitutional_chain.ipynb
Normal file
332
docs/docs/versions/migrating_chains/constitutional_chain.ipynb
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b57124cc-60a0-4c18-b7ce-3e483d1024a2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"---\n",
|
||||||
|
"title: Migrating from ConstitutionalChain\n",
|
||||||
|
"---"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ce8457ed-c0b1-4a74-abbd-9d3d2211270f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"[ConstitutionalChain](https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html) allowed for a LLM to critique and revise generations based on [principles](https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.models.ConstitutionalPrinciple.html), structured as combinations of critique and revision requests. For example, a principle might include a request to identify harmful content, and a request to rewrite the content.\n",
|
||||||
|
"\n",
|
||||||
|
"In `ConstitutionalChain`, this structure of critique requests and associated revisions was formatted into a LLM prompt and parsed out of string responses. This is more naturally achieved via [structured output](/docs/how_to/structured_output/) features of chat models. We can construct a simple chain in [LangGraph](https://langchain-ai.github.io/langgraph/) for this purpose. Some advantages of this approach include:\n",
|
||||||
|
"\n",
|
||||||
|
"- Leverage tool-calling capabilities of chat models that have been fine-tuned for this purpose;\n",
|
||||||
|
"- Reduce parsing errors from extracting expression from a string LLM response;\n",
|
||||||
|
"- Delegation of instructions to [message roles](/docs/concepts/#messages) (e.g., chat models can understand what a `ToolMessage` represents without the need for additional prompting);\n",
|
||||||
|
"- Support for streaming, both of individual tokens and chain steps."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b99b47ec",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install --upgrade --quiet langchain-openai"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "717c8673",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from getpass import getpass\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"OPENAI_API_KEY\"] = getpass()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e3621b62-a037-42b8-8faa-59575608bb8b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Legacy\n",
|
||||||
|
"\n",
|
||||||
|
"<details open>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "f91c9809-8ee7-4e38-881d-0ace4f6ea883",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import ConstitutionalChain, LLMChain\n",
|
||||||
|
"from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple\n",
|
||||||
|
"from langchain_core.prompts import PromptTemplate\n",
|
||||||
|
"from langchain_openai import OpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"llm = OpenAI()\n",
|
||||||
|
"\n",
|
||||||
|
"qa_prompt = PromptTemplate(\n",
|
||||||
|
" template=\"Q: {question} A:\",\n",
|
||||||
|
" input_variables=[\"question\"],\n",
|
||||||
|
")\n",
|
||||||
|
"qa_chain = LLMChain(llm=llm, prompt=qa_prompt)\n",
|
||||||
|
"\n",
|
||||||
|
"constitutional_chain = ConstitutionalChain.from_llm(\n",
|
||||||
|
" llm=llm,\n",
|
||||||
|
" chain=qa_chain,\n",
|
||||||
|
" constitutional_principles=[\n",
|
||||||
|
" ConstitutionalPrinciple(\n",
|
||||||
|
" critique_request=\"Tell if this answer is good.\",\n",
|
||||||
|
" revision_request=\"Give a better answer.\",\n",
|
||||||
|
" )\n",
|
||||||
|
" ],\n",
|
||||||
|
" return_intermediate_steps=True,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"result = constitutional_chain.invoke(\"What is the meaning of life?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "fa3d11a1-ac1f-4a9a-9ab3-b7b244daa506",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'question': 'What is the meaning of life?',\n",
|
||||||
|
" 'output': 'The meaning of life is a deeply personal and ever-evolving concept. It is a journey of self-discovery and growth, and can be different for each individual. Some may find meaning in relationships, others in achieving their goals, and some may never find a concrete answer. Ultimately, the meaning of life is what we make of it.',\n",
|
||||||
|
" 'initial_output': ' The meaning of life is a subjective concept that can vary from person to person. Some may believe that the purpose of life is to find happiness and fulfillment, while others may see it as a journey of self-discovery and personal growth. Ultimately, the meaning of life is something that each individual must determine for themselves.',\n",
|
||||||
|
" 'critiques_and_revisions': [('This answer is good in that it recognizes and acknowledges the subjective nature of the question and provides a valid and thoughtful response. However, it could have also mentioned that the meaning of life is a complex and deeply personal concept that can also change and evolve over time for each individual. Critique Needed.',\n",
|
||||||
|
" 'The meaning of life is a deeply personal and ever-evolving concept. It is a journey of self-discovery and growth, and can be different for each individual. Some may find meaning in relationships, others in achieving their goals, and some may never find a concrete answer. Ultimately, the meaning of life is what we make of it.')]}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"result"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "374ae108-f1a0-4723-9237-5259c8123c04",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Above, we've returned intermediate steps showing:\n",
|
||||||
|
"\n",
|
||||||
|
"- The original question;\n",
|
||||||
|
"- The initial output;\n",
|
||||||
|
"- Critiques and revisions;\n",
|
||||||
|
"- The final output (matching a revision)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "cdc3b527-c09e-4c77-9711-c3cc4506cd95",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"</details>\n",
|
||||||
|
"\n",
|
||||||
|
"## LangGraph\n",
|
||||||
|
"\n",
|
||||||
|
"<details open>\n",
|
||||||
|
"\n",
|
||||||
|
"Below, we use the [.with_structured_output](/docs/how_to/structured_output/) method to simultaneously generate (1) a judgment of whether a critique is needed, and (2) the critique. We surface all prompts involved for clarity and ease of customizability.\n",
|
||||||
|
"\n",
|
||||||
|
"Note that we are also able to stream intermediate steps with this implementation, so we can monitor and if needed intervene during its execution."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "917fdb73-2411-4fcc-9add-c32dc5c745da",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import List, Optional, Tuple\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple\n",
|
||||||
|
"from langchain.chains.constitutional_ai.prompts import (\n",
|
||||||
|
" CRITIQUE_PROMPT,\n",
|
||||||
|
" REVISION_PROMPT,\n",
|
||||||
|
")\n",
|
||||||
|
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||||
|
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||||
|
"from langchain_openai import ChatOpenAI\n",
|
||||||
|
"from langgraph.graph import END, START, StateGraph\n",
|
||||||
|
"from typing_extensions import Annotated, TypedDict\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class Critique(TypedDict):\n",
|
||||||
|
" \"\"\"Generate a critique, if needed.\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
" critique_needed: Annotated[bool, ..., \"Whether or not a critique is needed.\"]\n",
|
||||||
|
" critique: Annotated[str, ..., \"If needed, the critique.\"]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"critique_prompt = ChatPromptTemplate.from_template(\n",
|
||||||
|
" \"Critique this response according to the critique request. \"\n",
|
||||||
|
" \"If no critique is needed, specify that.\\n\\n\"\n",
|
||||||
|
" \"Query: {query}\\n\\n\"\n",
|
||||||
|
" \"Response: {response}\\n\\n\"\n",
|
||||||
|
" \"Critique request: {critique_request}\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"revision_prompt = ChatPromptTemplate.from_template(\n",
|
||||||
|
" \"Revise this response according to the critique and reivsion request.\\n\\n\"\n",
|
||||||
|
" \"Query: {query}\\n\\n\"\n",
|
||||||
|
" \"Response: {response}\\n\\n\"\n",
|
||||||
|
" \"Critique request: {critique_request}\\n\\n\"\n",
|
||||||
|
" \"Critique: {critique}\\n\\n\"\n",
|
||||||
|
" \"If the critique does not identify anything worth changing, ignore the \"\n",
|
||||||
|
" \"revision request and return 'No revisions needed'. If the critique \"\n",
|
||||||
|
" \"does identify something worth changing, revise the response based on \"\n",
|
||||||
|
" \"the revision request.\\n\\n\"\n",
|
||||||
|
" \"Revision Request: {revision_request}\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"chain = llm | StrOutputParser()\n",
|
||||||
|
"critique_chain = critique_prompt | llm.with_structured_output(Critique)\n",
|
||||||
|
"revision_chain = revision_prompt | llm | StrOutputParser()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class State(TypedDict):\n",
|
||||||
|
" query: str\n",
|
||||||
|
" constitutional_principles: List[ConstitutionalPrinciple]\n",
|
||||||
|
" initial_response: str\n",
|
||||||
|
" critiques_and_revisions: List[Tuple[str, str]]\n",
|
||||||
|
" response: str\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"async def generate_response(state: State):\n",
|
||||||
|
" \"\"\"Generate initial response.\"\"\"\n",
|
||||||
|
" response = await chain.ainvoke(state[\"query\"])\n",
|
||||||
|
" return {\"response\": response, \"initial_response\": response}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"async def critique_and_revise(state: State):\n",
|
||||||
|
" \"\"\"Critique and revise response according to principles.\"\"\"\n",
|
||||||
|
" critiques_and_revisions = []\n",
|
||||||
|
" response = state[\"initial_response\"]\n",
|
||||||
|
" for principle in state[\"constitutional_principles\"]:\n",
|
||||||
|
" critique = await critique_chain.ainvoke(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"query\": state[\"query\"],\n",
|
||||||
|
" \"response\": response,\n",
|
||||||
|
" \"critique_request\": principle.critique_request,\n",
|
||||||
|
" }\n",
|
||||||
|
" )\n",
|
||||||
|
" if critique[\"critique_needed\"]:\n",
|
||||||
|
" revision = await revision_chain.ainvoke(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"query\": state[\"query\"],\n",
|
||||||
|
" \"response\": response,\n",
|
||||||
|
" \"critique_request\": principle.critique_request,\n",
|
||||||
|
" \"critique\": critique[\"critique\"],\n",
|
||||||
|
" \"revision_request\": principle.revision_request,\n",
|
||||||
|
" }\n",
|
||||||
|
" )\n",
|
||||||
|
" response = revision\n",
|
||||||
|
" critiques_and_revisions.append((critique[\"critique\"], revision))\n",
|
||||||
|
" else:\n",
|
||||||
|
" critiques_and_revisions.append((critique[\"critique\"], \"\"))\n",
|
||||||
|
" return {\n",
|
||||||
|
" \"critiques_and_revisions\": critiques_and_revisions,\n",
|
||||||
|
" \"response\": response,\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"graph = StateGraph(State)\n",
|
||||||
|
"graph.add_node(\"generate_response\", generate_response)\n",
|
||||||
|
"graph.add_node(\"critique_and_revise\", critique_and_revise)\n",
|
||||||
|
"\n",
|
||||||
|
"graph.add_edge(START, \"generate_response\")\n",
|
||||||
|
"graph.add_edge(\"generate_response\", \"critique_and_revise\")\n",
|
||||||
|
"graph.add_edge(\"critique_and_revise\", END)\n",
|
||||||
|
"app = graph.compile()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "01aac88d-464e-431f-b92e-746dcb743e1b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{}\n",
|
||||||
|
"{'initial_response': 'Finding purpose, connection, and joy in our experiences and relationships.', 'response': 'Finding purpose, connection, and joy in our experiences and relationships.'}\n",
|
||||||
|
"{'initial_response': 'Finding purpose, connection, and joy in our experiences and relationships.', 'critiques_and_revisions': [(\"The response exceeds the 10-word limit, providing a more elaborate answer than requested. A concise response, such as 'To seek purpose and joy in life,' would better align with the query.\", 'To seek purpose and joy in life.')], 'response': 'To seek purpose and joy in life.'}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"constitutional_principles = [\n",
|
||||||
|
" ConstitutionalPrinciple(\n",
|
||||||
|
" critique_request=\"Tell if this answer is good.\",\n",
|
||||||
|
" revision_request=\"Give a better answer.\",\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"query = \"What is the meaning of life? Answer in 10 words or fewer.\"\n",
|
||||||
|
"\n",
|
||||||
|
"async for step in app.astream(\n",
|
||||||
|
" {\"query\": query, \"constitutional_principles\": constitutional_principles},\n",
|
||||||
|
" stream_mode=\"values\",\n",
|
||||||
|
"):\n",
|
||||||
|
" subset = [\"initial_response\", \"critiques_and_revisions\", \"response\"]\n",
|
||||||
|
" print({k: v for k, v in step.items() if k in subset})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b2717810",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"</details>\n",
|
||||||
|
"\n",
|
||||||
|
"## Next steps\n",
|
||||||
|
"\n",
|
||||||
|
"See guides for generating structured output [here](/docs/how_to/structured_output/).\n",
|
||||||
|
"\n",
|
||||||
|
"Check out the [LangGraph documentation](https://langchain-ai.github.io/langgraph/) for detail on building with LangGraph."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -45,5 +45,7 @@ The below pages assist with migration from various specific chains to LCEL and L
|
|||||||
- [RefineDocumentsChain](/docs/versions/migrating_chains/refine_docs_chain)
|
- [RefineDocumentsChain](/docs/versions/migrating_chains/refine_docs_chain)
|
||||||
- [LLMRouterChain](/docs/versions/migrating_chains/llm_router_chain)
|
- [LLMRouterChain](/docs/versions/migrating_chains/llm_router_chain)
|
||||||
- [MultiPromptChain](/docs/versions/migrating_chains/multi_prompt_chain)
|
- [MultiPromptChain](/docs/versions/migrating_chains/multi_prompt_chain)
|
||||||
|
- [LLMMathChain](/docs/versions/migrating_chains/llm_math_chain)
|
||||||
|
- [ConstitutionalChain](/docs/versions/migrating_chains/constitutional_chain)
|
||||||
|
|
||||||
Check out the [LCEL conceptual docs](/docs/concepts/#langchain-expression-language-lcel) and [LangGraph docs](https://langchain-ai.github.io/langgraph/) for more background information.
|
Check out the [LCEL conceptual docs](/docs/concepts/#langchain-expression-language-lcel) and [LangGraph docs](https://langchain-ai.github.io/langgraph/) for more background information.
|
281
docs/docs/versions/migrating_chains/llm_math_chain.ipynb
Normal file
281
docs/docs/versions/migrating_chains/llm_math_chain.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -0,0 +1,8 @@
|
|||||||
|
"""Implement a GPT-3 driven browser.
|
||||||
|
|
||||||
|
Heavily influenced from https://github.com/nat/natbot
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_community.chains.natbot.base import NatBotChain
|
||||||
|
|
||||||
|
__all__ = ["NatBotChain"]
|
3
libs/community/langchain_community/chains/natbot/base.py
Normal file
3
libs/community/langchain_community/chains/natbot/base.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from langchain.chains import NatBotChain
|
||||||
|
|
||||||
|
__all__ = ["NatBotChain"]
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain.chains.natbot.crawler import (
|
||||||
|
Crawler,
|
||||||
|
ElementInViewPort,
|
||||||
|
black_listed_elements,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["ElementInViewPort", "Crawler", "black_listed_elements"]
|
@ -0,0 +1,3 @@
|
|||||||
|
from langchain.chains.natbot.prompt import PROMPT
|
||||||
|
|
||||||
|
__all__ = ["PROMPT"]
|
@ -6,14 +6,6 @@ from langchain_core.documents import Document
|
|||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
def test_llm_construction_with_kwargs() -> None:
|
|
||||||
llm_chain_kwargs = {"verbose": True}
|
|
||||||
compressor = LLMChainExtractor.from_llm(
|
|
||||||
ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs
|
|
||||||
)
|
|
||||||
assert compressor.llm_chain.verbose is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_llm_chain_extractor() -> None:
|
def test_llm_chain_extractor() -> None:
|
||||||
texts = [
|
texts = [
|
||||||
"The Roman Empire followed the Roman Republic.",
|
"The Roman Empire followed the Roman Republic.",
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.chains.natbot.base import NatBotChain
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
|
|
||||||
from langchain.chains.natbot.base import NatBotChain
|
|
||||||
|
|
||||||
|
|
||||||
class FakeLLM(LLM):
|
class FakeLLM(LLM):
|
||||||
"""Fake LLM wrapper for testing purposes."""
|
"""Fake LLM wrapper for testing purposes."""
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.2.13",
|
||||||
|
message=(
|
||||||
|
"This class is deprecated and will be removed in langchain 1.0. "
|
||||||
|
"See API reference for replacement: "
|
||||||
|
"https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
|
||||||
|
),
|
||||||
|
removal="1.0",
|
||||||
|
)
|
||||||
class ConstitutionalChain(Chain):
|
class ConstitutionalChain(Chain):
|
||||||
"""Chain for applying constitutional principles.
|
"""Chain for applying constitutional principles.
|
||||||
|
|
||||||
|
Note: this class is deprecated. See below for a replacement implementation
|
||||||
|
using LangGraph. The benefits of this implementation are:
|
||||||
|
|
||||||
|
- Uses LLM tool calling features instead of parsing string responses;
|
||||||
|
- Support for both token-by-token and step-by-step streaming;
|
||||||
|
- Support for checkpointing and memory of chat history;
|
||||||
|
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
||||||
|
|
||||||
|
Install LangGraph with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install -U langgraph
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain.chains.constitutional_ai.prompts import (
|
||||||
|
CRITIQUE_PROMPT,
|
||||||
|
REVISION_PROMPT,
|
||||||
|
)
|
||||||
|
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from typing_extensions import Annotated, TypedDict
|
||||||
|
|
||||||
|
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||||
|
|
||||||
|
class Critique(TypedDict):
|
||||||
|
\"\"\"Generate a critique, if needed.\"\"\"
|
||||||
|
critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
|
||||||
|
critique: Annotated[str, ..., "If needed, the critique."]
|
||||||
|
|
||||||
|
critique_prompt = ChatPromptTemplate.from_template(
|
||||||
|
"Critique this response according to the critique request. "
|
||||||
|
"If no critique is needed, specify that.\\n\\n"
|
||||||
|
"Query: {query}\\n\\n"
|
||||||
|
"Response: {response}\\n\\n"
|
||||||
|
"Critique request: {critique_request}"
|
||||||
|
)
|
||||||
|
|
||||||
|
revision_prompt = ChatPromptTemplate.from_template(
|
||||||
|
"Revise this response according to the critique and reivsion request.\\n\\n"
|
||||||
|
"Query: {query}\\n\\n"
|
||||||
|
"Response: {response}\\n\\n"
|
||||||
|
"Critique request: {critique_request}\\n\\n"
|
||||||
|
"Critique: {critique}\\n\\n"
|
||||||
|
"If the critique does not identify anything worth changing, ignore the "
|
||||||
|
"revision request and return 'No revisions needed'. If the critique "
|
||||||
|
"does identify something worth changing, revise the response based on "
|
||||||
|
"the revision request.\\n\\n"
|
||||||
|
"Revision Request: {revision_request}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = llm | StrOutputParser()
|
||||||
|
critique_chain = critique_prompt | llm.with_structured_output(Critique)
|
||||||
|
revision_chain = revision_prompt | llm | StrOutputParser()
|
||||||
|
|
||||||
|
|
||||||
|
class State(TypedDict):
|
||||||
|
query: str
|
||||||
|
constitutional_principles: List[ConstitutionalPrinciple]
|
||||||
|
initial_response: str
|
||||||
|
critiques_and_revisions: List[Tuple[str, str]]
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_response(state: State):
|
||||||
|
\"\"\"Generate initial response.\"\"\"
|
||||||
|
response = await chain.ainvoke(state["query"])
|
||||||
|
return {"response": response, "initial_response": response}
|
||||||
|
|
||||||
|
async def critique_and_revise(state: State):
|
||||||
|
\"\"\"Critique and revise response according to principles.\"\"\"
|
||||||
|
critiques_and_revisions = []
|
||||||
|
response = state["initial_response"]
|
||||||
|
for principle in state["constitutional_principles"]:
|
||||||
|
critique = await critique_chain.ainvoke(
|
||||||
|
{
|
||||||
|
"query": state["query"],
|
||||||
|
"response": response,
|
||||||
|
"critique_request": principle.critique_request,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if critique["critique_needed"]:
|
||||||
|
revision = await revision_chain.ainvoke(
|
||||||
|
{
|
||||||
|
"query": state["query"],
|
||||||
|
"response": response,
|
||||||
|
"critique_request": principle.critique_request,
|
||||||
|
"critique": critique["critique"],
|
||||||
|
"revision_request": principle.revision_request,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = revision
|
||||||
|
critiques_and_revisions.append((critique["critique"], revision))
|
||||||
|
else:
|
||||||
|
critiques_and_revisions.append((critique["critique"], ""))
|
||||||
|
return {
|
||||||
|
"critiques_and_revisions": critiques_and_revisions,
|
||||||
|
"response": response,
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = StateGraph(State)
|
||||||
|
graph.add_node("generate_response", generate_response)
|
||||||
|
graph.add_node("critique_and_revise", critique_and_revise)
|
||||||
|
|
||||||
|
graph.add_edge(START, "generate_response")
|
||||||
|
graph.add_edge("generate_response", "critique_and_revise")
|
||||||
|
graph.add_edge("critique_and_revise", END)
|
||||||
|
app = graph.compile()
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
constitutional_principles=[
|
||||||
|
ConstitutionalPrinciple(
|
||||||
|
critique_request="Tell if this answer is good.",
|
||||||
|
revision_request="Give a better answer.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
query = "What is the meaning of life? Answer in 10 words or fewer."
|
||||||
|
|
||||||
|
async for step in app.astream(
|
||||||
|
{"query": query, "constitutional_principles": constitutional_principles},
|
||||||
|
stream_mode="values",
|
||||||
|
):
|
||||||
|
subset = ["initial_response", "critiques_and_revisions", "response"]
|
||||||
|
print({k: v for k, v in step.items() if k in subset})
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
constitutional_chain.run(question="What is the meaning of life?")
|
constitutional_chain.run(question="What is the meaning of life?")
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
chain: LLMChain
|
chain: LLMChain
|
||||||
constitutional_principles: List[ConstitutionalPrinciple]
|
constitutional_principles: List[ConstitutionalPrinciple]
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -9,10 +8,12 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import Field
|
from langchain_core.pydantic_v1 import Field
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.flare.prompts import (
|
from langchain.chains.flare.prompts import (
|
||||||
@ -23,50 +24,13 @@ from langchain.chains.flare.prompts import (
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
|
||||||
|
|
||||||
class _ResponseChain(LLMChain):
|
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
|
||||||
"""Base class for chains that generate responses."""
|
"""Extract tokens and log probabilities from chat model response."""
|
||||||
|
|
||||||
prompt: BasePromptTemplate = PROMPT
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
return self.prompt.input_variables
|
|
||||||
|
|
||||||
def generate_tokens_and_log_probs(
|
|
||||||
self,
|
|
||||||
_input: Dict[str, Any],
|
|
||||||
*,
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
||||||
llm_result = self.generate([_input], run_manager=run_manager)
|
|
||||||
return self._extract_tokens_and_log_probs(llm_result.generations[0])
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _extract_tokens_and_log_probs(
|
|
||||||
self, generations: List[Generation]
|
|
||||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
||||||
"""Extract tokens and log probs from response."""
|
|
||||||
|
|
||||||
|
|
||||||
class _OpenAIResponseChain(_ResponseChain):
|
|
||||||
"""Chain that generates responses from user input and context."""
|
|
||||||
|
|
||||||
llm: BaseLanguageModel
|
|
||||||
|
|
||||||
def _extract_tokens_and_log_probs(
|
|
||||||
self, generations: List[Generation]
|
|
||||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
||||||
tokens = []
|
tokens = []
|
||||||
log_probs = []
|
log_probs = []
|
||||||
for gen in generations:
|
for token in response.response_metadata["logprobs"]["content"]:
|
||||||
if gen.generation_info is None:
|
tokens.append(token["token"])
|
||||||
raise ValueError
|
log_probs.append(token["logprob"])
|
||||||
tokens.extend(gen.generation_info["logprobs"]["tokens"])
|
|
||||||
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
|
|
||||||
return tokens, log_probs
|
return tokens, log_probs
|
||||||
|
|
||||||
|
|
||||||
@ -111,9 +75,9 @@ class FlareChain(Chain):
|
|||||||
"""Chain that combines a retriever, a question generator,
|
"""Chain that combines a retriever, a question generator,
|
||||||
and a response generator."""
|
and a response generator."""
|
||||||
|
|
||||||
question_generator_chain: QuestionGeneratorChain
|
question_generator_chain: Runnable
|
||||||
"""Chain that generates questions from uncertain spans."""
|
"""Chain that generates questions from uncertain spans."""
|
||||||
response_chain: _ResponseChain
|
response_chain: Runnable
|
||||||
"""Chain that generates responses from user input and context."""
|
"""Chain that generates responses from user input and context."""
|
||||||
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
||||||
"""Parser that determines whether the chain is finished."""
|
"""Parser that determines whether the chain is finished."""
|
||||||
@ -152,12 +116,16 @@ class FlareChain(Chain):
|
|||||||
for question in questions:
|
for question in questions:
|
||||||
docs.extend(self.retriever.invoke(question))
|
docs.extend(self.retriever.invoke(question))
|
||||||
context = "\n\n".join(d.page_content for d in docs)
|
context = "\n\n".join(d.page_content for d in docs)
|
||||||
result = self.response_chain.predict(
|
result = self.response_chain.invoke(
|
||||||
user_input=user_input,
|
{
|
||||||
context=context,
|
"user_input": user_input,
|
||||||
response=response,
|
"context": context,
|
||||||
callbacks=callbacks,
|
"response": response,
|
||||||
|
},
|
||||||
|
{"callbacks": callbacks},
|
||||||
)
|
)
|
||||||
|
if isinstance(result, AIMessage):
|
||||||
|
result = result.content
|
||||||
marginal, finished = self.output_parser.parse(result)
|
marginal, finished = self.output_parser.parse(result)
|
||||||
return marginal, finished
|
return marginal, finished
|
||||||
|
|
||||||
@ -178,6 +146,7 @@ class FlareChain(Chain):
|
|||||||
for span in low_confidence_spans
|
for span in low_confidence_spans
|
||||||
]
|
]
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
|
if isinstance(self.question_generator_chain, LLMChain):
|
||||||
question_gen_outputs = self.question_generator_chain.apply(
|
question_gen_outputs = self.question_generator_chain.apply(
|
||||||
question_gen_inputs, callbacks=callbacks
|
question_gen_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
@ -185,6 +154,10 @@ class FlareChain(Chain):
|
|||||||
output[self.question_generator_chain.output_keys[0]]
|
output[self.question_generator_chain.output_keys[0]]
|
||||||
for output in question_gen_outputs
|
for output in question_gen_outputs
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
questions = self.question_generator_chain.batch(
|
||||||
|
question_gen_inputs, config={"callbacks": callbacks}
|
||||||
|
)
|
||||||
_run_manager.on_text(
|
_run_manager.on_text(
|
||||||
f"Generated Questions: {questions}", color="yellow", end="\n"
|
f"Generated Questions: {questions}", color="yellow", end="\n"
|
||||||
)
|
)
|
||||||
@ -206,8 +179,10 @@ class FlareChain(Chain):
|
|||||||
f"Current Response: {response}", color="blue", end="\n"
|
f"Current Response: {response}", color="blue", end="\n"
|
||||||
)
|
)
|
||||||
_input = {"user_input": user_input, "context": "", "response": response}
|
_input = {"user_input": user_input, "context": "", "response": response}
|
||||||
tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
|
tokens, log_probs = _extract_tokens_and_log_probs(
|
||||||
_input, run_manager=_run_manager
|
self.response_chain.invoke(
|
||||||
|
_input, {"callbacks": _run_manager.get_child()}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
low_confidence_spans = _low_confidence_spans(
|
low_confidence_spans = _low_confidence_spans(
|
||||||
tokens,
|
tokens,
|
||||||
@ -251,18 +226,16 @@ class FlareChain(Chain):
|
|||||||
FlareChain class with the given language model.
|
FlareChain class with the given language model.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from langchain_openai import OpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"OpenAI is required for FlareChain. "
|
"OpenAI is required for FlareChain. "
|
||||||
"Please install langchain-openai."
|
"Please install langchain-openai."
|
||||||
"pip install langchain-openai"
|
"pip install langchain-openai"
|
||||||
)
|
)
|
||||||
question_gen_chain = QuestionGeneratorChain(llm=llm)
|
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
|
||||||
response_llm = OpenAI(
|
response_chain = PROMPT | llm
|
||||||
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
|
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
||||||
)
|
|
||||||
response_chain = _OpenAIResponseChain(llm=response_llm)
|
|
||||||
return cls(
|
return cls(
|
||||||
question_generator_chain=question_gen_chain,
|
question_generator_chain=question_gen_chain,
|
||||||
response_chain=response_chain,
|
response_chain=response_chain,
|
||||||
|
@ -11,7 +11,9 @@ import numpy as np
|
|||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||||
@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
base_embeddings: Embeddings
|
base_embeddings: Embeddings
|
||||||
llm_chain: LLMChain
|
llm_chain: Runnable
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Input keys for Hyde's LLM chain."""
|
"""Input keys for Hyde's LLM chain."""
|
||||||
return self.llm_chain.input_keys
|
return self.llm_chain.input_schema.schema()["required"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Output keys for Hyde's LLM chain."""
|
"""Output keys for Hyde's LLM chain."""
|
||||||
|
if isinstance(self.llm_chain, LLMChain):
|
||||||
return self.llm_chain.output_keys
|
return self.llm_chain.output_keys
|
||||||
|
else:
|
||||||
|
return ["text"]
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Call the base embeddings."""
|
"""Call the base embeddings."""
|
||||||
@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Generate a hypothetical document and embedded it."""
|
"""Generate a hypothetical document and embedded it."""
|
||||||
var_name = self.llm_chain.input_keys[0]
|
var_name = self.input_keys[0]
|
||||||
result = self.llm_chain.generate([{var_name: text}])
|
result = self.llm_chain.invoke({var_name: text})
|
||||||
documents = [generation.text for generation in result.generations[0]]
|
if isinstance(self.llm_chain, LLMChain):
|
||||||
|
documents = [result[self.output_keys[0]]]
|
||||||
|
else:
|
||||||
|
documents = [result]
|
||||||
embeddings = self.embed_documents(documents)
|
embeddings = self.embed_documents(documents)
|
||||||
return self.combine_embeddings(embeddings)
|
return self.combine_embeddings(embeddings)
|
||||||
|
|
||||||
@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Call the internal llm chain."""
|
"""Call the internal llm chain."""
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
return self.llm_chain.invoke(
|
||||||
|
inputs, config={"callbacks": _run_manager.get_child()}
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
f"of {list(PROMPT_MAP.keys())}."
|
f"of {list(PROMPT_MAP.keys())}."
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
llm_chain = prompt | llm | StrOutputParser()
|
||||||
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -7,6 +7,7 @@ import re
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain
|
|||||||
from langchain.chains.llm_math.prompt import PROMPT
|
from langchain.chains.llm_math.prompt import PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.2.13",
|
||||||
|
message=(
|
||||||
|
"This class is deprecated and will be removed in langchain 1.0. "
|
||||||
|
"See API reference for replacement: "
|
||||||
|
"https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501
|
||||||
|
),
|
||||||
|
removal="1.0",
|
||||||
|
)
|
||||||
class LLMMathChain(Chain):
|
class LLMMathChain(Chain):
|
||||||
"""Chain that interprets a prompt and executes python code to do math.
|
"""Chain that interprets a prompt and executes python code to do math.
|
||||||
|
|
||||||
|
Note: this class is deprecated. See below for a replacement implementation
|
||||||
|
using LangGraph. The benefits of this implementation are:
|
||||||
|
|
||||||
|
- Uses LLM tool calling features;
|
||||||
|
- Support for both token-by-token and step-by-step streaming;
|
||||||
|
- Support for checkpointing and memory of chat history;
|
||||||
|
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
||||||
|
|
||||||
|
Install LangGraph with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install -U langgraph
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Annotated, Sequence
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.graph import END, StateGraph
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
from langgraph.prebuilt.tool_node import ToolNode
|
||||||
|
import numexpr
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def calculator(expression: str) -> str:
|
||||||
|
\"\"\"Calculate expression using Python's numexpr library.
|
||||||
|
|
||||||
|
Expression should be a single line mathematical expression
|
||||||
|
that solves the problem.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
"37593 * 67" for "37593 times 67"
|
||||||
|
"37593**(1/5)" for "37593^(1/5)"
|
||||||
|
\"\"\"
|
||||||
|
local_dict = {"pi": math.pi, "e": math.e}
|
||||||
|
return str(
|
||||||
|
numexpr.evaluate(
|
||||||
|
expression.strip(),
|
||||||
|
global_dict={}, # restrict access to globals
|
||||||
|
local_dict=local_dict, # add common mathematical functions
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||||
|
tools = [calculator]
|
||||||
|
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
|
||||||
|
|
||||||
|
class ChainState(TypedDict):
|
||||||
|
\"\"\"LangGraph state.\"\"\"
|
||||||
|
|
||||||
|
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||||
|
|
||||||
|
async def acall_chain(state: ChainState, config: RunnableConfig):
|
||||||
|
last_message = state["messages"][-1]
|
||||||
|
response = await llm_with_tools.ainvoke(state["messages"], config)
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
async def acall_model(state: ChainState, config: RunnableConfig):
|
||||||
|
response = await llm.ainvoke(state["messages"], config)
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
graph_builder = StateGraph(ChainState)
|
||||||
|
graph_builder.add_node("call_tool", acall_chain)
|
||||||
|
graph_builder.add_node("execute_tool", ToolNode(tools))
|
||||||
|
graph_builder.add_node("call_model", acall_model)
|
||||||
|
graph_builder.set_entry_point("call_tool")
|
||||||
|
graph_builder.add_edge("call_tool", "execute_tool")
|
||||||
|
graph_builder.add_edge("execute_tool", "call_model")
|
||||||
|
graph_builder.add_edge("call_model", END)
|
||||||
|
chain = graph_builder.compile()
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
example_query = "What is 551368 divided by 82"
|
||||||
|
|
||||||
|
events = chain.astream(
|
||||||
|
{"messages": [("user", example_query)]},
|
||||||
|
stream_mode="values",
|
||||||
|
)
|
||||||
|
async for event in events:
|
||||||
|
event["messages"][-1].pretty_print()
|
||||||
|
|
||||||
|
.. code-block:: none
|
||||||
|
|
||||||
|
================================ Human Message =================================
|
||||||
|
|
||||||
|
What is 551368 divided by 82
|
||||||
|
================================== Ai Message ==================================
|
||||||
|
Tool Calls:
|
||||||
|
calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS)
|
||||||
|
Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS
|
||||||
|
Args:
|
||||||
|
expression: 551368 / 82
|
||||||
|
================================= Tool Message =================================
|
||||||
|
Name: calculator
|
||||||
|
|
||||||
|
6724.0
|
||||||
|
================================== Ai Message ==================================
|
||||||
|
|
||||||
|
551368 divided by 82 equals 6724.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.chains import LLMMathChain
|
from langchain.chains import LLMMathChain
|
||||||
from langchain_community.llms import OpenAI
|
from langchain_community.llms import OpenAI
|
||||||
llm_math = LLMMathChain.from_llm(OpenAI())
|
llm_math = LLMMathChain.from_llm(OpenAI())
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
llm: Optional[BaseLanguageModel] = None
|
llm: Optional[BaseLanguageModel] = None
|
||||||
|
@ -5,15 +5,27 @@ from __future__ import annotations
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.chains.natbot.prompt import PROMPT
|
from langchain.chains.natbot.prompt import PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.2.13",
|
||||||
|
message=(
|
||||||
|
"Importing NatBotChain from langchain is deprecated and will be removed in "
|
||||||
|
"langchain 1.0. Please import from langchain_community instead: "
|
||||||
|
"from langchain_community.chains.natbot import NatBotChain. "
|
||||||
|
"You may need to pip install -U langchain-community."
|
||||||
|
),
|
||||||
|
removal="1.0",
|
||||||
|
)
|
||||||
class NatBotChain(Chain):
|
class NatBotChain(Chain):
|
||||||
"""Implement an LLM driven browser.
|
"""Implement an LLM driven browser.
|
||||||
|
|
||||||
@ -37,7 +49,7 @@ class NatBotChain(Chain):
|
|||||||
natbot = NatBotChain.from_default("Buy me a new hat.")
|
natbot = NatBotChain.from_default("Buy me a new hat.")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: Runnable
|
||||||
objective: str
|
objective: str
|
||||||
"""Objective that NatBot is tasked with completing."""
|
"""Objective that NatBot is tasked with completing."""
|
||||||
llm: Optional[BaseLanguageModel] = None
|
llm: Optional[BaseLanguageModel] = None
|
||||||
@ -60,7 +72,7 @@ class NatBotChain(Chain):
|
|||||||
"class method."
|
"class method."
|
||||||
)
|
)
|
||||||
if "llm_chain" not in values and values["llm"] is not None:
|
if "llm_chain" not in values and values["llm"] is not None:
|
||||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
|
values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -77,7 +89,7 @@ class NatBotChain(Chain):
|
|||||||
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
|
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
|
||||||
) -> NatBotChain:
|
) -> NatBotChain:
|
||||||
"""Load from LLM."""
|
"""Load from LLM."""
|
||||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
llm_chain = PROMPT | llm | StrOutputParser()
|
||||||
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -104,12 +116,14 @@ class NatBotChain(Chain):
|
|||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
url = inputs[self.input_url_key]
|
url = inputs[self.input_url_key]
|
||||||
browser_content = inputs[self.input_browser_content_key]
|
browser_content = inputs[self.input_browser_content_key]
|
||||||
llm_cmd = self.llm_chain.predict(
|
llm_cmd = self.llm_chain.invoke(
|
||||||
objective=self.objective,
|
{
|
||||||
url=url[:100],
|
"objective": self.objective,
|
||||||
previous_command=self.previous_command,
|
"url": url[:100],
|
||||||
browser_content=browser_content[:4500],
|
"previous_command": self.previous_command,
|
||||||
callbacks=_run_manager.get_child(),
|
"browser_content": browser_content[:4500],
|
||||||
|
},
|
||||||
|
config={"callbacks": _run_manager.get_child()},
|
||||||
)
|
)
|
||||||
llm_cmd = llm_cmd.strip()
|
llm_cmd = llm_cmd.strip()
|
||||||
self.previous_command = llm_cmd
|
self.previous_command = llm_cmd
|
||||||
|
@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast
|
|||||||
from langchain_core.callbacks.manager import Callbacks
|
from langchain_core.callbacks.manager import Callbacks
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||||
@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
"""Document compressor that uses an LLM chain to extract
|
"""Document compressor that uses an LLM chain to extract
|
||||||
the relevant parts of documents."""
|
the relevant parts of documents."""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: Runnable
|
||||||
"""LLM wrapper to use for compressing documents."""
|
"""LLM wrapper to use for compressing documents."""
|
||||||
|
|
||||||
get_input: Callable[[str, Document], dict] = default_get_input
|
get_input: Callable[[str, Document], dict] = default_get_input
|
||||||
"""Callable for constructing the chain input from the query and a Document."""
|
"""Callable for constructing the chain input from the query and a Document."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def compress_documents(
|
def compress_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
compressed_docs = []
|
compressed_docs = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
_input = self.get_input(query, doc)
|
_input = self.get_input(query, doc)
|
||||||
output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
||||||
output = output_dict[self.llm_chain.output_key]
|
if isinstance(self.llm_chain, LLMChain):
|
||||||
|
output = output_[self.llm_chain.output_key]
|
||||||
if self.llm_chain.prompt.output_parser is not None:
|
if self.llm_chain.prompt.output_parser is not None:
|
||||||
output = self.llm_chain.prompt.output_parser.parse(output)
|
output = self.llm_chain.prompt.output_parser.parse(output)
|
||||||
|
else:
|
||||||
|
output = output_
|
||||||
if len(output) == 0:
|
if len(output) == 0:
|
||||||
continue
|
continue
|
||||||
compressed_docs.append(
|
compressed_docs.append(
|
||||||
@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
"""Compress page content of raw documents asynchronously."""
|
"""Compress page content of raw documents asynchronously."""
|
||||||
outputs = await asyncio.gather(
|
outputs = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self.llm_chain.apredict_and_parse(
|
self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks)
|
||||||
**self.get_input(query, doc), callbacks=callbacks
|
|
||||||
)
|
|
||||||
for doc in documents
|
for doc in documents
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||||
_get_input = get_input if get_input is not None else default_get_input
|
_get_input = get_input if get_input is not None else default_get_input
|
||||||
llm_chain = LLMChain(llm=llm, prompt=_prompt, **(llm_chain_kwargs or {}))
|
if _prompt.output_parser is not None:
|
||||||
|
parser = _prompt.output_parser
|
||||||
|
else:
|
||||||
|
parser = StrOutputParser()
|
||||||
|
llm_chain = _prompt | llm | parser
|
||||||
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
|
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
|
||||||
|
@ -5,7 +5,9 @@ from typing import Any, Callable, Dict, Optional, Sequence
|
|||||||
from langchain_core.callbacks.manager import Callbacks
|
from langchain_core.callbacks.manager import Callbacks
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
@ -32,13 +34,16 @@ def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
|
|||||||
class LLMChainFilter(BaseDocumentCompressor):
|
class LLMChainFilter(BaseDocumentCompressor):
|
||||||
"""Filter that drops documents that aren't relevant to the query."""
|
"""Filter that drops documents that aren't relevant to the query."""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: Runnable
|
||||||
"""LLM wrapper to use for filtering documents.
|
"""LLM wrapper to use for filtering documents.
|
||||||
The chain prompt is expected to have a BooleanOutputParser."""
|
The chain prompt is expected to have a BooleanOutputParser."""
|
||||||
|
|
||||||
get_input: Callable[[str, Document], dict] = default_get_input
|
get_input: Callable[[str, Document], dict] = default_get_input
|
||||||
"""Callable for constructing the chain input from the query and a Document."""
|
"""Callable for constructing the chain input from the query and a Document."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def compress_documents(
|
def compress_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
@ -56,11 +61,15 @@ class LLMChainFilter(BaseDocumentCompressor):
|
|||||||
documents,
|
documents,
|
||||||
)
|
)
|
||||||
|
|
||||||
for output_dict, doc in outputs:
|
for output_, doc in outputs:
|
||||||
include_doc = None
|
include_doc = None
|
||||||
output = output_dict[self.llm_chain.output_key]
|
if isinstance(self.llm_chain, LLMChain):
|
||||||
|
output = output_[self.llm_chain.output_key]
|
||||||
if self.llm_chain.prompt.output_parser is not None:
|
if self.llm_chain.prompt.output_parser is not None:
|
||||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||||
|
else:
|
||||||
|
if isinstance(output_, bool):
|
||||||
|
include_doc = output_
|
||||||
if include_doc:
|
if include_doc:
|
||||||
filtered_docs.append(doc)
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
@ -82,11 +91,15 @@ class LLMChainFilter(BaseDocumentCompressor):
|
|||||||
),
|
),
|
||||||
documents,
|
documents,
|
||||||
)
|
)
|
||||||
for output_dict, doc in outputs:
|
for output_, doc in outputs:
|
||||||
include_doc = None
|
include_doc = None
|
||||||
output = output_dict[self.llm_chain.output_key]
|
if isinstance(self.llm_chain, LLMChain):
|
||||||
|
output = output_[self.llm_chain.output_key]
|
||||||
if self.llm_chain.prompt.output_parser is not None:
|
if self.llm_chain.prompt.output_parser is not None:
|
||||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||||
|
else:
|
||||||
|
if isinstance(output_, bool):
|
||||||
|
include_doc = output_
|
||||||
if include_doc:
|
if include_doc:
|
||||||
filtered_docs.append(doc)
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
@ -110,5 +123,9 @@ class LLMChainFilter(BaseDocumentCompressor):
|
|||||||
A LLMChainFilter that uses the given language model.
|
A LLMChainFilter that uses the given language model.
|
||||||
"""
|
"""
|
||||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
if _prompt.output_parser is not None:
|
||||||
|
parser = _prompt.output_parser
|
||||||
|
else:
|
||||||
|
parser = StrOutputParser()
|
||||||
|
llm_chain = _prompt | llm | parser
|
||||||
return cls(llm_chain=llm_chain, **kwargs)
|
return cls(llm_chain=llm_chain, **kwargs)
|
||||||
|
@ -7,11 +7,11 @@ from langchain_core.callbacks import (
|
|||||||
)
|
)
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models import BaseLLM
|
from langchain_core.language_models import BaseLLM
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class RePhraseQueryRetriever(BaseRetriever):
|
|||||||
Then, retrieve docs for the re-phrased query."""
|
Then, retrieve docs for the re-phrased query."""
|
||||||
|
|
||||||
retriever: BaseRetriever
|
retriever: BaseRetriever
|
||||||
llm_chain: LLMChain
|
llm_chain: Runnable
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
@ -51,8 +51,7 @@ class RePhraseQueryRetriever(BaseRetriever):
|
|||||||
Returns:
|
Returns:
|
||||||
RePhraseQueryRetriever
|
RePhraseQueryRetriever
|
||||||
"""
|
"""
|
||||||
|
llm_chain = prompt | llm | StrOutputParser()
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
||||||
return cls(
|
return cls(
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
@ -72,8 +71,9 @@ class RePhraseQueryRetriever(BaseRetriever):
|
|||||||
Returns:
|
Returns:
|
||||||
Relevant documents for re-phrased question
|
Relevant documents for re-phrased question
|
||||||
"""
|
"""
|
||||||
response = self.llm_chain(query, callbacks=run_manager.get_child())
|
re_phrased_question = self.llm_chain.invoke(
|
||||||
re_phrased_question = response["text"]
|
query, {"callbacks": run_manager.get_child()}
|
||||||
|
)
|
||||||
logger.info(f"Re-phrased question: {re_phrased_question}")
|
logger.info(f"Re-phrased question: {re_phrased_question}")
|
||||||
docs = self.retriever.invoke(
|
docs = self.retriever.invoke(
|
||||||
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
||||||
|
@ -0,0 +1,84 @@
|
|||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.language_models import FakeListChatModel
|
||||||
|
|
||||||
|
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_chain_extractor() -> None:
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
page_content=(
|
||||||
|
"The sky is blue. Candlepin bowling is popular in New England."
|
||||||
|
),
|
||||||
|
metadata={"a": 1},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content=(
|
||||||
|
"Mercury is the closest planet to the Sun. "
|
||||||
|
"Candlepin bowling balls are smaller."
|
||||||
|
),
|
||||||
|
metadata={"b": 2},
|
||||||
|
),
|
||||||
|
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||||
|
]
|
||||||
|
llm = FakeListChatModel(
|
||||||
|
responses=[
|
||||||
|
"Candlepin bowling is popular in New England.",
|
||||||
|
"Candlepin bowling balls are smaller.",
|
||||||
|
"NO_OUTPUT",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
doc_compressor = LLMChainExtractor.from_llm(llm)
|
||||||
|
output = doc_compressor.compress_documents(
|
||||||
|
documents, "Tell me about Candlepin bowling."
|
||||||
|
)
|
||||||
|
expected = documents = [
|
||||||
|
Document(
|
||||||
|
page_content="Candlepin bowling is popular in New England.",
|
||||||
|
metadata={"a": 1},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_llm_chain_extractor_async() -> None:
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
page_content=(
|
||||||
|
"The sky is blue. Candlepin bowling is popular in New England."
|
||||||
|
),
|
||||||
|
metadata={"a": 1},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content=(
|
||||||
|
"Mercury is the closest planet to the Sun. "
|
||||||
|
"Candlepin bowling balls are smaller."
|
||||||
|
),
|
||||||
|
metadata={"b": 2},
|
||||||
|
),
|
||||||
|
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||||
|
]
|
||||||
|
llm = FakeListChatModel(
|
||||||
|
responses=[
|
||||||
|
"Candlepin bowling is popular in New England.",
|
||||||
|
"Candlepin bowling balls are smaller.",
|
||||||
|
"NO_OUTPUT",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
doc_compressor = LLMChainExtractor.from_llm(llm)
|
||||||
|
output = await doc_compressor.acompress_documents(
|
||||||
|
documents, "Tell me about Candlepin bowling."
|
||||||
|
)
|
||||||
|
expected = documents = [
|
||||||
|
Document(
|
||||||
|
page_content="Candlepin bowling is popular in New England.",
|
||||||
|
metadata={"a": 1},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert output == expected
|
@ -0,0 +1,46 @@
|
|||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.language_models import FakeListChatModel
|
||||||
|
|
||||||
|
from langchain.retrievers.document_compressors import LLMChainFilter
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_chain_filter() -> None:
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
page_content="Candlepin bowling is popular in New England.",
|
||||||
|
metadata={"a": 1},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content="Candlepin bowling balls are smaller.",
|
||||||
|
metadata={"b": 2},
|
||||||
|
),
|
||||||
|
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||||
|
]
|
||||||
|
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
|
||||||
|
doc_compressor = LLMChainFilter.from_llm(llm)
|
||||||
|
output = doc_compressor.compress_documents(
|
||||||
|
documents, "Tell me about Candlepin bowling."
|
||||||
|
)
|
||||||
|
expected = documents[:2]
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_llm_chain_extractor_async() -> None:
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
page_content="Candlepin bowling is popular in New England.",
|
||||||
|
metadata={"a": 1},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content="Candlepin bowling balls are smaller.",
|
||||||
|
metadata={"b": 2},
|
||||||
|
),
|
||||||
|
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||||
|
]
|
||||||
|
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
|
||||||
|
doc_compressor = LLMChainFilter.from_llm(llm)
|
||||||
|
output = await doc_compressor.acompress_documents(
|
||||||
|
documents, "Tell me about Candlepin bowling."
|
||||||
|
)
|
||||||
|
expected = documents[:2]
|
||||||
|
assert output == expected
|
Loading…
Reference in New Issue
Block a user