diff --git a/docs/docs/versions/migrating_chains/constitutional_chain.ipynb b/docs/docs/versions/migrating_chains/constitutional_chain.ipynb new file mode 100644 index 00000000000..c3729b67c5a --- /dev/null +++ b/docs/docs/versions/migrating_chains/constitutional_chain.ipynb @@ -0,0 +1,332 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b57124cc-60a0-4c18-b7ce-3e483d1024a2", + "metadata": {}, + "source": [ + "---\n", + "title: Migrating from ConstitutionalChain\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ce8457ed-c0b1-4a74-abbd-9d3d2211270f", + "metadata": {}, + "source": [ + "[ConstitutionalChain](https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html) allowed for a LLM to critique and revise generations based on [principles](https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.models.ConstitutionalPrinciple.html), structured as combinations of critique and revision requests. For example, a principle might include a request to identify harmful content, and a request to rewrite the content.\n", + "\n", + "In `ConstitutionalChain`, this structure of critique requests and associated revisions was formatted into a LLM prompt and parsed out of string responses. This is more naturally achieved via [structured output](/docs/how_to/structured_output/) features of chat models. We can construct a simple chain in [LangGraph](https://langchain-ai.github.io/langgraph/) for this purpose. Some advantages of this approach include:\n", + "\n", + "- Leverage tool-calling capabilities of chat models that have been fine-tuned for this purpose;\n", + "- Reduce parsing errors from extracting expression from a string LLM response;\n", + "- Delegation of instructions to [message roles](/docs/concepts/#messages) (e.g., chat models can understand what a `ToolMessage` represents without the need for additional prompting);\n", + "- Support for streaming, both of individual tokens and chain steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b99b47ec", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain-openai" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "717c8673", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass()" + ] + }, + { + "cell_type": "markdown", + "id": "e3621b62-a037-42b8-8faa-59575608bb8b", + "metadata": {}, + "source": [ + "## Legacy\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f91c9809-8ee7-4e38-881d-0ace4f6ea883", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import ConstitutionalChain, LLMChain\n", + "from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple\n", + "from langchain_core.prompts import PromptTemplate\n", + "from langchain_openai import OpenAI\n", + "\n", + "llm = OpenAI()\n", + "\n", + "qa_prompt = PromptTemplate(\n", + " template=\"Q: {question} A:\",\n", + " input_variables=[\"question\"],\n", + ")\n", + "qa_chain = LLMChain(llm=llm, prompt=qa_prompt)\n", + "\n", + "constitutional_chain = ConstitutionalChain.from_llm(\n", + " llm=llm,\n", + " chain=qa_chain,\n", + " constitutional_principles=[\n", + " ConstitutionalPrinciple(\n", + " critique_request=\"Tell if this answer is good.\",\n", + " revision_request=\"Give a better answer.\",\n", + " )\n", + " ],\n", + " return_intermediate_steps=True,\n", + ")\n", + "\n", + "result = constitutional_chain.invoke(\"What is the meaning of life?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fa3d11a1-ac1f-4a9a-9ab3-b7b244daa506", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'question': 'What is the meaning of life?',\n", + " 'output': 'The meaning of life is a deeply personal and ever-evolving concept. It is a journey of self-discovery and growth, and can be different for each individual. Some may find meaning in relationships, others in achieving their goals, and some may never find a concrete answer. Ultimately, the meaning of life is what we make of it.',\n", + " 'initial_output': ' The meaning of life is a subjective concept that can vary from person to person. Some may believe that the purpose of life is to find happiness and fulfillment, while others may see it as a journey of self-discovery and personal growth. Ultimately, the meaning of life is something that each individual must determine for themselves.',\n", + " 'critiques_and_revisions': [('This answer is good in that it recognizes and acknowledges the subjective nature of the question and provides a valid and thoughtful response. However, it could have also mentioned that the meaning of life is a complex and deeply personal concept that can also change and evolve over time for each individual. Critique Needed.',\n", + " 'The meaning of life is a deeply personal and ever-evolving concept. It is a journey of self-discovery and growth, and can be different for each individual. Some may find meaning in relationships, others in achieving their goals, and some may never find a concrete answer. Ultimately, the meaning of life is what we make of it.')]}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result" + ] + }, + { + "cell_type": "markdown", + "id": "374ae108-f1a0-4723-9237-5259c8123c04", + "metadata": {}, + "source": [ + "Above, we've returned intermediate steps showing:\n", + "\n", + "- The original question;\n", + "- The initial output;\n", + "- Critiques and revisions;\n", + "- The final output (matching a revision)." + ] + }, + { + "cell_type": "markdown", + "id": "cdc3b527-c09e-4c77-9711-c3cc4506cd95", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## LangGraph\n", + "\n", + "
\n", + "\n", + "Below, we use the [.with_structured_output](/docs/how_to/structured_output/) method to simultaneously generate (1) a judgment of whether a critique is needed, and (2) the critique. We surface all prompts involved for clarity and ease of customizability.\n", + "\n", + "Note that we are also able to stream intermediate steps with this implementation, so we can monitor and if needed intervene during its execution." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "917fdb73-2411-4fcc-9add-c32dc5c745da", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Optional, Tuple\n", + "\n", + "from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple\n", + "from langchain.chains.constitutional_ai.prompts import (\n", + " CRITIQUE_PROMPT,\n", + " REVISION_PROMPT,\n", + ")\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_openai import ChatOpenAI\n", + "from langgraph.graph import END, START, StateGraph\n", + "from typing_extensions import Annotated, TypedDict\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", + "\n", + "\n", + "class Critique(TypedDict):\n", + " \"\"\"Generate a critique, if needed.\"\"\"\n", + "\n", + " critique_needed: Annotated[bool, ..., \"Whether or not a critique is needed.\"]\n", + " critique: Annotated[str, ..., \"If needed, the critique.\"]\n", + "\n", + "\n", + "critique_prompt = ChatPromptTemplate.from_template(\n", + " \"Critique this response according to the critique request. \"\n", + " \"If no critique is needed, specify that.\\n\\n\"\n", + " \"Query: {query}\\n\\n\"\n", + " \"Response: {response}\\n\\n\"\n", + " \"Critique request: {critique_request}\"\n", + ")\n", + "\n", + "revision_prompt = ChatPromptTemplate.from_template(\n", + " \"Revise this response according to the critique and reivsion request.\\n\\n\"\n", + " \"Query: {query}\\n\\n\"\n", + " \"Response: {response}\\n\\n\"\n", + " \"Critique request: {critique_request}\\n\\n\"\n", + " \"Critique: {critique}\\n\\n\"\n", + " \"If the critique does not identify anything worth changing, ignore the \"\n", + " \"revision request and return 'No revisions needed'. If the critique \"\n", + " \"does identify something worth changing, revise the response based on \"\n", + " \"the revision request.\\n\\n\"\n", + " \"Revision Request: {revision_request}\"\n", + ")\n", + "\n", + "chain = llm | StrOutputParser()\n", + "critique_chain = critique_prompt | llm.with_structured_output(Critique)\n", + "revision_chain = revision_prompt | llm | StrOutputParser()\n", + "\n", + "\n", + "class State(TypedDict):\n", + " query: str\n", + " constitutional_principles: List[ConstitutionalPrinciple]\n", + " initial_response: str\n", + " critiques_and_revisions: List[Tuple[str, str]]\n", + " response: str\n", + "\n", + "\n", + "async def generate_response(state: State):\n", + " \"\"\"Generate initial response.\"\"\"\n", + " response = await chain.ainvoke(state[\"query\"])\n", + " return {\"response\": response, \"initial_response\": response}\n", + "\n", + "\n", + "async def critique_and_revise(state: State):\n", + " \"\"\"Critique and revise response according to principles.\"\"\"\n", + " critiques_and_revisions = []\n", + " response = state[\"initial_response\"]\n", + " for principle in state[\"constitutional_principles\"]:\n", + " critique = await critique_chain.ainvoke(\n", + " {\n", + " \"query\": state[\"query\"],\n", + " \"response\": response,\n", + " \"critique_request\": principle.critique_request,\n", + " }\n", + " )\n", + " if critique[\"critique_needed\"]:\n", + " revision = await revision_chain.ainvoke(\n", + " {\n", + " \"query\": state[\"query\"],\n", + " \"response\": response,\n", + " \"critique_request\": principle.critique_request,\n", + " \"critique\": critique[\"critique\"],\n", + " \"revision_request\": principle.revision_request,\n", + " }\n", + " )\n", + " response = revision\n", + " critiques_and_revisions.append((critique[\"critique\"], revision))\n", + " else:\n", + " critiques_and_revisions.append((critique[\"critique\"], \"\"))\n", + " return {\n", + " \"critiques_and_revisions\": critiques_and_revisions,\n", + " \"response\": response,\n", + " }\n", + "\n", + "\n", + "graph = StateGraph(State)\n", + "graph.add_node(\"generate_response\", generate_response)\n", + "graph.add_node(\"critique_and_revise\", critique_and_revise)\n", + "\n", + "graph.add_edge(START, \"generate_response\")\n", + "graph.add_edge(\"generate_response\", \"critique_and_revise\")\n", + "graph.add_edge(\"critique_and_revise\", END)\n", + "app = graph.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "01aac88d-464e-431f-b92e-746dcb743e1b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{}\n", + "{'initial_response': 'Finding purpose, connection, and joy in our experiences and relationships.', 'response': 'Finding purpose, connection, and joy in our experiences and relationships.'}\n", + "{'initial_response': 'Finding purpose, connection, and joy in our experiences and relationships.', 'critiques_and_revisions': [(\"The response exceeds the 10-word limit, providing a more elaborate answer than requested. A concise response, such as 'To seek purpose and joy in life,' would better align with the query.\", 'To seek purpose and joy in life.')], 'response': 'To seek purpose and joy in life.'}\n" + ] + } + ], + "source": [ + "constitutional_principles = [\n", + " ConstitutionalPrinciple(\n", + " critique_request=\"Tell if this answer is good.\",\n", + " revision_request=\"Give a better answer.\",\n", + " )\n", + "]\n", + "\n", + "query = \"What is the meaning of life? Answer in 10 words or fewer.\"\n", + "\n", + "async for step in app.astream(\n", + " {\"query\": query, \"constitutional_principles\": constitutional_principles},\n", + " stream_mode=\"values\",\n", + "):\n", + " subset = [\"initial_response\", \"critiques_and_revisions\", \"response\"]\n", + " print({k: v for k, v in step.items() if k in subset})" + ] + }, + { + "cell_type": "markdown", + "id": "b2717810", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Next steps\n", + "\n", + "See guides for generating structured output [here](/docs/how_to/structured_output/).\n", + "\n", + "Check out the [LangGraph documentation](https://langchain-ai.github.io/langgraph/) for detail on building with LangGraph." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/versions/migrating_chains/index.mdx b/docs/docs/versions/migrating_chains/index.mdx index 4f809972e0a..69f6e24c8ef 100644 --- a/docs/docs/versions/migrating_chains/index.mdx +++ b/docs/docs/versions/migrating_chains/index.mdx @@ -45,5 +45,7 @@ The below pages assist with migration from various specific chains to LCEL and L - [RefineDocumentsChain](/docs/versions/migrating_chains/refine_docs_chain) - [LLMRouterChain](/docs/versions/migrating_chains/llm_router_chain) - [MultiPromptChain](/docs/versions/migrating_chains/multi_prompt_chain) +- [LLMMathChain](/docs/versions/migrating_chains/llm_math_chain) +- [ConstitutionalChain](/docs/versions/migrating_chains/constitutional_chain) Check out the [LCEL conceptual docs](/docs/concepts/#langchain-expression-language-lcel) and [LangGraph docs](https://langchain-ai.github.io/langgraph/) for more background information. \ No newline at end of file diff --git a/docs/docs/versions/migrating_chains/llm_math_chain.ipynb b/docs/docs/versions/migrating_chains/llm_math_chain.ipynb new file mode 100644 index 00000000000..87f2511085e --- /dev/null +++ b/docs/docs/versions/migrating_chains/llm_math_chain.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b57124cc-60a0-4c18-b7ce-3e483d1024a2", + "metadata": {}, + "source": [ + "---\n", + "title: Migrating from LLMMathChain\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ce8457ed-c0b1-4a74-abbd-9d3d2211270f", + "metadata": {}, + "source": [ + "[`LLMMathChain`](https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html) enabled the evaluation of mathematical expressions generated by a LLM. Instructions for generating the expressions were formatted into the prompt, and the expressions were parsed out of the string response before evaluation using the [numexpr](https://numexpr.readthedocs.io/en/latest/user_guide.html) library.\n", + "\n", + "This is more naturally achieved via [tool calling](/docs/concepts/#functiontool-calling). We can equip a chat model with a simple calculator tool leveraging `numexpr` and construct a simple chain around it using [LangGraph](https://langchain-ai.github.io/langgraph/). Some advantages of this approach include:\n", + "\n", + "- Leverage tool-calling capabilities of chat models that have been fine-tuned for this purpose;\n", + "- Reduce parsing errors from extracting expression from a string LLM response;\n", + "- Delegation of instructions to [message roles](/docs/concepts/#messages) (e.g., chat models can understand what a `ToolMessage` represents without the need for additional prompting);\n", + "- Support for streaming, both of individual tokens and chain steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b99b47ec", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet numexpr" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "717c8673", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass()" + ] + }, + { + "cell_type": "markdown", + "id": "e3621b62-a037-42b8-8faa-59575608bb8b", + "metadata": {}, + "source": [ + "## Legacy\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f91c9809-8ee7-4e38-881d-0ace4f6ea883", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'question': 'What is 551368 divided by 82?', 'answer': 'Answer: 6724.0'}" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.chains import LLMMathChain\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", + "\n", + "chain = LLMMathChain.from_llm(llm)\n", + "\n", + "chain.invoke(\"What is 551368 divided by 82?\")" + ] + }, + { + "cell_type": "markdown", + "id": "cdc3b527-c09e-4c77-9711-c3cc4506cd95", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## LangGraph\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f0903025-9aa8-4a53-8336-074341c00e59", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from typing import Annotated, Sequence\n", + "\n", + "import numexpr\n", + "from langchain_core.messages import BaseMessage\n", + "from langchain_core.runnables import RunnableConfig\n", + "from langchain_core.tools import tool\n", + "from langchain_openai import ChatOpenAI\n", + "from langgraph.graph import END, StateGraph\n", + "from langgraph.graph.message import add_messages\n", + "from langgraph.prebuilt.tool_node import ToolNode\n", + "from typing_extensions import TypedDict\n", + "\n", + "\n", + "@tool\n", + "def calculator(expression: str) -> str:\n", + " \"\"\"Calculate expression using Python's numexpr library.\n", + "\n", + " Expression should be a single line mathematical expression\n", + " that solves the problem.\n", + "\n", + " Examples:\n", + " \"37593 * 67\" for \"37593 times 67\"\n", + " \"37593**(1/5)\" for \"37593^(1/5)\"\n", + " \"\"\"\n", + " local_dict = {\"pi\": math.pi, \"e\": math.e}\n", + " return str(\n", + " numexpr.evaluate(\n", + " expression.strip(),\n", + " global_dict={}, # restrict access to globals\n", + " local_dict=local_dict, # add common mathematical functions\n", + " )\n", + " )\n", + "\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n", + "tools = [calculator]\n", + "llm_with_tools = llm.bind_tools(tools, tool_choice=\"any\")\n", + "\n", + "\n", + "class ChainState(TypedDict):\n", + " \"\"\"LangGraph state.\"\"\"\n", + "\n", + " messages: Annotated[Sequence[BaseMessage], add_messages]\n", + "\n", + "\n", + "async def acall_chain(state: ChainState, config: RunnableConfig):\n", + " last_message = state[\"messages\"][-1]\n", + " response = await llm_with_tools.ainvoke(state[\"messages\"], config)\n", + " return {\"messages\": [response]}\n", + "\n", + "\n", + "async def acall_model(state: ChainState, config: RunnableConfig):\n", + " response = await llm.ainvoke(state[\"messages\"], config)\n", + " return {\"messages\": [response]}\n", + "\n", + "\n", + "graph_builder = StateGraph(ChainState)\n", + "graph_builder.add_node(\"call_tool\", acall_chain)\n", + "graph_builder.add_node(\"execute_tool\", ToolNode(tools))\n", + "graph_builder.add_node(\"call_model\", acall_model)\n", + "graph_builder.set_entry_point(\"call_tool\")\n", + "graph_builder.add_edge(\"call_tool\", \"execute_tool\")\n", + "graph_builder.add_edge(\"execute_tool\", \"call_model\")\n", + "graph_builder.add_edge(\"call_model\", END)\n", + "chain = graph_builder.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d0a8a81a-328b-497d-956b-4d16b2efea0e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/jpeg": "", + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Visualize chain:\n", + "\n", + "from IPython.display import Image\n", + "\n", + "Image(chain.get_graph().draw_mermaid_png())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3ea1d71f-e31d-4722-be39-9a2b16d72f5f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================\u001b[1m Human Message \u001b[0m=================================\n", + "\n", + "What is 551368 divided by 82\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "Tool Calls:\n", + " calculator (call_1ic3gjuII0Aq9vxlSYiwvjSb)\n", + " Call ID: call_1ic3gjuII0Aq9vxlSYiwvjSb\n", + " Args:\n", + " expression: 551368 / 82\n", + "=================================\u001b[1m Tool Message \u001b[0m=================================\n", + "Name: calculator\n", + "\n", + "6724.0\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "\n", + "551368 divided by 82 equals 6724.\n" + ] + } + ], + "source": [ + "# Stream chain steps:\n", + "\n", + "example_query = \"What is 551368 divided by 82\"\n", + "\n", + "events = chain.astream(\n", + " {\"messages\": [(\"user\", example_query)]},\n", + " stream_mode=\"values\",\n", + ")\n", + "async for event in events:\n", + " event[\"messages\"][-1].pretty_print()" + ] + }, + { + "cell_type": "markdown", + "id": "b2717810", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Next steps\n", + "\n", + "See guides for building and working with tools [here](/docs/how_to/#tools).\n", + "\n", + "Check out the [LangGraph documentation](https://langchain-ai.github.io/langgraph/) for detail on building with LangGraph." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/chains/natbot/__init__.py b/libs/community/langchain_community/chains/natbot/__init__.py new file mode 100644 index 00000000000..aeec86f8bf2 --- /dev/null +++ b/libs/community/langchain_community/chains/natbot/__init__.py @@ -0,0 +1,8 @@ +"""Implement a GPT-3 driven browser. + +Heavily influenced from https://github.com/nat/natbot +""" + +from langchain_community.chains.natbot.base import NatBotChain + +__all__ = ["NatBotChain"] diff --git a/libs/community/langchain_community/chains/natbot/base.py b/libs/community/langchain_community/chains/natbot/base.py new file mode 100644 index 00000000000..7cb575e90e4 --- /dev/null +++ b/libs/community/langchain_community/chains/natbot/base.py @@ -0,0 +1,3 @@ +from langchain.chains import NatBotChain + +__all__ = ["NatBotChain"] diff --git a/libs/community/langchain_community/chains/natbot/crawler.py b/libs/community/langchain_community/chains/natbot/crawler.py new file mode 100644 index 00000000000..5c2c7657127 --- /dev/null +++ b/libs/community/langchain_community/chains/natbot/crawler.py @@ -0,0 +1,7 @@ +from langchain.chains.natbot.crawler import ( + Crawler, + ElementInViewPort, + black_listed_elements, +) + +__all__ = ["ElementInViewPort", "Crawler", "black_listed_elements"] diff --git a/libs/community/langchain_community/chains/natbot/prompt.py b/libs/community/langchain_community/chains/natbot/prompt.py new file mode 100644 index 00000000000..0ea63d5bbe2 --- /dev/null +++ b/libs/community/langchain_community/chains/natbot/prompt.py @@ -0,0 +1,3 @@ +from langchain.chains.natbot.prompt import PROMPT + +__all__ = ["PROMPT"] diff --git a/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py index ded7e5149be..aa167487172 100644 --- a/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py @@ -6,14 +6,6 @@ from langchain_core.documents import Document from langchain_community.chat_models import ChatOpenAI -def test_llm_construction_with_kwargs() -> None: - llm_chain_kwargs = {"verbose": True} - compressor = LLMChainExtractor.from_llm( - ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs - ) - assert compressor.llm_chain.verbose is True - - def test_llm_chain_extractor() -> None: texts = [ "The Roman Empire followed the Roman Republic.", diff --git a/libs/langchain/tests/unit_tests/chains/test_natbot.py b/libs/community/tests/unit_tests/chains/test_natbot.py similarity index 99% rename from libs/langchain/tests/unit_tests/chains/test_natbot.py rename to libs/community/tests/unit_tests/chains/test_natbot.py index 3f1f79da2e8..2b85ebbc209 100644 --- a/libs/langchain/tests/unit_tests/chains/test_natbot.py +++ b/libs/community/tests/unit_tests/chains/test_natbot.py @@ -2,11 +2,10 @@ from typing import Any, Dict, List, Optional +from langchain.chains.natbot.base import NatBotChain from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain.chains.natbot.base import NatBotChain - class FakeLLM(LLM): """Fake LLM wrapper for testing purposes.""" diff --git a/libs/langchain/langchain/chains/constitutional_ai/base.py b/libs/langchain/langchain/chains/constitutional_ai/base.py index bd86b57ed27..a095bf4047d 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/base.py +++ b/libs/langchain/langchain/chains/constitutional_ai/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional +from langchain_core._api import deprecated from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate @@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION from langchain.chains.llm import LLMChain +@deprecated( + since="0.2.13", + message=( + "This class is deprecated and will be removed in langchain 1.0. " + "See API reference for replacement: " + "https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501 + ), + removal="1.0", +) class ConstitutionalChain(Chain): """Chain for applying constitutional principles. + Note: this class is deprecated. See below for a replacement implementation + using LangGraph. The benefits of this implementation are: + + - Uses LLM tool calling features instead of parsing string responses; + - Support for both token-by-token and step-by-step streaming; + - Support for checkpointing and memory of chat history; + - Easier to modify or extend (e.g., with additional tools, structured responses, etc.) + + Install LangGraph with: + + .. code-block:: bash + + pip install -U langgraph + + .. code-block:: python + + from typing import List, Optional, Tuple + + from langchain.chains.constitutional_ai.prompts import ( + CRITIQUE_PROMPT, + REVISION_PROMPT, + ) + from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple + from langchain_core.output_parsers import StrOutputParser + from langchain_core.prompts import ChatPromptTemplate + from langchain_openai import ChatOpenAI + from langgraph.graph import END, START, StateGraph + from typing_extensions import Annotated, TypedDict + + llm = ChatOpenAI(model="gpt-4o-mini") + + class Critique(TypedDict): + \"\"\"Generate a critique, if needed.\"\"\" + critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."] + critique: Annotated[str, ..., "If needed, the critique."] + + critique_prompt = ChatPromptTemplate.from_template( + "Critique this response according to the critique request. " + "If no critique is needed, specify that.\\n\\n" + "Query: {query}\\n\\n" + "Response: {response}\\n\\n" + "Critique request: {critique_request}" + ) + + revision_prompt = ChatPromptTemplate.from_template( + "Revise this response according to the critique and reivsion request.\\n\\n" + "Query: {query}\\n\\n" + "Response: {response}\\n\\n" + "Critique request: {critique_request}\\n\\n" + "Critique: {critique}\\n\\n" + "If the critique does not identify anything worth changing, ignore the " + "revision request and return 'No revisions needed'. If the critique " + "does identify something worth changing, revise the response based on " + "the revision request.\\n\\n" + "Revision Request: {revision_request}" + ) + + chain = llm | StrOutputParser() + critique_chain = critique_prompt | llm.with_structured_output(Critique) + revision_chain = revision_prompt | llm | StrOutputParser() + + + class State(TypedDict): + query: str + constitutional_principles: List[ConstitutionalPrinciple] + initial_response: str + critiques_and_revisions: List[Tuple[str, str]] + response: str + + + async def generate_response(state: State): + \"\"\"Generate initial response.\"\"\" + response = await chain.ainvoke(state["query"]) + return {"response": response, "initial_response": response} + + async def critique_and_revise(state: State): + \"\"\"Critique and revise response according to principles.\"\"\" + critiques_and_revisions = [] + response = state["initial_response"] + for principle in state["constitutional_principles"]: + critique = await critique_chain.ainvoke( + { + "query": state["query"], + "response": response, + "critique_request": principle.critique_request, + } + ) + if critique["critique_needed"]: + revision = await revision_chain.ainvoke( + { + "query": state["query"], + "response": response, + "critique_request": principle.critique_request, + "critique": critique["critique"], + "revision_request": principle.revision_request, + } + ) + response = revision + critiques_and_revisions.append((critique["critique"], revision)) + else: + critiques_and_revisions.append((critique["critique"], "")) + return { + "critiques_and_revisions": critiques_and_revisions, + "response": response, + } + + graph = StateGraph(State) + graph.add_node("generate_response", generate_response) + graph.add_node("critique_and_revise", critique_and_revise) + + graph.add_edge(START, "generate_response") + graph.add_edge("generate_response", "critique_and_revise") + graph.add_edge("critique_and_revise", END) + app = graph.compile() + + .. code-block:: python + + constitutional_principles=[ + ConstitutionalPrinciple( + critique_request="Tell if this answer is good.", + revision_request="Give a better answer.", + ) + ] + + query = "What is the meaning of life? Answer in 10 words or fewer." + + async for step in app.astream( + {"query": query, "constitutional_principles": constitutional_principles}, + stream_mode="values", + ): + subset = ["initial_response", "critiques_and_revisions", "response"] + print({k: v for k, v in step.items() if k in subset}) + Example: .. code-block:: python @@ -44,7 +187,7 @@ class ConstitutionalChain(Chain): ) constitutional_chain.run(question="What is the meaning of life?") - """ + """ # noqa: E501 chain: LLMChain constitutional_principles: List[ConstitutionalPrinciple] diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 8a100ed0595..1d55bed468b 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -1,7 +1,6 @@ from __future__ import annotations import re -from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np @@ -9,10 +8,12 @@ from langchain_core.callbacks import ( CallbackManagerForChainRun, ) from langchain_core.language_models import BaseLanguageModel -from langchain_core.outputs import Generation +from langchain_core.messages import AIMessage +from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import BasePromptTemplate from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables import Runnable from langchain.chains.base import Chain from langchain.chains.flare.prompts import ( @@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import ( from langchain.chains.llm import LLMChain -class _ResponseChain(LLMChain): - """Base class for chains that generate responses.""" - - prompt: BasePromptTemplate = PROMPT - - @classmethod - def is_lc_serializable(cls) -> bool: - return False - - @property - def input_keys(self) -> List[str]: - return self.prompt.input_variables - - def generate_tokens_and_log_probs( - self, - _input: Dict[str, Any], - *, - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Tuple[Sequence[str], Sequence[float]]: - llm_result = self.generate([_input], run_manager=run_manager) - return self._extract_tokens_and_log_probs(llm_result.generations[0]) - - @abstractmethod - def _extract_tokens_and_log_probs( - self, generations: List[Generation] - ) -> Tuple[Sequence[str], Sequence[float]]: - """Extract tokens and log probs from response.""" - - -class _OpenAIResponseChain(_ResponseChain): - """Chain that generates responses from user input and context.""" - - llm: BaseLanguageModel - - def _extract_tokens_and_log_probs( - self, generations: List[Generation] - ) -> Tuple[Sequence[str], Sequence[float]]: - tokens = [] - log_probs = [] - for gen in generations: - if gen.generation_info is None: - raise ValueError - tokens.extend(gen.generation_info["logprobs"]["tokens"]) - log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"]) - return tokens, log_probs +def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]: + """Extract tokens and log probabilities from chat model response.""" + tokens = [] + log_probs = [] + for token in response.response_metadata["logprobs"]["content"]: + tokens.append(token["token"]) + log_probs.append(token["logprob"]) + return tokens, log_probs class QuestionGeneratorChain(LLMChain): @@ -111,9 +75,9 @@ class FlareChain(Chain): """Chain that combines a retriever, a question generator, and a response generator.""" - question_generator_chain: QuestionGeneratorChain + question_generator_chain: Runnable """Chain that generates questions from uncertain spans.""" - response_chain: _ResponseChain + response_chain: Runnable """Chain that generates responses from user input and context.""" output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) """Parser that determines whether the chain is finished.""" @@ -152,12 +116,16 @@ class FlareChain(Chain): for question in questions: docs.extend(self.retriever.invoke(question)) context = "\n\n".join(d.page_content for d in docs) - result = self.response_chain.predict( - user_input=user_input, - context=context, - response=response, - callbacks=callbacks, + result = self.response_chain.invoke( + { + "user_input": user_input, + "context": context, + "response": response, + }, + {"callbacks": callbacks}, ) + if isinstance(result, AIMessage): + result = result.content marginal, finished = self.output_parser.parse(result) return marginal, finished @@ -178,13 +146,18 @@ class FlareChain(Chain): for span in low_confidence_spans ] callbacks = _run_manager.get_child() - question_gen_outputs = self.question_generator_chain.apply( - question_gen_inputs, callbacks=callbacks - ) - questions = [ - output[self.question_generator_chain.output_keys[0]] - for output in question_gen_outputs - ] + if isinstance(self.question_generator_chain, LLMChain): + question_gen_outputs = self.question_generator_chain.apply( + question_gen_inputs, callbacks=callbacks + ) + questions = [ + output[self.question_generator_chain.output_keys[0]] + for output in question_gen_outputs + ] + else: + questions = self.question_generator_chain.batch( + question_gen_inputs, config={"callbacks": callbacks} + ) _run_manager.on_text( f"Generated Questions: {questions}", color="yellow", end="\n" ) @@ -206,8 +179,10 @@ class FlareChain(Chain): f"Current Response: {response}", color="blue", end="\n" ) _input = {"user_input": user_input, "context": "", "response": response} - tokens, log_probs = self.response_chain.generate_tokens_and_log_probs( - _input, run_manager=_run_manager + tokens, log_probs = _extract_tokens_and_log_probs( + self.response_chain.invoke( + _input, {"callbacks": _run_manager.get_child()} + ) ) low_confidence_spans = _low_confidence_spans( tokens, @@ -251,18 +226,16 @@ class FlareChain(Chain): FlareChain class with the given language model. """ try: - from langchain_openai import OpenAI + from langchain_openai import ChatOpenAI except ImportError: raise ImportError( "OpenAI is required for FlareChain. " "Please install langchain-openai." "pip install langchain-openai" ) - question_gen_chain = QuestionGeneratorChain(llm=llm) - response_llm = OpenAI( - max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0 - ) - response_chain = _OpenAIResponseChain(llm=response_llm) + llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0) + response_chain = PROMPT | llm + question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser() return cls( question_generator_chain=question_gen_chain, response_chain=response_chain, diff --git a/libs/langchain/langchain/chains/hyde/base.py b/libs/langchain/langchain/chains/hyde/base.py index 851e76c1599..833999127b6 100644 --- a/libs/langchain/langchain/chains/hyde/base.py +++ b/libs/langchain/langchain/chains/hyde/base.py @@ -11,7 +11,9 @@ import numpy as np from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import BasePromptTemplate +from langchain_core.runnables import Runnable from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP @@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): """ base_embeddings: Embeddings - llm_chain: LLMChain + llm_chain: Runnable class Config: arbitrary_types_allowed = True @@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): @property def input_keys(self) -> List[str]: """Input keys for Hyde's LLM chain.""" - return self.llm_chain.input_keys + return self.llm_chain.input_schema.schema()["required"] @property def output_keys(self) -> List[str]: """Output keys for Hyde's LLM chain.""" - return self.llm_chain.output_keys + if isinstance(self.llm_chain, LLMChain): + return self.llm_chain.output_keys + else: + return ["text"] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call the base embeddings.""" @@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): def embed_query(self, text: str) -> List[float]: """Generate a hypothetical document and embedded it.""" - var_name = self.llm_chain.input_keys[0] - result = self.llm_chain.generate([{var_name: text}]) - documents = [generation.text for generation in result.generations[0]] + var_name = self.input_keys[0] + result = self.llm_chain.invoke({var_name: text}) + if isinstance(self.llm_chain, LLMChain): + documents = [result[self.output_keys[0]]] + else: + documents = [result] embeddings = self.embed_documents(documents) return self.combine_embeddings(embeddings) @@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): ) -> Dict[str, str]: """Call the internal llm chain.""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - return self.llm_chain(inputs, callbacks=_run_manager.get_child()) + return self.llm_chain.invoke( + inputs, config={"callbacks": _run_manager.get_child()} + ) @classmethod def from_llm( @@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): f"of {list(PROMPT_MAP.keys())}." ) - llm_chain = LLMChain(llm=llm, prompt=prompt) + llm_chain = prompt | llm | StrOutputParser() return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs) @property diff --git a/libs/langchain/langchain/chains/llm_math/base.py b/libs/langchain/langchain/chains/llm_math/base.py index 0733b0079b3..e7fd89dcd54 100644 --- a/libs/langchain/langchain/chains/llm_math/base.py +++ b/libs/langchain/langchain/chains/llm_math/base.py @@ -7,6 +7,7 @@ import re import warnings from typing import Any, Dict, List, Optional +from langchain_core._api import deprecated from langchain_core.callbacks import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT +@deprecated( + since="0.2.13", + message=( + "This class is deprecated and will be removed in langchain 1.0. " + "See API reference for replacement: " + "https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501 + ), + removal="1.0", +) class LLMMathChain(Chain): """Chain that interprets a prompt and executes python code to do math. + Note: this class is deprecated. See below for a replacement implementation + using LangGraph. The benefits of this implementation are: + + - Uses LLM tool calling features; + - Support for both token-by-token and step-by-step streaming; + - Support for checkpointing and memory of chat history; + - Easier to modify or extend (e.g., with additional tools, structured responses, etc.) + + Install LangGraph with: + + .. code-block:: bash + + pip install -U langgraph + + .. code-block:: python + + import math + from typing import Annotated, Sequence + + from langchain_core.messages import BaseMessage + from langchain_core.runnables import RunnableConfig + from langchain_core.tools import tool + from langchain_openai import ChatOpenAI + from langgraph.graph import END, StateGraph + from langgraph.graph.message import add_messages + from langgraph.prebuilt.tool_node import ToolNode + import numexpr + from typing_extensions import TypedDict + + @tool + def calculator(expression: str) -> str: + \"\"\"Calculate expression using Python's numexpr library. + + Expression should be a single line mathematical expression + that solves the problem. + + Examples: + "37593 * 67" for "37593 times 67" + "37593**(1/5)" for "37593^(1/5)" + \"\"\" + local_dict = {"pi": math.pi, "e": math.e} + return str( + numexpr.evaluate( + expression.strip(), + global_dict={}, # restrict access to globals + local_dict=local_dict, # add common mathematical functions + ) + ) + + llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) + tools = [calculator] + llm_with_tools = llm.bind_tools(tools, tool_choice="any") + + class ChainState(TypedDict): + \"\"\"LangGraph state.\"\"\" + + messages: Annotated[Sequence[BaseMessage], add_messages] + + async def acall_chain(state: ChainState, config: RunnableConfig): + last_message = state["messages"][-1] + response = await llm_with_tools.ainvoke(state["messages"], config) + return {"messages": [response]} + + async def acall_model(state: ChainState, config: RunnableConfig): + response = await llm.ainvoke(state["messages"], config) + return {"messages": [response]} + + graph_builder = StateGraph(ChainState) + graph_builder.add_node("call_tool", acall_chain) + graph_builder.add_node("execute_tool", ToolNode(tools)) + graph_builder.add_node("call_model", acall_model) + graph_builder.set_entry_point("call_tool") + graph_builder.add_edge("call_tool", "execute_tool") + graph_builder.add_edge("execute_tool", "call_model") + graph_builder.add_edge("call_model", END) + chain = graph_builder.compile() + + .. code-block:: python + + example_query = "What is 551368 divided by 82" + + events = chain.astream( + {"messages": [("user", example_query)]}, + stream_mode="values", + ) + async for event in events: + event["messages"][-1].pretty_print() + + .. code-block:: none + + ================================ Human Message ================================= + + What is 551368 divided by 82 + ================================== Ai Message ================================== + Tool Calls: + calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS) + Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS + Args: + expression: 551368 / 82 + ================================= Tool Message ================================= + Name: calculator + + 6724.0 + ================================== Ai Message ================================== + + 551368 divided by 82 equals 6724. + Example: .. code-block:: python from langchain.chains import LLMMathChain from langchain_community.llms import OpenAI llm_math = LLMMathChain.from_llm(OpenAI()) - """ + """ # noqa: E501 llm_chain: LLMChain llm: Optional[BaseLanguageModel] = None diff --git a/libs/langchain/langchain/chains/natbot/base.py b/libs/langchain/langchain/chains/natbot/base.py index 910e03f7d4f..e92131ff35c 100644 --- a/libs/langchain/langchain/chains/natbot/base.py +++ b/libs/langchain/langchain/chains/natbot/base.py @@ -5,15 +5,27 @@ from __future__ import annotations import warnings from typing import Any, Dict, List, Optional +from langchain_core._api import deprecated from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import StrOutputParser from langchain_core.pydantic_v1 import root_validator +from langchain_core.runnables import Runnable from langchain.chains.base import Chain -from langchain.chains.llm import LLMChain from langchain.chains.natbot.prompt import PROMPT +@deprecated( + since="0.2.13", + message=( + "Importing NatBotChain from langchain is deprecated and will be removed in " + "langchain 1.0. Please import from langchain_community instead: " + "from langchain_community.chains.natbot import NatBotChain. " + "You may need to pip install -U langchain-community." + ), + removal="1.0", +) class NatBotChain(Chain): """Implement an LLM driven browser. @@ -37,7 +49,7 @@ class NatBotChain(Chain): natbot = NatBotChain.from_default("Buy me a new hat.") """ - llm_chain: LLMChain + llm_chain: Runnable objective: str """Objective that NatBot is tasked with completing.""" llm: Optional[BaseLanguageModel] = None @@ -60,7 +72,7 @@ class NatBotChain(Chain): "class method." ) if "llm_chain" not in values and values["llm"] is not None: - values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT) + values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser() return values @classmethod @@ -77,7 +89,7 @@ class NatBotChain(Chain): cls, llm: BaseLanguageModel, objective: str, **kwargs: Any ) -> NatBotChain: """Load from LLM.""" - llm_chain = LLMChain(llm=llm, prompt=PROMPT) + llm_chain = PROMPT | llm | StrOutputParser() return cls(llm_chain=llm_chain, objective=objective, **kwargs) @property @@ -104,12 +116,14 @@ class NatBotChain(Chain): _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() url = inputs[self.input_url_key] browser_content = inputs[self.input_browser_content_key] - llm_cmd = self.llm_chain.predict( - objective=self.objective, - url=url[:100], - previous_command=self.previous_command, - browser_content=browser_content[:4500], - callbacks=_run_manager.get_child(), + llm_cmd = self.llm_chain.invoke( + { + "objective": self.objective, + "url": url[:100], + "previous_command": self.previous_command, + "browser_content": browser_content[:4500], + }, + config={"callbacks": _run_manager.get_child()}, ) llm_cmd = llm_cmd.strip() self.previous_command = llm_cmd diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py index 95a56677cc4..cc86f2be49b 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py @@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from langchain_core.output_parsers import BaseOutputParser +from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import Runnable from langchain.chains.llm import LLMChain from langchain.retrievers.document_compressors.base import BaseDocumentCompressor @@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor): """Document compressor that uses an LLM chain to extract the relevant parts of documents.""" - llm_chain: LLMChain + llm_chain: Runnable """LLM wrapper to use for compressing documents.""" get_input: Callable[[str, Document], dict] = default_get_input """Callable for constructing the chain input from the query and a Document.""" + class Config: + arbitrary_types_allowed = True + def compress_documents( self, documents: Sequence[Document], @@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor): compressed_docs = [] for doc in documents: _input = self.get_input(query, doc) - output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks}) - output = output_dict[self.llm_chain.output_key] - if self.llm_chain.prompt.output_parser is not None: - output = self.llm_chain.prompt.output_parser.parse(output) + output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks}) + if isinstance(self.llm_chain, LLMChain): + output = output_[self.llm_chain.output_key] + if self.llm_chain.prompt.output_parser is not None: + output = self.llm_chain.prompt.output_parser.parse(output) + else: + output = output_ if len(output) == 0: continue compressed_docs.append( @@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor): """Compress page content of raw documents asynchronously.""" outputs = await asyncio.gather( *[ - self.llm_chain.apredict_and_parse( - **self.get_input(query, doc), callbacks=callbacks - ) + self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks) for doc in documents ] ) @@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor): """Initialize from LLM.""" _prompt = prompt if prompt is not None else _get_default_chain_prompt() _get_input = get_input if get_input is not None else default_get_input - llm_chain = LLMChain(llm=llm, prompt=_prompt, **(llm_chain_kwargs or {})) + if _prompt.output_parser is not None: + parser = _prompt.output_parser + else: + parser = StrOutputParser() + llm_chain = _prompt | llm | parser return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type] diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index 1efaef7abf0..2db6f5be3a7 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -5,7 +5,9 @@ from typing import Any, Callable, Dict, Optional, Sequence from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.runnables import Runnable from langchain_core.runnables.config import RunnableConfig from langchain.chains import LLMChain @@ -32,13 +34,16 @@ def default_get_input(query: str, doc: Document) -> Dict[str, Any]: class LLMChainFilter(BaseDocumentCompressor): """Filter that drops documents that aren't relevant to the query.""" - llm_chain: LLMChain + llm_chain: Runnable """LLM wrapper to use for filtering documents. The chain prompt is expected to have a BooleanOutputParser.""" get_input: Callable[[str, Document], dict] = default_get_input """Callable for constructing the chain input from the query and a Document.""" + class Config: + arbitrary_types_allowed = True + def compress_documents( self, documents: Sequence[Document], @@ -56,11 +61,15 @@ class LLMChainFilter(BaseDocumentCompressor): documents, ) - for output_dict, doc in outputs: + for output_, doc in outputs: include_doc = None - output = output_dict[self.llm_chain.output_key] - if self.llm_chain.prompt.output_parser is not None: - include_doc = self.llm_chain.prompt.output_parser.parse(output) + if isinstance(self.llm_chain, LLMChain): + output = output_[self.llm_chain.output_key] + if self.llm_chain.prompt.output_parser is not None: + include_doc = self.llm_chain.prompt.output_parser.parse(output) + else: + if isinstance(output_, bool): + include_doc = output_ if include_doc: filtered_docs.append(doc) @@ -82,11 +91,15 @@ class LLMChainFilter(BaseDocumentCompressor): ), documents, ) - for output_dict, doc in outputs: + for output_, doc in outputs: include_doc = None - output = output_dict[self.llm_chain.output_key] - if self.llm_chain.prompt.output_parser is not None: - include_doc = self.llm_chain.prompt.output_parser.parse(output) + if isinstance(self.llm_chain, LLMChain): + output = output_[self.llm_chain.output_key] + if self.llm_chain.prompt.output_parser is not None: + include_doc = self.llm_chain.prompt.output_parser.parse(output) + else: + if isinstance(output_, bool): + include_doc = output_ if include_doc: filtered_docs.append(doc) @@ -110,5 +123,9 @@ class LLMChainFilter(BaseDocumentCompressor): A LLMChainFilter that uses the given language model. """ _prompt = prompt if prompt is not None else _get_default_chain_prompt() - llm_chain = LLMChain(llm=llm, prompt=_prompt) + if _prompt.output_parser is not None: + parser = _prompt.output_parser + else: + parser = StrOutputParser() + llm_chain = _prompt | llm | parser return cls(llm_chain=llm_chain, **kwargs) diff --git a/libs/langchain/langchain/retrievers/re_phraser.py b/libs/langchain/langchain/retrievers/re_phraser.py index 5fdc47d10f2..55cb054e997 100644 --- a/libs/langchain/langchain/retrievers/re_phraser.py +++ b/libs/langchain/langchain/retrievers/re_phraser.py @@ -7,11 +7,11 @@ from langchain_core.callbacks import ( ) from langchain_core.documents import Document from langchain_core.language_models import BaseLLM +from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.retrievers import BaseRetriever - -from langchain.chains.llm import LLMChain +from langchain_core.runnables import Runnable logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ class RePhraseQueryRetriever(BaseRetriever): Then, retrieve docs for the re-phrased query.""" retriever: BaseRetriever - llm_chain: LLMChain + llm_chain: Runnable @classmethod def from_llm( @@ -51,8 +51,7 @@ class RePhraseQueryRetriever(BaseRetriever): Returns: RePhraseQueryRetriever """ - - llm_chain = LLMChain(llm=llm, prompt=prompt) + llm_chain = prompt | llm | StrOutputParser() return cls( retriever=retriever, llm_chain=llm_chain, @@ -72,8 +71,9 @@ class RePhraseQueryRetriever(BaseRetriever): Returns: Relevant documents for re-phrased question """ - response = self.llm_chain(query, callbacks=run_manager.get_child()) - re_phrased_question = response["text"] + re_phrased_question = self.llm_chain.invoke( + query, {"callbacks": run_manager.get_child()} + ) logger.info(f"Re-phrased question: {re_phrased_question}") docs = self.retriever.invoke( re_phrased_question, config={"callbacks": run_manager.get_child()} diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_extract.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_extract.py new file mode 100644 index 00000000000..1e4afed1eec --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_extract.py @@ -0,0 +1,84 @@ +from langchain_core.documents import Document +from langchain_core.language_models import FakeListChatModel + +from langchain.retrievers.document_compressors import LLMChainExtractor + + +def test_llm_chain_extractor() -> None: + documents = [ + Document( + page_content=( + "The sky is blue. Candlepin bowling is popular in New England." + ), + metadata={"a": 1}, + ), + Document( + page_content=( + "Mercury is the closest planet to the Sun. " + "Candlepin bowling balls are smaller." + ), + metadata={"b": 2}, + ), + Document(page_content="The moon is round.", metadata={"c": 3}), + ] + llm = FakeListChatModel( + responses=[ + "Candlepin bowling is popular in New England.", + "Candlepin bowling balls are smaller.", + "NO_OUTPUT", + ] + ) + doc_compressor = LLMChainExtractor.from_llm(llm) + output = doc_compressor.compress_documents( + documents, "Tell me about Candlepin bowling." + ) + expected = documents = [ + Document( + page_content="Candlepin bowling is popular in New England.", + metadata={"a": 1}, + ), + Document( + page_content="Candlepin bowling balls are smaller.", metadata={"b": 2} + ), + ] + assert output == expected + + +async def test_llm_chain_extractor_async() -> None: + documents = [ + Document( + page_content=( + "The sky is blue. Candlepin bowling is popular in New England." + ), + metadata={"a": 1}, + ), + Document( + page_content=( + "Mercury is the closest planet to the Sun. " + "Candlepin bowling balls are smaller." + ), + metadata={"b": 2}, + ), + Document(page_content="The moon is round.", metadata={"c": 3}), + ] + llm = FakeListChatModel( + responses=[ + "Candlepin bowling is popular in New England.", + "Candlepin bowling balls are smaller.", + "NO_OUTPUT", + ] + ) + doc_compressor = LLMChainExtractor.from_llm(llm) + output = await doc_compressor.acompress_documents( + documents, "Tell me about Candlepin bowling." + ) + expected = documents = [ + Document( + page_content="Candlepin bowling is popular in New England.", + metadata={"a": 1}, + ), + Document( + page_content="Candlepin bowling balls are smaller.", metadata={"b": 2} + ), + ] + assert output == expected diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_filter.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_filter.py new file mode 100644 index 00000000000..4020694afa6 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_filter.py @@ -0,0 +1,46 @@ +from langchain_core.documents import Document +from langchain_core.language_models import FakeListChatModel + +from langchain.retrievers.document_compressors import LLMChainFilter + + +def test_llm_chain_filter() -> None: + documents = [ + Document( + page_content="Candlepin bowling is popular in New England.", + metadata={"a": 1}, + ), + Document( + page_content="Candlepin bowling balls are smaller.", + metadata={"b": 2}, + ), + Document(page_content="The moon is round.", metadata={"c": 3}), + ] + llm = FakeListChatModel(responses=["YES", "YES", "NO"]) + doc_compressor = LLMChainFilter.from_llm(llm) + output = doc_compressor.compress_documents( + documents, "Tell me about Candlepin bowling." + ) + expected = documents[:2] + assert output == expected + + +async def test_llm_chain_extractor_async() -> None: + documents = [ + Document( + page_content="Candlepin bowling is popular in New England.", + metadata={"a": 1}, + ), + Document( + page_content="Candlepin bowling balls are smaller.", + metadata={"b": 2}, + ), + Document(page_content="The moon is round.", metadata={"c": 3}), + ] + llm = FakeListChatModel(responses=["YES", "YES", "NO"]) + doc_compressor = LLMChainFilter.from_llm(llm) + output = await doc_compressor.acompress_documents( + documents, "Tell me about Candlepin bowling." + ) + expected = documents[:2] + assert output == expected