Compare commits

...

5 Commits

Author SHA1 Message Date
Bagatur
3d26dc8569 fmt 2024-05-02 18:35:55 -04:00
Bagatur
2cdcc4353e fmt 2024-05-02 18:25:48 -04:00
Bagatur
e96cabc176 Update libs/partners/openai/langchain_openai/chat_models/base.py
Co-authored-by: ccurme <chester.curme@gmail.com>
2024-05-02 17:33:02 -04:00
Bagatur
e2904a2be3 poetry 2024-05-02 14:11:58 -04:00
Bagatur
ac3d16e5a4 openai[patch]: support tool_choice="required" 2024-05-02 14:08:44 -04:00
6 changed files with 269 additions and 104 deletions

View File

@@ -20,11 +20,15 @@
"\n",
"```{=mdx}\n",
":::info\n",
"We use the term tool calling interchangeably with function calling. Although\n",
"We use the term \"tool calling\" interchangeably with \"function calling\". Although\n",
"function calling is sometimes meant to refer to invocations of a single function,\n",
"we treat all models as though they can return multiple tool or function calls in \n",
"each message.\n",
":::\n",
"\n",
":::tip\n",
"See [here](/docs/integrations/chat/) for a list of all models that support tool calling.\n",
":::\n",
"```\n",
"\n",
"Tool calling allows a model to respond to a given prompt by generating output that \n",
@@ -86,12 +90,14 @@
"LangChain implements standard interfaces for defining tools, passing them to LLMs, \n",
"and representing tool calls.\n",
"\n",
"## Passing tools to LLMs\n",
"## Request: Passing tools to model\n",
"\n",
"Chat models supporting tool calling features implement a `.bind_tools` method, which \n",
"receives a list of LangChain [tool objects](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool) \n",
"and binds them to the chat model in its expected format. Subsequent invocations of the \n",
"chat model will include tool schemas in its calls to the LLM.\n",
"For a model to be able to invoke tools, you need to pass tool schemas to it when making a chat request.\n",
"LangChain ChatModels supporting tool calling features implement a `.bind_tools` method, which \n",
"receives a list of LangChain [tool objects](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool), Pydantic classes, or JSON Schemas and binds them to the chat model in the provider-specific expected format. Subsequent invocations of the \n",
"bound chat model will include tool schemas in every call to the model API.\n",
"\n",
"### Defining tool schemas: LangChain Tool\n",
"\n",
"For example, we can define the schema for custom tools using the `@tool` decorator \n",
"on Python functions:"
@@ -99,7 +105,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 1,
"id": "841dca72-1b57-4a42-8e22-da4835c4cfe0",
"metadata": {},
"outputs": [],
@@ -109,13 +115,23 @@
"\n",
"@tool\n",
"def add(a: int, b: int) -> int:\n",
" \"\"\"Adds a and b.\"\"\"\n",
" \"\"\"Adds a and b.\n",
"\n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" return a + b\n",
"\n",
"\n",
"@tool\n",
"def multiply(a: int, b: int) -> int:\n",
" \"\"\"Multiplies a and b.\"\"\"\n",
" \"\"\"Multiplies a and b.\n",
"\n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" return a * b\n",
"\n",
"\n",
@@ -127,12 +143,14 @@
"id": "48058b7d-048d-48e6-a272-3931ad7ad146",
"metadata": {},
"source": [
"Or below, we define the schema using Pydantic:\n"
"### Defining tool schemas: Pydantic class\n",
"\n",
"We can equivalently define the schema using Pydantic. Pydantic is useful when your tool inputs are more complex:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 2,
"id": "fca56328-85e4-4839-97b7-b5dc55920602",
"metadata": {},
"outputs": [],
@@ -142,21 +160,21 @@
"\n",
"# Note that the docstrings here are crucial, as they will be passed along\n",
"# to the model along with the class name.\n",
"class Add(BaseModel):\n",
"class add(BaseModel):\n",
" \"\"\"Add two integers together.\"\"\"\n",
"\n",
" a: int = Field(..., description=\"First integer\")\n",
" b: int = Field(..., description=\"Second integer\")\n",
"\n",
"\n",
"class Multiply(BaseModel):\n",
"class multiply(BaseModel):\n",
" \"\"\"Multiply two integers together.\"\"\"\n",
"\n",
" a: int = Field(..., description=\"First integer\")\n",
" b: int = Field(..., description=\"Second integer\")\n",
"\n",
"\n",
"tools = [Add, Multiply]"
"tools = [add, multiply]"
]
},
{
@@ -175,6 +193,8 @@
"/>\n",
"```\n",
"\n",
"### Binding tool schemas\n",
"\n",
"We can use the `bind_tools()` method to handle converting\n",
"`Multiply` to a \"tool\" and binding it to the model (i.e.,\n",
"passing it in each time the model is invoked)."
@@ -182,7 +202,7 @@
},
{
"cell_type": "code",
"execution_count": 67,
"execution_count": 3,
"id": "44eb8327-a03d-4c7c-945e-30f13f455346",
"metadata": {},
"outputs": [],
@@ -197,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 4,
"id": "af2a83ac-e43f-43ce-b107-9ed8376bfb75",
"metadata": {},
"outputs": [],
@@ -205,17 +225,45 @@
"llm_with_tools = llm.bind_tools(tools)"
]
},
{
"cell_type": "markdown",
"id": "3dd0e53f-d48d-4952-b53e-faf97ebf8831",
"metadata": {},
"source": [
"## Request: Forcing a tool call\n",
"\n",
"When you just use `bind_tools(tools)`, the model can choose whether to return one tool call, multiple tool calls, or no tool calls at all. Some models support a `tool_choice` parameter that gives you some ability to force the model to call a tool. For models that support this, you can pass in the name of the tool you want the model to always call `tool_choice=\"xyz_tool_name\"`. Or you can pass in `tool_choice=\"any\"` to force the model to call at least one tool, without specifying which tool specifically.\n",
"\n",
"```{=mdx}\n",
":::note\n",
"Currently `tool_choice=\"any\"` functionality is supported by OpenAI, MistralAI, FireworksAI, and Groq.\n",
"\n",
"Currently Anthropic does not support `tool_choice` at all.\n",
":::\n",
"```\n",
"\n",
"If we wanted our model to always call the multiply tool we could do:\n",
"```python\n",
"always_multiply_llm = llm.bind_tools([multiply], tool_choice=\"multiply\")\n",
"```\n",
"\n",
"And if we wanted it to always call at least one of add or multiply, we could do:\n",
"```python\n",
"always_call_tool_llm = llm.bind_tools([add, multiply], tool_choice=\"any\")\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "16208230-f64f-4935-9aa1-280a91f34ba3",
"metadata": {},
"source": [
"## Tool calls\n",
"## Response: Reading tool calls from model output\n",
"\n",
"If tool calls are included in a LLM response, they are attached to the corresponding \n",
"[message](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessage.html#langchain_core.messages.ai.AIMessage) \n",
"or [message chunk](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) \n",
"as a list of [tool call](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolCall.html#langchain_core.messages.tool.ToolCall) \n",
"[AIMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessage.html#langchain_core.messages.ai.AIMessage) \n",
"or [AIMessageChunk](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) (when streaming)\n",
"as a list of [ToolCall](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolCall.html#langchain_core.messages.tool.ToolCall) \n",
"objects in the `.tool_calls` attribute. A `ToolCall` is a typed dict that includes a \n",
"tool name, dict of argument values, and (optionally) an identifier. Messages with no \n",
"tool calls default to an empty list for this attribute.\n",
@@ -225,22 +273,22 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 5,
"id": "1640a4b4-c201-4b23-b257-738d854fb9fd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'Multiply',\n",
"[{'name': 'multiply',\n",
" 'args': {'a': 3, 'b': 12},\n",
" 'id': 'call_1Tdp5wUXbYQzpkBoagGXqUTo'},\n",
" {'name': 'Add',\n",
" 'id': 'call_UL7E2232GfDHIQGOM4gJfEDD'},\n",
" {'name': 'add',\n",
" 'args': {'a': 11, 'b': 49},\n",
" 'id': 'call_k9v09vYioS3X0Qg35zESuUKI'}]"
" 'id': 'call_VKw8t5tpAuzvbHgdAXe9mjUx'}]"
]
},
"execution_count": 15,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -269,17 +317,17 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 6,
"id": "ca15fcad-74fe-4109-a1b1-346c3eefe238",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Multiply(a=3, b=12), Add(a=11, b=49)]"
"[multiply(a=3, b=12), add(a=11, b=49)]"
]
},
"execution_count": 16,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -287,7 +335,7 @@
"source": [
"from langchain_core.output_parsers.openai_tools import PydanticToolsParser\n",
"\n",
"chain = llm_with_tools | PydanticToolsParser(tools=[Multiply, Add])\n",
"chain = llm_with_tools | PydanticToolsParser(tools=[multiply, add])\n",
"chain.invoke(query)"
]
},
@@ -296,7 +344,7 @@
"id": "0ba3505d-f405-43ba-93c4-7fbd84f6464b",
"metadata": {},
"source": [
"### Streaming\n",
"## Response: Streaming\n",
"\n",
"When tools are called in a streaming context, \n",
"[message chunks](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) \n",
@@ -319,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 7,
"id": "4f54a0de-74c7-4f2d-86c5-660aed23840d",
"metadata": {},
"outputs": [
@@ -328,12 +376,12 @@
"output_type": "stream",
"text": [
"[]\n",
"[{'name': 'Multiply', 'args': '', 'id': 'call_d39MsxKM5cmeGJOoYKdGBgzc', 'index': 0}]\n",
"[{'name': 'multiply', 'args': '', 'id': 'call_5Gdgx3R2z97qIycWKixgD2OU', 'index': 0}]\n",
"[{'name': None, 'args': '{\"a\"', 'id': None, 'index': 0}]\n",
"[{'name': None, 'args': ': 3, ', 'id': None, 'index': 0}]\n",
"[{'name': None, 'args': '\"b\": 1', 'id': None, 'index': 0}]\n",
"[{'name': None, 'args': '2}', 'id': None, 'index': 0}]\n",
"[{'name': 'Add', 'args': '', 'id': 'call_QJpdxD9AehKbdXzMHxgDMMhs', 'index': 1}]\n",
"[{'name': 'add', 'args': '', 'id': 'call_DpeKaF8pUCmLP0tkinhdmBgD', 'index': 1}]\n",
"[{'name': None, 'args': '{\"a\"', 'id': None, 'index': 1}]\n",
"[{'name': None, 'args': ': 11,', 'id': None, 'index': 1}]\n",
"[{'name': None, 'args': ' \"b\": ', 'id': None, 'index': 1}]\n",
@@ -359,7 +407,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 8,
"id": "0a944af0-eedd-43c8-8ff3-f4301f129d9b",
"metadata": {},
"outputs": [
@@ -368,17 +416,17 @@
"output_type": "stream",
"text": [
"[]\n",
"[{'name': 'Multiply', 'args': '', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
"[{'name': 'Multiply', 'args': '{\"a\"', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, ', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 1', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\"', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11,', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": ', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n",
"[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_erKtz8z3e681cmxYKbRof0NS', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_tYHYdEV2YBvzDcSCiFCExNvw', 'index': 1}]\n"
"[{'name': 'multiply', 'args': '', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
"[{'name': 'multiply', 'args': '{\"a\"', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, ', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 1', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\"', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11,', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": ', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n",
"[{'name': 'multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_hXqj6HxzACkpiPG4hFFuIKuP', 'index': 0}, {'name': 'add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_GERgANDUbRqdtmXRbIAS9JTS', 'index': 1}]\n"
]
}
],
@@ -396,7 +444,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 9,
"id": "db4e3e3a-3553-44dc-bd31-149c0981a06a",
"metadata": {},
"outputs": [
@@ -422,7 +470,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 10,
"id": "e9402bde-d4b5-4564-a99e-f88c9b46b28a",
"metadata": {},
"outputs": [
@@ -432,16 +480,16 @@
"text": [
"[]\n",
"[]\n",
"[{'name': 'Multiply', 'args': {}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 1}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n",
"[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BXqUtt6jYCwR1DguqpS2ehP0'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_UjSHJKROSAw2BDc8cp9cSv4i'}]\n"
"[{'name': 'multiply', 'args': {}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
"[{'name': 'multiply', 'args': {'a': 3}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
"[{'name': 'multiply', 'args': {'a': 3, 'b': 1}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}]\n",
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n",
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n",
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n",
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n",
"[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_aXQdLhKJpEpUxTNPXIS4l7Mv'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_P39VunIrq9MQOxHgF30VByuB'}]\n"
]
}
],
@@ -459,7 +507,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 11,
"id": "8c2f21cc-0c6d-416a-871f-e854621c96e2",
"metadata": {},
"outputs": [
@@ -480,14 +528,14 @@
"id": "97a0c977-0c3c-4011-b49b-db98c609d0ce",
"metadata": {},
"source": [
"## Passing tool outputs to model\n",
"## Request: Passing tool outputs to model\n",
"\n",
"If we're using the model-generated tool invocations to actually call tools and want to pass the tool results back to the model, we can do so using `ToolMessage`s."
]
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 13,
"id": "48049192-be28-42ab-9a44-d897924e67cd",
"metadata": {},
"outputs": [
@@ -495,12 +543,12 @@
"data": {
"text/plain": [
"[HumanMessage(content='What is 3 * 12? Also, what is 11 + 49?'),\n",
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_K5DsWEmgt6D08EI9AFu9NaL1', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'Multiply'}, 'type': 'function'}, {'id': 'call_qywVrsplg0ZMv7LHYYMjyG81', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'Add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 50, 'prompt_tokens': 105, 'total_tokens': 155}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-1a0b8cdd-9221-4d94-b2ed-5701f67ce9fe-0', tool_calls=[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_K5DsWEmgt6D08EI9AFu9NaL1'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_qywVrsplg0ZMv7LHYYMjyG81'}]),\n",
" ToolMessage(content='36', tool_call_id='call_K5DsWEmgt6D08EI9AFu9NaL1'),\n",
" ToolMessage(content='60', tool_call_id='call_qywVrsplg0ZMv7LHYYMjyG81')]"
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Jja7J89XsjrOLA5rAjULqTSL', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'multiply'}, 'type': 'function'}, {'id': 'call_K4ArVEUjhl36EcSuxGN1nwvZ', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 49, 'prompt_tokens': 144, 'total_tokens': 193}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_a450710239', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9db7e8e1-86d5-4015-9f43-f1d33abea64d-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_Jja7J89XsjrOLA5rAjULqTSL'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_K4ArVEUjhl36EcSuxGN1nwvZ'}]),\n",
" ToolMessage(content='36', tool_call_id='call_Jja7J89XsjrOLA5rAjULqTSL'),\n",
" ToolMessage(content='60', tool_call_id='call_K4ArVEUjhl36EcSuxGN1nwvZ')]"
]
},
"execution_count": 117,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -508,29 +556,57 @@
"source": [
"from langchain_core.messages import HumanMessage, ToolMessage\n",
"\n",
"\n",
"@tool\n",
"def add(a: int, b: int) -> int:\n",
" \"\"\"Adds a and b.\n",
"\n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" return a + b\n",
"\n",
"\n",
"@tool\n",
"def multiply(a: int, b: int) -> int:\n",
" \"\"\"Multiplies a and b.\n",
"\n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" return a * b\n",
"\n",
"\n",
"tools = [add, multiply]\n",
"llm_with_tools = llm.bind_tools(tools)\n",
"\n",
"messages = [HumanMessage(query)]\n",
"ai_msg = llm_with_tools.invoke(messages)\n",
"messages.append(ai_msg)\n",
"\n",
"for tool_call in ai_msg.tool_calls:\n",
" selected_tool = {\"add\": add, \"multiply\": multiply}[tool_call[\"name\"].lower()]\n",
" tool_output = selected_tool.invoke(tool_call[\"args\"])\n",
" messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
"\n",
"messages"
]
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 14,
"id": "611e0f36-d736-48d1-bca1-1cec51d223f3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='3 * 12 is 36 and 11 + 49 is 60.', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 171, 'total_tokens': 189}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-a6c8093c-b16a-4c92-8308-7c9ac998118c-0')"
"AIMessage(content='3 * 12 = 36\\n11 + 49 = 60', response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 209, 'total_tokens': 225}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'stop', 'logprobs': None}, id='run-a55f8cb5-6d6d-4835-9c6b-7de36b2590c7-0')"
]
},
"execution_count": 118,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -544,31 +620,36 @@
"id": "a5937498-d6fe-400a-b192-ef35c314168e",
"metadata": {},
"source": [
"## Few-shot prompting\n",
"## Request: Few-shot prompting\n",
"\n",
"For more complex tool use it's very useful to add few-shot examples to the prompt. We can do this by adding `AIMessage`s with `ToolCall`s and corresponding `ToolMessage`s to our prompt.\n",
"For more complex tool use it's very useful to add few-shot examples to the prompt. We can do this by adding `AIMessage`s with `ToolCall`s and corresponding `ToolMessage`s to our prompt. \n",
"\n",
"```{=mdx}\n",
":::note\n",
"For most models it's important that the ToolCall and ToolMessage ids line up, so that each AIMessage with ToolCalls is followed by ToolMessages with corresponding ids.\n",
"```\n",
"\n",
"For example, even with some special instructions our model can get tripped up by order of operations:"
]
},
{
"cell_type": "code",
"execution_count": 112,
"execution_count": 15,
"id": "5ef2e7c3-0925-49da-ab8f-e42c4fa40f29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'Multiply',\n",
"[{'name': 'multiply',\n",
" 'args': {'a': 119, 'b': 8},\n",
" 'id': 'call_Dl3FXRVkQCFW4sUNYOe4rFr7'},\n",
" {'name': 'Add',\n",
" 'id': 'call_RofMKNQ2qbWAFaMsef4cpTS9'},\n",
" {'name': 'add',\n",
" 'args': {'a': 952, 'b': -20},\n",
" 'id': 'call_n03l4hmka7VZTCiP387Wud2C'}]"
" 'id': 'call_HjOfoF8ceMCHmO3cpwG6oB3X'}]"
]
},
"execution_count": 112,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -591,19 +672,19 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 16,
"id": "7b2e8b19-270f-4e1a-8be7-7aad704c1cf4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'Multiply',\n",
"[{'name': 'multiply',\n",
" 'args': {'a': 119, 'b': 8},\n",
" 'id': 'call_MoSgwzIhPxhclfygkYaKIsGZ'}]"
" 'id': 'call_tWwpzWqqc8dQtN13CyKZCVMe'}]"
]
},
"execution_count": 107,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -621,14 +702,14 @@
" \"\",\n",
" name=\"example_assistant\",\n",
" tool_calls=[\n",
" {\"name\": \"Multiply\", \"args\": {\"x\": 317253, \"y\": 128472}, \"id\": \"1\"}\n",
" {\"name\": \"multiply\", \"args\": {\"x\": 317253, \"y\": 128472}, \"id\": \"1\"}\n",
" ],\n",
" ),\n",
" ToolMessage(\"16505054784\", tool_call_id=\"1\"),\n",
" AIMessage(\n",
" \"\",\n",
" name=\"example_assistant\",\n",
" tool_calls=[{\"name\": \"Add\", \"args\": {\"x\": 16505054784, \"y\": 4}, \"id\": \"2\"}],\n",
" tool_calls=[{\"name\": \"add\", \"args\": {\"x\": 16505054784, \"y\": 4}, \"id\": \"2\"}],\n",
" ),\n",
" ToolMessage(\"16505054788\", tool_call_id=\"2\"),\n",
" AIMessage(\n",

