openai[patch]: support tool_choice="required" (#21216)

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Bagatur 2024-05-02 18:33:25 -04:00 committed by GitHub
parent aa9faa8512
commit 6ac6158a07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 106 additions and 22 deletions

View File

@ -763,7 +763,9 @@ class BaseChatOpenAI(BaseChatModel):
self, self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*, *,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model. """Bind tool-like objects to this chat model.
@ -776,40 +778,55 @@ class BaseChatOpenAI(BaseChatModel):
models, callables, and BaseTools will be automatically converted to models, callables, and BaseTools will be automatically converted to
their schema dictionary representation. their schema dictionary representation.
tool_choice: Which tool to require the model to call. tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or Options are:
"auto" to automatically determine which function to call name of the tool (str): calls corresponding tool;
(if any), or a dict of the form: "auto": automatically selects a tool (including no tool);
"none": does not call a tool;
"any" or "required": force at least one tool to be called;
True: forces tool call (requires `tools` be length 1);
False: no effect;
or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}. {"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the **kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor. :class:`~langchain.runnable.Runnable` constructor.
""" """
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice: if tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, str): if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none"): # tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "any", "required"):
tool_choice = { tool_choice = {
"type": "function", "type": "function",
"function": {"name": tool_choice}, "function": {"name": tool_choice},
} }
# 'any' is not natively supported by OpenAI API.
# We support 'any' since other models use this instead of 'required'.
if tool_choice == "any":
tool_choice = "required"
elif isinstance(tool_choice, bool): elif isinstance(tool_choice, bool):
if len(tools) > 1:
raise ValueError(
"tool_choice=True can only be used when a single tool is "
f"passed in, received {len(tools)} tools."
)
tool_choice = { tool_choice = {
"type": "function", "type": "function",
"function": {"name": formatted_tools[0]["function"]["name"]}, "function": {"name": formatted_tools[0]["function"]["name"]},
} }
elif isinstance(tool_choice, dict): elif isinstance(tool_choice, dict):
if ( tool_names = [
formatted_tools[0]["function"]["name"] formatted_tool["function"]["name"]
!= tool_choice["function"]["name"] for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
): ):
raise ValueError( raise ValueError(
f"Tool choice {tool_choice} was specified, but the only " f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}." f"provided tools were {tool_names}."
) )
else: else:
raise ValueError( raise ValueError(

View File

@ -385,7 +385,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.1.46" version = "0.1.49"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -540,13 +540,13 @@ files = [
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.16.2" version = "1.25.1"
description = "The official Python library for the openai API" description = "The official Python library for the openai API"
optional = false optional = false
python-versions = ">=3.7.1" python-versions = ">=3.7.1"
files = [ files = [
{file = "openai-1.16.2-py3-none-any.whl", hash = "sha256:46a435380921e42dae218d04d6dd0e89a30d7f3b9d8a778d5887f78003cf9354"}, {file = "openai-1.25.1-py3-none-any.whl", hash = "sha256:aa2f381f476f5fa4df8728a34a3e454c321caa064b7b68ab6e9daa1ed082dbf9"},
{file = "openai-1.16.2.tar.gz", hash = "sha256:c93d5efe5b73b6cb72c4cd31823852d2e7c84a138c0af3cbe4a8eb32b1164ab2"}, {file = "openai-1.25.1.tar.gz", hash = "sha256:f561ce86f4b4008eb6c78622d641e4b7e1ab8a8cdb15d2f0b2a49942d40d21a8"},
] ]
[package.dependencies] [package.dependencies]
@ -1286,4 +1286,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "1d9cefc90178d94dee2a09afc14af160a7e35e4972ad4701d3bbbfdde14a81fa" content-hash = "2dbfc54f73eec285047a224d9dcddd5d16d24c693f550b792d399826497bbbf8"

View File

@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
langchain-core = "^0.1.46" langchain-core = "^0.1.46"
openai = "^1.10.0" openai = "^1.24.0"
tiktoken = ">=0.5.2,<1" tiktoken = ">=0.5.2,<1"
[tool.poetry.group.test] [tool.poetry.group.test]

View File

@ -479,6 +479,15 @@ class GenerateUsername(BaseModel):
hair_color: str hair_color: str
class MakeASandwich(BaseModel):
"Make a sandwich given a list of ingredients."
bread_type: str
cheese_type: str
condiments: List[str]
vegetables: List[str]
def test_tool_use() -> None: def test_tool_use() -> None:
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True) llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
@ -563,6 +572,21 @@ def test_manual_tool_call_msg() -> None:
llm_with_tool.invoke(msgs) llm_with_tool.invoke(msgs)
def test_bind_tools_tool_choice() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
for tool_choice in ("any", "required"):
llm_with_tools = llm.bind_tools(
tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice
)
msg = cast(AIMessage, llm_with_tools.invoke("how are you"))
assert msg.tool_calls
llm_with_tools = llm.bind_tools(tools=[GenerateUsername, MakeASandwich])
msg = cast(AIMessage, llm_with_tools.invoke("how are you"))
assert not msg.tool_calls
def test_openai_structured_output() -> None: def test_openai_structured_output() -> None:
class MyModel(BaseModel): class MyModel(BaseModel):
"""A Person""" """A Person"""

View File

@ -1,7 +1,7 @@
"""Test OpenAI Chat API wrapper.""" """Test OpenAI Chat API wrapper."""
import json import json
from typing import Any, List from typing import Any, List, Type, Union
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -14,6 +14,7 @@ from langchain_core.messages import (
ToolCall, ToolCall,
ToolMessage, ToolMessage,
) )
from langchain_core.pydantic_v1 import BaseModel
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import ( from langchain_openai.chat_models.base import (
@ -321,3 +322,45 @@ def test_format_message_content() -> None:
}, },
] ]
assert [{"type": "text", "text": "hello"}] == _format_message_content(content) assert [{"type": "text", "text": "hello"}] == _format_message_content(content)
class GenerateUsername(BaseModel):
"Get a username based on someone's name and hair color."
name: str
hair_color: str
class MakeASandwich(BaseModel):
"Make a sandwich given a list of ingredients."
bread_type: str
cheese_type: str
condiments: List[str]
vegetables: List[str]
@pytest.mark.parametrize(
"tool_choice",
[
"any",
"none",
"auto",
"required",
"GenerateUsername",
{"type": "function", "function": {"name": "MakeASandwich"}},
False,
None,
],
)
def test_bind_tools_tool_choice(tool_choice: Any) -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.bind_tools(tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice)
@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()])
def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.with_structured_output(schema)