mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
docs: add tool choice to tool calling (#21229)
This commit is contained in:
parent
67a5cc34c6
commit
70bde15480
@ -20,11 +20,15 @@
|
||||
"\n",
|
||||
"```{=mdx}\n",
|
||||
":::info\n",
|
||||
"We use the term tool calling interchangeably with function calling. Although\n",
|
||||
"We use the term \"tool calling\" interchangeably with \"function calling\". Although\n",
|
||||
"function calling is sometimes meant to refer to invocations of a single function,\n",
|
||||
"we treat all models as though they can return multiple tool or function calls in \n",
|
||||
"each message.\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
":::tip\n",
|
||||
"See [here](/docs/integrations/chat/) for a list of all models that support tool calling.\n",
|
||||
":::\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Tool calling allows a model to respond to a given prompt by generating output that \n",
|
||||
@ -86,12 +90,14 @@
|
||||
"LangChain implements standard interfaces for defining tools, passing them to LLMs, \n",
|
||||
"and representing tool calls.\n",
|
||||
"\n",
|
||||
"## Passing tools to LLMs\n",
|
||||
"## Request: Passing tools to model\n",
|
||||
"\n",
|
||||
"Chat models supporting tool calling features implement a `.bind_tools` method, which \n",
|
||||
"receives a list of LangChain [tool objects](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool) \n",
|
||||
"and binds them to the chat model in its expected format. Subsequent invocations of the \n",
|
||||
"chat model will include tool schemas in its calls to the LLM.\n",
|
||||
"For a model to be able to invoke tools, you need to pass tool schemas to it when making a chat request.\n",
|
||||
"LangChain ChatModels supporting tool calling features implement a `.bind_tools` method, which \n",
|
||||
"receives a list of LangChain [tool objects](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool), Pydantic classes, or JSON Schemas and binds them to the chat model in the provider-specific expected format. Subsequent invocations of the \n",
|
||||
"bound chat model will include tool schemas in every call to the model API.\n",
|
||||
"\n",
|
||||
"### Defining tool schemas: LangChain Tool\n",
|
||||
"\n",
|
||||
"For example, we can define the schema for custom tools using the `@tool` decorator \n",
|
||||
"on Python functions:"
|
||||
@ -99,7 +105,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 1,
|
||||
"id": "841dca72-1b57-4a42-8e22-da4835c4cfe0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -109,13 +115,23 @@
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def add(a: int, b: int) -> int:\n",
|
||||
" \"\"\"Adds a and b.\"\"\"\n",
|
||||
" \"\"\"Adds a and b.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" a: first int\n",
|
||||
" b: second int\n",
|
||||
" \"\"\"\n",
|
||||
" return a + b\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def multiply(a: int, b: int) -> int:\n",
|
||||
" \"\"\"Multiplies a and b.\"\"\"\n",
|
||||
" \"\"\"Multiplies a and b.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" a: first int\n",
|
||||
" b: second int\n",
|
||||
" \"\"\"\n",
|
||||
" return a * b\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -127,12 +143,14 @@
|
||||
"id": "48058b7d-048d-48e6-a272-3931ad7ad146",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Or below, we define the schema using Pydantic:\n"
|
||||
"### Defining tool schemas: Pydantic class\n",
|
||||
"\n",
|
||||
"We can equivalently define the schema using Pydantic. Pydantic is useful when your tool inputs are more complex:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 2,
|
||||
"id": "fca56328-85e4-4839-97b7-b5dc55920602",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -142,21 +160,21 @@
|
||||
"\n",
|
||||
"# Note that the docstrings here are crucial, as they will be passed along\n",
|
||||
"# to the model along with the class name.\n",
|
||||
"class Add(BaseModel):\n",
|
||||
"class add(BaseModel):\n",
|
||||
" \"\"\"Add two integers together.\"\"\"\n",
|
||||
"\n",
|
||||
" a: int = Field(..., description=\"First integer\")\n",
|
||||
" b: int = Field(..., description=\"Second integer\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Multiply(BaseModel):\n",
|
||||
"class multiply(BaseModel):\n",
|
||||
" \"\"\"Multiply two integers together.\"\"\"\n",
|
||||
"\n",
|
||||
" a: int = Field(..., description=\"First integer\")\n",
|
||||
" b: int = Field(..., description=\"Second integer\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [Add, Multiply]"
|
||||
"tools = [add, multiply]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -175,6 +193,8 @@
|
||||
"/>\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Binding tool schemas\n",
|
||||
"\n",
|
||||
"We can use the `bind_tools()` method to handle converting\n",
|
||||
"`Multiply` to a \"tool\" and binding it to the model (i.e.,\n",
|
||||
"passing it in each time the model is invoked)."
|
||||
@ -182,7 +202,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"execution_count": 3,
|
||||
"id": "44eb8327-a03d-4c7c-945e-30f13f455346",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -197,7 +217,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 68,
|
||||
"execution_count": 4,
|
||||
"id": "af2a83ac-e43f-43ce-b107-9ed8376bfb75",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -205,17 +225,45 @@
|
||||
"llm_with_tools = llm.bind_tools(tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3dd0e53f-d48d-4952-b53e-faf97ebf8831",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Request: Forcing a tool call\n",
|
||||
"\n",
|
||||
"When you just use `bind_tools(tools)`, the model can choose whether to return one tool call, multiple tool calls, or no tool calls at all. Some models support a `tool_choice` parameter that gives you some ability to force the model to call a tool. For models that support this, you can pass in the name of the tool you want the model to always call `tool_choice=\"xyz_tool_name\"`. Or you can pass in `tool_choice=\"any\"` to force the model to call at least one tool, without specifying which tool specifically.\n",
|
||||
"\n",
|
||||
"```{=mdx}\n",
|
||||
":::note\n",
|
||||
"Currently `tool_choice=\"any\"` functionality is supported by OpenAI, MistralAI, FireworksAI, and Groq.\n",
|
||||
"\n",
|
||||
"Currently Anthropic does not support `tool_choice` at all.\n",
|
||||
":::\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"If we wanted our model to always call the multiply tool we could do:\n",
|
||||
"```python\n",
|
||||
"always_multiply_llm = llm.bind_tools([multiply], tool_choice=\"multiply\")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"And if we wanted it to always call at least one of add or multiply, we could do:\n",
|
||||
"```python\n",
|
||||
"always_call_tool_llm = llm.bind_tools([add, multiply], tool_choice=\"any\")\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "16208230-f64f-4935-9aa1-280a91f34ba3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tool calls\n",
|
||||
"## Response: Reading tool calls from model output\n",
|
||||
"\n",
|
||||
"If tool calls are included in a LLM response, they are attached to the corresponding \n",
|
||||
"[message](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessage.html#langchain_core.messages.ai.AIMessage) \n",
|
||||
"or [message chunk](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) \n",
|
||||
"as a list of [tool call](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolCall.html#langchain_core.messages.tool.ToolCall) \n",
|
||||
"[AIMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessage.html#langchain_core.messages.ai.AIMessage) \n",
|
||||
"or [AIMessageChunk](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) (when streaming)\n",
|
||||
"as a list of [ToolCall](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolCall.html#langchain_core.messages.tool.ToolCall) \n",
|
||||
"objects in the `.tool_calls` attribute. A `ToolCall` is a typed dict that includes a \n",
|
||||
"tool name, dict of argument values, and (optionally) an identifier. Messages with no \n",
|
||||
"tool calls default to an empty list for this attribute.\n",
|
||||
@ -225,22 +273,22 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 5,
|
||||
"id": "1640a4b4-c201-4b23-b257-738d854fb9fd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'Multiply',\n",
|
||||
"[{'name': 'multiply',\n",
|
||||
" 'args': {'a': 3, 'b': 12},\n",
|
||||
" 'id': 'call_1Tdp5wUXbYQzpkBoagGXqUTo'},\n",
|
||||
" {'name': 'Add',\n",
|
||||
" 'id': 'call_UL7E2232GfDHIQGOM4gJfEDD'},\n",
|
||||
" {'name': 'add',\n",
|
||||
" 'args': {'a': 11, 'b': 49},\n",
|
||||
" 'id': 'call_k9v09vYioS3X0Qg35zESuUKI'}]"
|
||||
" 'id': 'call_VKw8t5tpAuzvbHgdAXe9mjUx'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -269,17 +317,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 6,
|
||||
"id": "ca15fcad-74fe-4109-a1b1-346c3eefe238",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Multiply(a=3, b=12), Add(a=11, b=49)]"
|
||||
"[multiply(a=3, b=12), add(a=11, b=49)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -287,7 +335,7 @@
|
||||
"source": [
|
||||
"from langchain_core.output_parsers.openai_tools import PydanticToolsParser\n",
|
||||
"\n",
|
||||
"chain = llm_with_tools | PydanticToolsParser(tools=[Multiply, Add])\n",
|
||||
"chain = llm_with_tools | PydanticToolsParser(tools=[multiply, add])\n",
|
||||
"chain.invoke(query)"
|
||||
]
|
||||
},
|
||||
@ -296,7 +344,7 @@
|
||||
"id": "0ba3505d-f405-43ba-93c4-7fbd84f6464b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Streaming\n",
|
||||
"## Response: Streaming\n",
|
||||
"\n",
|
||||
"When tools are called in a streaming context, \n",
|
||||
"[message chunks](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) \n",
|
||||
@ -319,7 +367,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 7,
|
||||
"id": "4f54a0de-74c7-4f2d-86c5-660aed23840d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -328,12 +376,12 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[]\n",
|
||||
"[{'name': 'Multiply', 'args': '', 'id': 'call_d39MsxKM5cmeGJOoYKdGBgzc', 'index': 0}]\n",
|
||||
"[{'name': 'multiply', 'args': '', 'id': 'call_5Gdgx3R2z97qIycWKixgD2OU', 'index': 0}]\n",
|
||||
"[{'name': None, 'args': '{\"a\"', 'id': None, 'index': 0}]\n",
|
||||
"[{'name': None, 'args': ': 3, ', 'id': None, 'index': 0}]\n",
|
||||
"[{'name': None, 'args': '\"b\": 1', 'id': None, 'index': 0}]\n",
|
||||
"[{'name': None, 'args': '2}', 'id': None, 'index': 0}]\n",
|
||||
"[{'name': 'Add', 'args': '', 'id': 'call_QJpdxD9AehKbdXzMHxgDMMhs', 'index': 1}]\n",
|
||||
"[{'name': 'add', 'args': '', 'id': 'call_DpeKaF8pUCmLP0tkinhdmBgD', 'index': 1}]\n",
|
||||
"[{'name': None, 'args': '{\"a\"', 'id': None, 'index': 1}]\n",
|
||||
"[{'name': None, 'args': ': 11,', 'id': None, 'index': 1}]\n",
|
||||
"[{'name': None, 'args': ' \"b\": ', 'id': None, 'index': 1}]\n",
|
||||
@ -359,7 +407,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 8,
|
||||
"id": "0a944af0-eedd-43c8-8ff3-f4301f129d9b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -368,17 +416,17 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[]\n",
|
||||
"[{'name': 'Multiply', 'args': '', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\"', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, ', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 1', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\"', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11,', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": ', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
|
||||
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n"
|
||||
"[{'name': 'multiply', 'args': '', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\"', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, ', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 1', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\"', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11,', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": ', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
|
||||
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -396,7 +444,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 9,
|
||||
"id": "db4e3e3a-3553-44dc-bd31-149c0981a06a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -422,7 +470,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 10,
|
||||
"id": "e9402bde-d4b5-4564-a99e-f88c9b46b28a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -432,16 +480,16 @@
|
||||
"text": [
|
||||
"[]\n",
|
||||
"[]\n",
|
||||
"[{'name': 'Multiply', 'args': {}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 1}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n",
|
||||
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n"
|
||||
"[{'name': 'multiply', 'args': {}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3, 'b': 1}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n",
|
||||
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -459,7 +507,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 11,
|
||||
"id": "8c2f21cc-0c6d-416a-871f-e854621c96e2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -480,14 +528,14 @@
|
||||
"id": "97a0c977-0c3c-4011-b49b-db98c609d0ce",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Passing tool outputs to model\n",
|
||||
"## Request: Passing tool outputs to model\n",
|
||||
"\n",
|
||||
"If we're using the model-generated tool invocations to actually call tools and want to pass the tool results back to the model, we can do so using `ToolMessage`s."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 117,
|
||||
"execution_count": 13,
|
||||
"id": "48049192-be28-42ab-9a44-d897924e67cd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -495,12 +543,12 @@
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[HumanMessage(content='What is 3 * 12? Also, what is 11 + 49?'),\n",
|
||||
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_K5DsWEmgt6D08EI9AFu9NaL1', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'Multiply'}, 'type': 'function'}, {'id': 'call_qywVrsplg0ZMv7LHYYMjyG81', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'Add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 50, 'prompt_tokens': 105, 'total_tokens': 155}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-1a0b8cdd-9221-4d94-b2ed-5701f67ce9fe-0', tool_calls=[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_K5DsWEmgt6D08EI9AFu9NaL1'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_qywVrsplg0ZMv7LHYYMjyG81'}]),\n",
|
||||
" ToolMessage(content='36', tool_call_id='call_K5DsWEmgt6D08EI9AFu9NaL1'),\n",
|
||||
" ToolMessage(content='60', tool_call_id='call_qywVrsplg0ZMv7LHYYMjyG81')]"
|
||||
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Jja7J89XsjrOLA5rAjULqTSL', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'multiply'}, 'type': 'function'}, {'id': 'call_K4ArVEUjhl36EcSuxGN1nwvZ', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 49, 'prompt_tokens': 144, 'total_tokens': 193}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_a450710239', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9db7e8e1-86d5-4015-9f43-f1d33abea64d-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_Jja7J89XsjrOLA5rAjULqTSL'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_K4ArVEUjhl36EcSuxGN1nwvZ'}]),\n",
|
||||
" ToolMessage(content='36', tool_call_id='call_Jja7J89XsjrOLA5rAjULqTSL'),\n",
|
||||
" ToolMessage(content='60', tool_call_id='call_K4ArVEUjhl36EcSuxGN1nwvZ')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 117,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -508,29 +556,57 @@
|
||||
"source": [
|
||||
"from langchain_core.messages import HumanMessage, ToolMessage\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def add(a: int, b: int) -> int:\n",
|
||||
" \"\"\"Adds a and b.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" a: first int\n",
|
||||
" b: second int\n",
|
||||
" \"\"\"\n",
|
||||
" return a + b\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def multiply(a: int, b: int) -> int:\n",
|
||||
" \"\"\"Multiplies a and b.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" a: first int\n",
|
||||
" b: second int\n",
|
||||
" \"\"\"\n",
|
||||
" return a * b\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [add, multiply]\n",
|
||||
"llm_with_tools = llm.bind_tools(tools)\n",
|
||||
"\n",
|
||||
"messages = [HumanMessage(query)]\n",
|
||||
"ai_msg = llm_with_tools.invoke(messages)\n",
|
||||
"messages.append(ai_msg)\n",
|
||||
"\n",
|
||||
"for tool_call in ai_msg.tool_calls:\n",
|
||||
" selected_tool = {\"add\": add, \"multiply\": multiply}[tool_call[\"name\"].lower()]\n",
|
||||
" tool_output = selected_tool.invoke(tool_call[\"args\"])\n",
|
||||
" messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
|
||||
"\n",
|
||||
"messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 118,
|
||||
"execution_count": 14,
|
||||
"id": "611e0f36-d736-48d1-bca1-1cec51d223f3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='3 * 12 is 36 and 11 + 49 is 60.', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 171, 'total_tokens': 189}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-a6c8093c-b16a-4c92-8308-7c9ac998118c-0')"
|
||||
"AIMessage(content='3 * 12 = 36\\n11 + 49 = 60', response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 209, 'total_tokens': 225}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'stop', 'logprobs': None}, id='run-a55f8cb5-6d6d-4835-9c6b-7de36b2590c7-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 118,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -544,31 +620,36 @@
|
||||
"id": "a5937498-d6fe-400a-b192-ef35c314168e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Few-shot prompting\n",
|
||||
"## Request: Few-shot prompting\n",
|
||||
"\n",
|
||||
"For more complex tool use it's very useful to add few-shot examples to the prompt. We can do this by adding `AIMessage`s with `ToolCall`s and corresponding `ToolMessage`s to our prompt.\n",
|
||||
"For more complex tool use it's very useful to add few-shot examples to the prompt. We can do this by adding `AIMessage`s with `ToolCall`s and corresponding `ToolMessage`s to our prompt. \n",
|
||||
"\n",
|
||||
"```{=mdx}\n",
|
||||
":::note\n",
|
||||
"For most models it's important that the ToolCall and ToolMessage ids line up, so that each AIMessage with ToolCalls is followed by ToolMessages with corresponding ids.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"For example, even with some special instructions our model can get tripped up by order of operations:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 112,
|
||||
"execution_count": 15,
|
||||
"id": "5ef2e7c3-0925-49da-ab8f-e42c4fa40f29",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'Multiply',\n",
|
||||
"[{'name': 'multiply',\n",
|
||||
" 'args': {'a': 119, 'b': 8},\n",
|
||||
" 'id': 'call_Dl3FXRVkQCFW4sUNYOe4rFr7'},\n",
|
||||
" {'name': 'Add',\n",
|
||||
" 'id': 'call_RofMKNQ2qbWAFaMsef4cpTS9'},\n",
|
||||
" {'name': 'add',\n",
|
||||
" 'args': {'a': 952, 'b': -20},\n",
|
||||
" 'id': 'call_n03l4hmka7VZTCiP387Wud2C'}]"
|
||||
" 'id': 'call_HjOfoF8ceMCHmO3cpwG6oB3X'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 112,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -591,19 +672,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 107,
|
||||
"execution_count": 16,
|
||||
"id": "7b2e8b19-270f-4e1a-8be7-7aad704c1cf4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'Multiply',\n",
|
||||
"[{'name': 'multiply',\n",
|
||||
" 'args': {'a': 119, 'b': 8},\n",
|
||||
" 'id': 'call_MoSgwzIhPxhclfygkYaKIsGZ'}]"
|
||||
" 'id': 'call_tWwpzWqqc8dQtN13CyKZCVMe'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 107,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -621,14 +702,14 @@
|
||||
" \"\",\n",
|
||||
" name=\"example_assistant\",\n",
|
||||
" tool_calls=[\n",
|
||||
" {\"name\": \"Multiply\", \"args\": {\"x\": 317253, \"y\": 128472}, \"id\": \"1\"}\n",
|
||||
" {\"name\": \"multiply\", \"args\": {\"x\": 317253, \"y\": 128472}, \"id\": \"1\"}\n",
|
||||
" ],\n",
|
||||
" ),\n",
|
||||
" ToolMessage(\"16505054784\", tool_call_id=\"1\"),\n",
|
||||
" AIMessage(\n",
|
||||
" \"\",\n",
|
||||
" name=\"example_assistant\",\n",
|
||||
" tool_calls=[{\"name\": \"Add\", \"args\": {\"x\": 16505054784, \"y\": 4}, \"id\": \"2\"}],\n",
|
||||
" tool_calls=[{\"name\": \"add\", \"args\": {\"x\": 16505054784, \"y\": 4}, \"id\": \"2\"}],\n",
|
||||
" ),\n",
|
||||
" ToolMessage(\"16505054788\", tool_call_id=\"2\"),\n",
|
||||
" AIMessage(\n",
|
||||
|
Loading…
Reference in New Issue
Block a user