mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-29 07:19:59 +00:00
232 lines
7.9 KiB
Plaintext
232 lines
7.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# How to pass tool outputs to chat models\n",
|
|
"\n",
|
|
":::info Prerequisites\n",
|
|
"This guide assumes familiarity with the following concepts:\n",
|
|
"\n",
|
|
"- [LangChain Tools](/docs/concepts/tools)\n",
|
|
"- [Function/tool calling](/docs/concepts/tool_calling)\n",
|
|
"- [Using chat models to call tools](/docs/how_to/tool_calling)\n",
|
|
"- [Defining custom tools](/docs/how_to/custom_tools/)\n",
|
|
"\n",
|
|
":::\n",
|
|
"\n",
|
|
"Some models are capable of [**tool calling**](/docs/concepts/tool_calling) - generating arguments that conform to a specific user-provided schema. This guide will demonstrate how to use those tool calls to actually call a function and properly pass the results back to the model.\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"First, let's define our tools and our model:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
|
|
"\n",
|
|
"<ChatModelTabs\n",
|
|
" customVarName=\"llm\"\n",
|
|
" overrideParams={{fireworks: {model: \"accounts/fireworks/models/firefunction-v1\", kwargs: \"temperature=0\"}}}\n",
|
|
"/>\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# | output: false\n",
|
|
"# | echo: false\n",
|
|
"\n",
|
|
"import os\n",
|
|
"from getpass import getpass\n",
|
|
"\n",
|
|
"from langchain_openai import ChatOpenAI\n",
|
|
"\n",
|
|
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
|
" os.environ[\"OPENAI_API_KEY\"] = getpass()\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_core.tools import tool\n",
|
|
"\n",
|
|
"\n",
|
|
"@tool\n",
|
|
"def add(a: int, b: int) -> int:\n",
|
|
" \"\"\"Adds a and b.\"\"\"\n",
|
|
" return a + b\n",
|
|
"\n",
|
|
"\n",
|
|
"@tool\n",
|
|
"def multiply(a: int, b: int) -> int:\n",
|
|
" \"\"\"Multiplies a and b.\"\"\"\n",
|
|
" return a * b\n",
|
|
"\n",
|
|
"\n",
|
|
"tools = [add, multiply]\n",
|
|
"\n",
|
|
"llm_with_tools = llm.bind_tools(tools)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, let's get the model to call a tool. We'll add it to a list of messages that we'll treat as conversation history:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_GPGPE943GORirhIAYnWv00rK', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_dm8o64ZrY3WFZHAvCh1bEJ6i', 'type': 'tool_call'}]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain_core.messages import HumanMessage\n",
|
|
"\n",
|
|
"query = \"What is 3 * 12? Also, what is 11 + 49?\"\n",
|
|
"\n",
|
|
"messages = [HumanMessage(query)]\n",
|
|
"\n",
|
|
"ai_msg = llm_with_tools.invoke(messages)\n",
|
|
"\n",
|
|
"print(ai_msg.tool_calls)\n",
|
|
"\n",
|
|
"messages.append(ai_msg)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next let's invoke the tool functions using the args the model populated!\n",
|
|
"\n",
|
|
"Conveniently, if we invoke a LangChain `Tool` with a `ToolCall`, we'll automatically get back a `ToolMessage` that can be fed back to the model:\n",
|
|
"\n",
|
|
":::caution Compatibility\n",
|
|
"\n",
|
|
"This functionality was added in `langchain-core == 0.2.19`. Please make sure your package is up to date.\n",
|
|
"\n",
|
|
"If you are on earlier versions of `langchain-core`, you will need to extract the `args` field from the tool and construct a `ToolMessage` manually.\n",
|
|
"\n",
|
|
":::"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[HumanMessage(content='What is 3 * 12? Also, what is 11 + 49?'),\n",
|
|
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_loT2pliJwJe3p7nkgXYF48A1', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'multiply'}, 'type': 'function'}, {'id': 'call_bG9tYZCXOeYDZf3W46TceoV4', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 50, 'prompt_tokens': 87, 'total_tokens': 137}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_661538dc1f', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-e3db3c46-bf9e-478e-abc1-dc9a264f4afe-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_loT2pliJwJe3p7nkgXYF48A1', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_bG9tYZCXOeYDZf3W46TceoV4', 'type': 'tool_call'}], usage_metadata={'input_tokens': 87, 'output_tokens': 50, 'total_tokens': 137}),\n",
|
|
" ToolMessage(content='36', name='multiply', tool_call_id='call_loT2pliJwJe3p7nkgXYF48A1'),\n",
|
|
" ToolMessage(content='60', name='add', tool_call_id='call_bG9tYZCXOeYDZf3W46TceoV4')]"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"for tool_call in ai_msg.tool_calls:\n",
|
|
" selected_tool = {\"add\": add, \"multiply\": multiply}[tool_call[\"name\"].lower()]\n",
|
|
" tool_msg = selected_tool.invoke(tool_call)\n",
|
|
" messages.append(tool_msg)\n",
|
|
"\n",
|
|
"messages"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And finally, we'll invoke the model with the tool results. The model will use this information to generate a final answer to our original query:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content='The result of \\\\(3 \\\\times 12\\\\) is 36, and the result of \\\\(11 + 49\\\\) is 60.', response_metadata={'token_usage': {'completion_tokens': 31, 'prompt_tokens': 153, 'total_tokens': 184}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_661538dc1f', 'finish_reason': 'stop', 'logprobs': None}, id='run-87d1ef0a-1223-4bb3-9310-7b591789323d-0', usage_metadata={'input_tokens': 153, 'output_tokens': 31, 'total_tokens': 184})"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"llm_with_tools.invoke(messages)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Note that each `ToolMessage` must include a `tool_call_id` that matches an `id` in the original tool calls that the model generates. This helps the model match tool responses with tool calls.\n",
|
|
"\n",
|
|
"Tool calling agents, like those in [LangGraph](https://langchain-ai.github.io/langgraph/tutorials/introduction/), use this basic flow to answer queries and solve tasks.\n",
|
|
"\n",
|
|
"## Related\n",
|
|
"\n",
|
|
"- [LangGraph quickstart](https://langchain-ai.github.io/langgraph/tutorials/introduction/)\n",
|
|
"- Few shot prompting [with tools](/docs/how_to/tools_few_shot/)\n",
|
|
"- Stream [tool calls](/docs/how_to/tool_streaming/)\n",
|
|
"- Pass [runtime values to tools](/docs/how_to/tool_runtime)\n",
|
|
"- Getting [structured outputs](/docs/how_to/structured_output/) from models"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|