mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 03:38:06 +00:00
openai[patch]: support tool_choice="required" (#21216)
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
aa9faa8512
commit
6ac6158a07
@ -763,7 +763,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
@ -776,40 +778,55 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any), or a dict of the form:
|
||||
Options are:
|
||||
name of the tool (str): calls corresponding tool;
|
||||
"auto": automatically selects a tool (including no tool);
|
||||
"none": does not call a tool;
|
||||
"any" or "required": force at least one tool to be called;
|
||||
True: forces tool call (requires `tools` be length 1);
|
||||
False: no effect;
|
||||
|
||||
or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None and tool_choice:
|
||||
if len(formatted_tools) != 1:
|
||||
raise ValueError(
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
f"tool. Received {len(formatted_tools)} tools."
|
||||
)
|
||||
if tool_choice:
|
||||
if isinstance(tool_choice, str):
|
||||
if tool_choice not in ("auto", "none"):
|
||||
# tool_choice is a tool/function name
|
||||
if tool_choice not in ("auto", "none", "any", "required"):
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
# 'any' is not natively supported by OpenAI API.
|
||||
# We support 'any' since other models use this instead of 'required'.
|
||||
if tool_choice == "any":
|
||||
tool_choice = "required"
|
||||
elif isinstance(tool_choice, bool):
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
"tool_choice=True can only be used when a single tool is "
|
||||
f"passed in, received {len(tools)} tools."
|
||||
)
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": formatted_tools[0]["function"]["name"]},
|
||||
}
|
||||
elif isinstance(tool_choice, dict):
|
||||
if (
|
||||
formatted_tools[0]["function"]["name"]
|
||||
!= tool_choice["function"]["name"]
|
||||
tool_names = [
|
||||
formatted_tool["function"]["name"]
|
||||
for formatted_tool in formatted_tools
|
||||
]
|
||||
if not any(
|
||||
tool_name == tool_choice["function"]["name"]
|
||||
for tool_name in tool_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||
f"provided tools were {tool_names}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
10
libs/partners/openai/poetry.lock
generated
10
libs/partners/openai/poetry.lock
generated
@ -385,7 +385,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.46"
|
||||
version = "0.1.49"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -540,13 +540,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.16.2"
|
||||
version = "1.25.1"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
{file = "openai-1.16.2-py3-none-any.whl", hash = "sha256:46a435380921e42dae218d04d6dd0e89a30d7f3b9d8a778d5887f78003cf9354"},
|
||||
{file = "openai-1.16.2.tar.gz", hash = "sha256:c93d5efe5b73b6cb72c4cd31823852d2e7c84a138c0af3cbe4a8eb32b1164ab2"},
|
||||
{file = "openai-1.25.1-py3-none-any.whl", hash = "sha256:aa2f381f476f5fa4df8728a34a3e454c321caa064b7b68ab6e9daa1ed082dbf9"},
|
||||
{file = "openai-1.25.1.tar.gz", hash = "sha256:f561ce86f4b4008eb6c78622d641e4b7e1ab8a8cdb15d2f0b2a49942d40d21a8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1286,4 +1286,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "1d9cefc90178d94dee2a09afc14af160a7e35e4972ad4701d3bbbfdde14a81fa"
|
||||
content-hash = "2dbfc54f73eec285047a224d9dcddd5d16d24c693f550b792d399826497bbbf8"
|
||||
|
@ -13,7 +13,7 @@ license = "MIT"
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.46"
|
||||
openai = "^1.10.0"
|
||||
openai = "^1.24.0"
|
||||
tiktoken = ">=0.5.2,<1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
|
@ -479,6 +479,15 @@ class GenerateUsername(BaseModel):
|
||||
hair_color: str
|
||||
|
||||
|
||||
class MakeASandwich(BaseModel):
|
||||
"Make a sandwich given a list of ingredients."
|
||||
|
||||
bread_type: str
|
||||
cheese_type: str
|
||||
condiments: List[str]
|
||||
vegetables: List[str]
|
||||
|
||||
|
||||
def test_tool_use() -> None:
|
||||
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
|
||||
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
|
||||
@ -563,6 +572,21 @@ def test_manual_tool_call_msg() -> None:
|
||||
llm_with_tool.invoke(msgs)
|
||||
|
||||
|
||||
def test_bind_tools_tool_choice() -> None:
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
for tool_choice in ("any", "required"):
|
||||
llm_with_tools = llm.bind_tools(
|
||||
tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice
|
||||
)
|
||||
msg = cast(AIMessage, llm_with_tools.invoke("how are you"))
|
||||
assert msg.tool_calls
|
||||
|
||||
llm_with_tools = llm.bind_tools(tools=[GenerateUsername, MakeASandwich])
|
||||
msg = cast(AIMessage, llm_with_tools.invoke("how are you"))
|
||||
assert not msg.tool_calls
|
||||
|
||||
|
||||
def test_openai_structured_output() -> None:
|
||||
class MyModel(BaseModel):
|
||||
"""A Person"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test OpenAI Chat API wrapper."""
|
||||
|
||||
import json
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Type, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -14,6 +14,7 @@ from langchain_core.messages import (
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import (
|
||||
@ -321,3 +322,45 @@ def test_format_message_content() -> None:
|
||||
},
|
||||
]
|
||||
assert [{"type": "text", "text": "hello"}] == _format_message_content(content)
|
||||
|
||||
|
||||
class GenerateUsername(BaseModel):
|
||||
"Get a username based on someone's name and hair color."
|
||||
|
||||
name: str
|
||||
hair_color: str
|
||||
|
||||
|
||||
class MakeASandwich(BaseModel):
|
||||
"Make a sandwich given a list of ingredients."
|
||||
|
||||
bread_type: str
|
||||
cheese_type: str
|
||||
condiments: List[str]
|
||||
vegetables: List[str]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_choice",
|
||||
[
|
||||
"any",
|
||||
"none",
|
||||
"auto",
|
||||
"required",
|
||||
"GenerateUsername",
|
||||
{"type": "function", "function": {"name": "MakeASandwich"}},
|
||||
False,
|
||||
None,
|
||||
],
|
||||
)
|
||||
def test_bind_tools_tool_choice(tool_choice: Any) -> None:
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
llm.bind_tools(tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()])
|
||||
def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
llm.with_structured_output(schema)
|
||||
|
Loading…
Reference in New Issue
Block a user