openai[patch]: support tool_choice="required" (#21216)

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Bagatur 2024-05-02 18:33:25 -04:00 committed by GitHub
parent aa9faa8512
commit 6ac6158a07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 106 additions and 22 deletions

View File

@ -763,7 +763,9 @@ class BaseChatOpenAI(BaseChatModel):
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
@ -776,40 +778,55 @@ class BaseChatOpenAI(BaseChatModel):
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
Options are:
name of the tool (str): calls corresponding tool;
"auto": automatically selects a tool (including no tool);
"none": does not call a tool;
"any" or "required": force at least one tool to be called;
True: forces tool call (requires `tools` be length 1);
False: no effect;
or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if tool_choice:
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none"):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "any", "required"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
# 'any' is not natively supported by OpenAI API.
# We support 'any' since other models use this instead of 'required'.
if tool_choice == "any":
tool_choice = "required"
elif isinstance(tool_choice, bool):
if len(tools) > 1:
raise ValueError(
"tool_choice=True can only be used when a single tool is "
f"passed in, received {len(tools)} tools."
)
tool_choice = {
"type": "function",
"function": {"name": formatted_tools[0]["function"]["name"]},
}
elif isinstance(tool_choice, dict):
if (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
tool_names = [
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
f"provided tools were {tool_names}."
)
else:
raise ValueError(

View File

@ -385,7 +385,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.46"
version = "0.1.49"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -540,13 +540,13 @@ files = [
[[package]]
name = "openai"
version = "1.16.2"
version = "1.25.1"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.7.1"
files = [
{file = "openai-1.16.2-py3-none-any.whl", hash = "sha256:46a435380921e42dae218d04d6dd0e89a30d7f3b9d8a778d5887f78003cf9354"},
{file = "openai-1.16.2.tar.gz", hash = "sha256:c93d5efe5b73b6cb72c4cd31823852d2e7c84a138c0af3cbe4a8eb32b1164ab2"},
{file = "openai-1.25.1-py3-none-any.whl", hash = "sha256:aa2f381f476f5fa4df8728a34a3e454c321caa064b7b68ab6e9daa1ed082dbf9"},
{file = "openai-1.25.1.tar.gz", hash = "sha256:f561ce86f4b4008eb6c78622d641e4b7e1ab8a8cdb15d2f0b2a49942d40d21a8"},
]
[package.dependencies]
@ -1286,4 +1286,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "1d9cefc90178d94dee2a09afc14af160a7e35e4972ad4701d3bbbfdde14a81fa"
content-hash = "2dbfc54f73eec285047a224d9dcddd5d16d24c693f550b792d399826497bbbf8"

View File

@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.46"
openai = "^1.10.0"
openai = "^1.24.0"
tiktoken = ">=0.5.2,<1"
[tool.poetry.group.test]

View File

@ -479,6 +479,15 @@ class GenerateUsername(BaseModel):
hair_color: str
class MakeASandwich(BaseModel):
"Make a sandwich given a list of ingredients."
bread_type: str
cheese_type: str
condiments: List[str]
vegetables: List[str]
def test_tool_use() -> None:
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
@ -563,6 +572,21 @@ def test_manual_tool_call_msg() -> None:
llm_with_tool.invoke(msgs)
def test_bind_tools_tool_choice() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
for tool_choice in ("any", "required"):
llm_with_tools = llm.bind_tools(
tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice
)
msg = cast(AIMessage, llm_with_tools.invoke("how are you"))
assert msg.tool_calls
llm_with_tools = llm.bind_tools(tools=[GenerateUsername, MakeASandwich])
msg = cast(AIMessage, llm_with_tools.invoke("how are you"))
assert not msg.tool_calls
def test_openai_structured_output() -> None:
class MyModel(BaseModel):
"""A Person"""

View File

@ -1,7 +1,7 @@
"""Test OpenAI Chat API wrapper."""
import json
from typing import Any, List
from typing import Any, List, Type, Union
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -14,6 +14,7 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
@ -321,3 +322,45 @@ def test_format_message_content() -> None:
},
]
assert [{"type": "text", "text": "hello"}] == _format_message_content(content)
class GenerateUsername(BaseModel):
"Get a username based on someone's name and hair color."
name: str
hair_color: str
class MakeASandwich(BaseModel):
"Make a sandwich given a list of ingredients."
bread_type: str
cheese_type: str
condiments: List[str]
vegetables: List[str]
@pytest.mark.parametrize(
"tool_choice",
[
"any",
"none",
"auto",
"required",
"GenerateUsername",
{"type": "function", "function": {"name": "MakeASandwich"}},
False,
None,
],
)
def test_bind_tools_tool_choice(tool_choice: Any) -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.bind_tools(tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice)
@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()])
def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.with_structured_output(schema)