From 70bde154807cbe434dfa7c059a43db6796b6b3e7 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 3 May 2024 03:10:22 -0400 Subject: [PATCH] docs: add tool choice to tool calling (#21229) --- .../model_io/chat/function_calling.ipynb | 245 ++++++++++++------ 1 file changed, 163 insertions(+), 82 deletions(-) diff --git a/docs/docs/modules/model_io/chat/function_calling.ipynb b/docs/docs/modules/model_io/chat/function_calling.ipynb index 92f66b429ef..7b40f158757 100644 --- a/docs/docs/modules/model_io/chat/function_calling.ipynb +++ b/docs/docs/modules/model_io/chat/function_calling.ipynb @@ -20,11 +20,15 @@ "\n", "```{=mdx}\n", ":::info\n", - "We use the term tool calling interchangeably with function calling. Although\n", + "We use the term \"tool calling\" interchangeably with \"function calling\". Although\n", "function calling is sometimes meant to refer to invocations of a single function,\n", "we treat all models as though they can return multiple tool or function calls in \n", "each message.\n", ":::\n", + "\n", + ":::tip\n", + "See [here](/docs/integrations/chat/) for a list of all models that support tool calling.\n", + ":::\n", "```\n", "\n", "Tool calling allows a model to respond to a given prompt by generating output that \n", @@ -86,12 +90,14 @@ "LangChain implements standard interfaces for defining tools, passing them to LLMs, \n", "and representing tool calls.\n", "\n", - "## Passing tools to LLMs\n", + "## Request: Passing tools to model\n", "\n", - "Chat models supporting tool calling features implement a `.bind_tools` method, which \n", - "receives a list of LangChain [tool objects](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool) \n", - "and binds them to the chat model in its expected format. Subsequent invocations of the \n", - "chat model will include tool schemas in its calls to the LLM.\n", + "For a model to be able to invoke tools, you need to pass tool schemas to it when making a chat request.\n", + "LangChain ChatModels supporting tool calling features implement a `.bind_tools` method, which \n", + "receives a list of LangChain [tool objects](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool), Pydantic classes, or JSON Schemas and binds them to the chat model in the provider-specific expected format. Subsequent invocations of the \n", + "bound chat model will include tool schemas in every call to the model API.\n", + "\n", + "### Defining tool schemas: LangChain Tool\n", "\n", "For example, we can define the schema for custom tools using the `@tool` decorator \n", "on Python functions:" @@ -99,7 +105,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 1, "id": "841dca72-1b57-4a42-8e22-da4835c4cfe0", "metadata": {}, "outputs": [], @@ -109,13 +115,23 @@ "\n", "@tool\n", "def add(a: int, b: int) -> int:\n", - " \"\"\"Adds a and b.\"\"\"\n", + " \"\"\"Adds a and b.\n", + "\n", + " Args:\n", + " a: first int\n", + " b: second int\n", + " \"\"\"\n", " return a + b\n", "\n", "\n", "@tool\n", "def multiply(a: int, b: int) -> int:\n", - " \"\"\"Multiplies a and b.\"\"\"\n", + " \"\"\"Multiplies a and b.\n", + "\n", + " Args:\n", + " a: first int\n", + " b: second int\n", + " \"\"\"\n", " return a * b\n", "\n", "\n", @@ -127,12 +143,14 @@ "id": "48058b7d-048d-48e6-a272-3931ad7ad146", "metadata": {}, "source": [ - "Or below, we define the schema using Pydantic:\n" + "### Defining tool schemas: Pydantic class\n", + "\n", + "We can equivalently define the schema using Pydantic. Pydantic is useful when your tool inputs are more complex:" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 2, "id": "fca56328-85e4-4839-97b7-b5dc55920602", "metadata": {}, "outputs": [], @@ -142,21 +160,21 @@ "\n", "# Note that the docstrings here are crucial, as they will be passed along\n", "# to the model along with the class name.\n", - "class Add(BaseModel):\n", + "class add(BaseModel):\n", " \"\"\"Add two integers together.\"\"\"\n", "\n", " a: int = Field(..., description=\"First integer\")\n", " b: int = Field(..., description=\"Second integer\")\n", "\n", "\n", - "class Multiply(BaseModel):\n", + "class multiply(BaseModel):\n", " \"\"\"Multiply two integers together.\"\"\"\n", "\n", " a: int = Field(..., description=\"First integer\")\n", " b: int = Field(..., description=\"Second integer\")\n", "\n", "\n", - "tools = [Add, Multiply]" + "tools = [add, multiply]" ] }, { @@ -175,6 +193,8 @@ "/>\n", "```\n", "\n", + "### Binding tool schemas\n", + "\n", "We can use the `bind_tools()` method to handle converting\n", "`Multiply` to a \"tool\" and binding it to the model (i.e.,\n", "passing it in each time the model is invoked)." @@ -182,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 3, "id": "44eb8327-a03d-4c7c-945e-30f13f455346", "metadata": {}, "outputs": [], @@ -197,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 4, "id": "af2a83ac-e43f-43ce-b107-9ed8376bfb75", "metadata": {}, "outputs": [], @@ -205,17 +225,45 @@ "llm_with_tools = llm.bind_tools(tools)" ] }, + { + "cell_type": "markdown", + "id": "3dd0e53f-d48d-4952-b53e-faf97ebf8831", + "metadata": {}, + "source": [ + "## Request: Forcing a tool call\n", + "\n", + "When you just use `bind_tools(tools)`, the model can choose whether to return one tool call, multiple tool calls, or no tool calls at all. Some models support a `tool_choice` parameter that gives you some ability to force the model to call a tool. For models that support this, you can pass in the name of the tool you want the model to always call `tool_choice=\"xyz_tool_name\"`. Or you can pass in `tool_choice=\"any\"` to force the model to call at least one tool, without specifying which tool specifically.\n", + "\n", + "```{=mdx}\n", + ":::note\n", + "Currently `tool_choice=\"any\"` functionality is supported by OpenAI, MistralAI, FireworksAI, and Groq.\n", + "\n", + "Currently Anthropic does not support `tool_choice` at all.\n", + ":::\n", + "```\n", + "\n", + "If we wanted our model to always call the multiply tool we could do:\n", + "```python\n", + "always_multiply_llm = llm.bind_tools([multiply], tool_choice=\"multiply\")\n", + "```\n", + "\n", + "And if we wanted it to always call at least one of add or multiply, we could do:\n", + "```python\n", + "always_call_tool_llm = llm.bind_tools([add, multiply], tool_choice=\"any\")\n", + "```" + ] + }, { "cell_type": "markdown", "id": "16208230-f64f-4935-9aa1-280a91f34ba3", "metadata": {}, "source": [ - "## Tool calls\n", + "## Response: Reading tool calls from model output\n", "\n", "If tool calls are included in a LLM response, they are attached to the corresponding \n", - "[message](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessage.html#langchain_core.messages.ai.AIMessage) \n", - "or [message chunk](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) \n", - "as a list of [tool call](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolCall.html#langchain_core.messages.tool.ToolCall) \n", + "[AIMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessage.html#langchain_core.messages.ai.AIMessage) \n", + "or [AIMessageChunk](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) (when streaming)\n", + "as a list of [ToolCall](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolCall.html#langchain_core.messages.tool.ToolCall) \n", "objects in the `.tool_calls` attribute. A `ToolCall` is a typed dict that includes a \n", "tool name, dict of argument values, and (optionally) an identifier. Messages with no \n", "tool calls default to an empty list for this attribute.\n", @@ -225,22 +273,22 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "id": "1640a4b4-c201-4b23-b257-738d854fb9fd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'name': 'Multiply',\n", + "[{'name': 'multiply',\n", " 'args': {'a': 3, 'b': 12},\n", - " 'id': 'call_1Tdp5wUXbYQzpkBoagGXqUTo'},\n", - " {'name': 'Add',\n", + " 'id': 'call_UL7E2232GfDHIQGOM4gJfEDD'},\n", + " {'name': 'add',\n", " 'args': {'a': 11, 'b': 49},\n", - " 'id': 'call_k9v09vYioS3X0Qg35zESuUKI'}]" + " 'id': 'call_VKw8t5tpAuzvbHgdAXe9mjUx'}]" ] }, - "execution_count": 15, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -269,17 +317,17 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 6, "id": "ca15fcad-74fe-4109-a1b1-346c3eefe238", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[Multiply(a=3, b=12), Add(a=11, b=49)]" + "[multiply(a=3, b=12), add(a=11, b=49)]" ] }, - "execution_count": 16, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -287,7 +335,7 @@ "source": [ "from langchain_core.output_parsers.openai_tools import PydanticToolsParser\n", "\n", - "chain = llm_with_tools | PydanticToolsParser(tools=[Multiply, Add])\n", + "chain = llm_with_tools | PydanticToolsParser(tools=[multiply, add])\n", "chain.invoke(query)" ] }, @@ -296,7 +344,7 @@ "id": "0ba3505d-f405-43ba-93c4-7fbd84f6464b", "metadata": {}, "source": [ - "### Streaming\n", + "## Response: Streaming\n", "\n", "When tools are called in a streaming context, \n", "[message chunks](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) \n", @@ -319,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 7, "id": "4f54a0de-74c7-4f2d-86c5-660aed23840d", "metadata": {}, "outputs": [ @@ -328,12 +376,12 @@ "output_type": "stream", "text": [ "[]\n", - "[{'name': 'Multiply', 'args': '', 'id': 'call_d39MsxKM5cmeGJOoYKdGBgzc', 'index': 0}]\n", + "[{'name': 'multiply', 'args': '', 'id': 'call_5Gdgx3R2z97qIycWKixgD2OU', 'index': 0}]\n", "[{'name': None, 'args': '{\"a\"', 'id': None, 'index': 0}]\n", "[{'name': None, 'args': ': 3, ', 'id': None, 'index': 0}]\n", "[{'name': None, 'args': '\"b\": 1', 'id': None, 'index': 0}]\n", "[{'name': None, 'args': '2}', 'id': None, 'index': 0}]\n", - "[{'name': 'Add', 'args': '', 'id': 'call_QJpdxD9AehKbdXzMHxgDMMhs', 'index': 1}]\n", + "[{'name': 'add', 'args': '', 'id': 'call_DpeKaF8pUCmLP0tkinhdmBgD', 'index': 1}]\n", "[{'name': None, 'args': '{\"a\"', 'id': None, 'index': 1}]\n", "[{'name': None, 'args': ': 11,', 'id': None, 'index': 1}]\n", "[{'name': None, 'args': ' \"b\": ', 'id': None, 'index': 1}]\n", @@ -359,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 8, "id": "0a944af0-eedd-43c8-8ff3-f4301f129d9b", "metadata": {}, "outputs": [ @@ -368,17 +416,17 @@ "output_type": "stream", "text": [ "[]\n", - "[{'name': 'Multiply', 'args': '', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n", - "[{'name': 'Multiply', 'args': '{\"a\"', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, ', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 1', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\"', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11,', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": ', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n", - "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n" + "[{'name': 'multiply', 'args': '', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n", + "[{'name': 'multiply', 'args': '{\"a\"', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, ', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 1', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\"', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11,', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": ', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n", + "[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n" ] } ], @@ -396,7 +444,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 9, "id": "db4e3e3a-3553-44dc-bd31-149c0981a06a", "metadata": {}, "outputs": [ @@ -422,7 +470,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 10, "id": "e9402bde-d4b5-4564-a99e-f88c9b46b28a", "metadata": {}, "outputs": [ @@ -432,16 +480,16 @@ "text": [ "[]\n", "[]\n", - "[{'name': 'Multiply', 'args': {}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3, 'b': 1}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n", - "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n" + "[{'name': 'multiply', 'args': {}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n", + "[{'name': 'multiply', 'args': {'a': 3}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n", + "[{'name': 'multiply', 'args': {'a': 3, 'b': 1}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n", + "[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n", + "[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n", + "[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n", + "[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n", + "[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n", + "[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n", + "[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n" ] } ], @@ -459,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 11, "id": "8c2f21cc-0c6d-416a-871f-e854621c96e2", "metadata": {}, "outputs": [ @@ -480,14 +528,14 @@ "id": "97a0c977-0c3c-4011-b49b-db98c609d0ce", "metadata": {}, "source": [ - "## Passing tool outputs to model\n", + "## Request: Passing tool outputs to model\n", "\n", "If we're using the model-generated tool invocations to actually call tools and want to pass the tool results back to the model, we can do so using `ToolMessage`s." ] }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 13, "id": "48049192-be28-42ab-9a44-d897924e67cd", "metadata": {}, "outputs": [ @@ -495,12 +543,12 @@ "data": { "text/plain": [ "[HumanMessage(content='What is 3 * 12? Also, what is 11 + 49?'),\n", - " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_K5DsWEmgt6D08EI9AFu9NaL1', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'Multiply'}, 'type': 'function'}, {'id': 'call_qywVrsplg0ZMv7LHYYMjyG81', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'Add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 50, 'prompt_tokens': 105, 'total_tokens': 155}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-1a0b8cdd-9221-4d94-b2ed-5701f67ce9fe-0', tool_calls=[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_K5DsWEmgt6D08EI9AFu9NaL1'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_qywVrsplg0ZMv7LHYYMjyG81'}]),\n", - " ToolMessage(content='36', tool_call_id='call_K5DsWEmgt6D08EI9AFu9NaL1'),\n", - " ToolMessage(content='60', tool_call_id='call_qywVrsplg0ZMv7LHYYMjyG81')]" + " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Jja7J89XsjrOLA5rAjULqTSL', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'multiply'}, 'type': 'function'}, {'id': 'call_K4ArVEUjhl36EcSuxGN1nwvZ', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 49, 'prompt_tokens': 144, 'total_tokens': 193}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_a450710239', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9db7e8e1-86d5-4015-9f43-f1d33abea64d-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_Jja7J89XsjrOLA5rAjULqTSL'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_K4ArVEUjhl36EcSuxGN1nwvZ'}]),\n", + " ToolMessage(content='36', tool_call_id='call_Jja7J89XsjrOLA5rAjULqTSL'),\n", + " ToolMessage(content='60', tool_call_id='call_K4ArVEUjhl36EcSuxGN1nwvZ')]" ] }, - "execution_count": 117, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -508,29 +556,57 @@ "source": [ "from langchain_core.messages import HumanMessage, ToolMessage\n", "\n", + "\n", + "@tool\n", + "def add(a: int, b: int) -> int:\n", + " \"\"\"Adds a and b.\n", + "\n", + " Args:\n", + " a: first int\n", + " b: second int\n", + " \"\"\"\n", + " return a + b\n", + "\n", + "\n", + "@tool\n", + "def multiply(a: int, b: int) -> int:\n", + " \"\"\"Multiplies a and b.\n", + "\n", + " Args:\n", + " a: first int\n", + " b: second int\n", + " \"\"\"\n", + " return a * b\n", + "\n", + "\n", + "tools = [add, multiply]\n", + "llm_with_tools = llm.bind_tools(tools)\n", + "\n", "messages = [HumanMessage(query)]\n", "ai_msg = llm_with_tools.invoke(messages)\n", "messages.append(ai_msg)\n", + "\n", "for tool_call in ai_msg.tool_calls:\n", " selected_tool = {\"add\": add, \"multiply\": multiply}[tool_call[\"name\"].lower()]\n", " tool_output = selected_tool.invoke(tool_call[\"args\"])\n", " messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n", + "\n", "messages" ] }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 14, "id": "611e0f36-d736-48d1-bca1-1cec51d223f3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content='3 * 12 is 36 and 11 + 49 is 60.', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 171, 'total_tokens': 189}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-a6c8093c-b16a-4c92-8308-7c9ac998118c-0')" + "AIMessage(content='3 * 12 = 36\\n11 + 49 = 60', response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 209, 'total_tokens': 225}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'stop', 'logprobs': None}, id='run-a55f8cb5-6d6d-4835-9c6b-7de36b2590c7-0')" ] }, - "execution_count": 118, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -544,31 +620,36 @@ "id": "a5937498-d6fe-400a-b192-ef35c314168e", "metadata": {}, "source": [ - "## Few-shot prompting\n", + "## Request: Few-shot prompting\n", "\n", - "For more complex tool use it's very useful to add few-shot examples to the prompt. We can do this by adding `AIMessage`s with `ToolCall`s and corresponding `ToolMessage`s to our prompt.\n", + "For more complex tool use it's very useful to add few-shot examples to the prompt. We can do this by adding `AIMessage`s with `ToolCall`s and corresponding `ToolMessage`s to our prompt. \n", + "\n", + "```{=mdx}\n", + ":::note\n", + "For most models it's important that the ToolCall and ToolMessage ids line up, so that each AIMessage with ToolCalls is followed by ToolMessages with corresponding ids.\n", + "```\n", "\n", "For example, even with some special instructions our model can get tripped up by order of operations:" ] }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 15, "id": "5ef2e7c3-0925-49da-ab8f-e42c4fa40f29", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'name': 'Multiply',\n", + "[{'name': 'multiply',\n", " 'args': {'a': 119, 'b': 8},\n", - " 'id': 'call_Dl3FXRVkQCFW4sUNYOe4rFr7'},\n", - " {'name': 'Add',\n", + " 'id': 'call_RofMKNQ2qbWAFaMsef4cpTS9'},\n", + " {'name': 'add',\n", " 'args': {'a': 952, 'b': -20},\n", - " 'id': 'call_n03l4hmka7VZTCiP387Wud2C'}]" + " 'id': 'call_HjOfoF8ceMCHmO3cpwG6oB3X'}]" ] }, - "execution_count": 112, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -591,19 +672,19 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 16, "id": "7b2e8b19-270f-4e1a-8be7-7aad704c1cf4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'name': 'Multiply',\n", + "[{'name': 'multiply',\n", " 'args': {'a': 119, 'b': 8},\n", - " 'id': 'call_MoSgwzIhPxhclfygkYaKIsGZ'}]" + " 'id': 'call_tWwpzWqqc8dQtN13CyKZCVMe'}]" ] }, - "execution_count": 107, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -621,14 +702,14 @@ " \"\",\n", " name=\"example_assistant\",\n", " tool_calls=[\n", - " {\"name\": \"Multiply\", \"args\": {\"x\": 317253, \"y\": 128472}, \"id\": \"1\"}\n", + " {\"name\": \"multiply\", \"args\": {\"x\": 317253, \"y\": 128472}, \"id\": \"1\"}\n", " ],\n", " ),\n", " ToolMessage(\"16505054784\", tool_call_id=\"1\"),\n", " AIMessage(\n", " \"\",\n", " name=\"example_assistant\",\n", - " tool_calls=[{\"name\": \"Add\", \"args\": {\"x\": 16505054784, \"y\": 4}, \"id\": \"2\"}],\n", + " tool_calls=[{\"name\": \"add\", \"args\": {\"x\": 16505054784, \"y\": 4}, \"id\": \"2\"}],\n", " ),\n", " ToolMessage(\"16505054788\", tool_call_id=\"2\"),\n", " AIMessage(\n",