View File

@@ -763,7 +763,9 @@ class BaseChatOpenAI(BaseChatModel):
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
@@ -776,40 +778,55 @@ class BaseChatOpenAI(BaseChatModel):
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
Options are:
name of the tool (str): calls corresponding tool;
"auto": automatically selects a tool (including no tool);
"none": does not call a tool;
"any" or "required": force at least one tool to be called;
True: forces tool call (requires `tools` be length 1);
False: no effect;
or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if tool_choice:
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none"):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "any", "required"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
# 'any' is not natively supported by OpenAI API.
# We support 'any' since other models use this instead of 'required'.
if tool_choice == "any":
tool_choice = "required"
elif isinstance(tool_choice, bool):
if len(tools) > 1:
raise ValueError(
"tool_choice=True can only be used when a single tool is "
f"passed in, received {len(tools)} tools."
)
tool_choice = {
"type": "function",
"function": {"name": formatted_tools[0]["function"]["name"]},
}
elif isinstance(tool_choice, dict):
if (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
tool_names = [
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
f"provided tools were {tool_names}."
)
else:
raise ValueError(

View File

@@ -385,7 +385,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.46"
version = "0.1.49"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@@ -540,13 +540,13 @@ files = [
[[package]]
name = "openai"
version = "1.16.2"
version = "1.25.1"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.7.1"
files = [
{file = "openai-1.16.2-py3-none-any.whl", hash = "sha256:46a435380921e42dae218d04d6dd0e89a30d7f3b9d8a778d5887f78003cf9354"},
{file = "openai-1.16.2.tar.gz", hash = "sha256:c93d5efe5b73b6cb72c4cd31823852d2e7c84a138c0af3cbe4a8eb32b1164ab2"},
{file = "openai-1.25.1-py3-none-any.whl", hash = "sha256:aa2f381f476f5fa4df8728a34a3e454c321caa064b7b68ab6e9daa1ed082dbf9"},
{file = "openai-1.25.1.tar.gz", hash = "sha256:f561ce86f4b4008eb6c78622d641e4b7e1ab8a8cdb15d2f0b2a49942d40d21a8"},
]
[package.dependencies]
@@ -1286,4 +1286,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "1d9cefc90178d94dee2a09afc14af160a7e35e4972ad4701d3bbbfdde14a81fa"
content-hash = "2dbfc54f73eec285047a224d9dcddd5d16d24c693f550b792d399826497bbbf8"

View File

@@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.46"
openai = "^1.10.0"
openai = "^1.24.0"
tiktoken = ">=0.5.2,<1"
[tool.poetry.group.test]

View File

@@ -479,6 +479,15 @@ class GenerateUsername(BaseModel):
hair_color: str
class MakeASandwich(BaseModel):
"Make a sandwich given a list of ingredients."
bread_type: str
cheese_type: str
condiments: List[str]
vegetables: List[str]
def test_tool_use() -> None:
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
@@ -563,6 +572,21 @@ def test_manual_tool_call_msg() -> None:
llm_with_tool.invoke(msgs)
def test_bind_tools_tool_choice() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
for tool_choice in ("any", "required"):
llm_with_tools = llm.bind_tools(
tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice
)
msg = cast(AIMessage, llm_with_tools.invoke("how are you"))
assert msg.tool_calls
llm_with_tools = llm.bind_tools(tools=[GenerateUsername, MakeASandwich])
msg = cast(AIMessage, llm_with_tools.invoke("how are you"))
assert not msg.tool_calls
def test_openai_structured_output() -> None:
class MyModel(BaseModel):
"""A Person"""

View File

@@ -1,7 +1,7 @@
"""Test OpenAI Chat API wrapper."""
import json
from typing import Any, List
from typing import Any, List, Type, Union
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -14,6 +14,7 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
@@ -321,3 +322,45 @@ def test_format_message_content() -> None:
},
]
assert [{"type": "text", "text": "hello"}] == _format_message_content(content)
class GenerateUsername(BaseModel):
"Get a username based on someone's name and hair color."
name: str
hair_color: str
class MakeASandwich(BaseModel):
"Make a sandwich given a list of ingredients."
bread_type: str
cheese_type: str
condiments: List[str]
vegetables: List[str]
@pytest.mark.parametrize(
"tool_choice",
[
"any",
"none",
"auto",
"required",
"GenerateUsername",
{"type": "function", "function": {"name": "MakeASandwich"}},
False,
None,
],
)
def test_bind_tools_tool_choice(tool_choice: Any) -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.bind_tools(tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice)
@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()])
def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.with_structured_output(schema)