diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 80f01a5d6df..5a21815ec26 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -763,7 +763,9 @@ class BaseChatOpenAI(BaseChatModel): self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], *, - tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -776,40 +778,55 @@ class BaseChatOpenAI(BaseChatModel): models, callables, and BaseTools will be automatically converted to their schema dictionary representation. tool_choice: Which tool to require the model to call. - Must be the name of the single provided function or - "auto" to automatically determine which function to call - (if any), or a dict of the form: + Options are: + name of the tool (str): calls corresponding tool; + "auto": automatically selects a tool (including no tool); + "none": does not call a tool; + "any" or "required": force at least one tool to be called; + True: forces tool call (requires `tools` be length 1); + False: no effect; + + or a dict of the form: {"type": "function", "function": {"name": <>}}. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """ formatted_tools = [convert_to_openai_tool(tool) for tool in tools] - if tool_choice is not None and tool_choice: - if len(formatted_tools) != 1: - raise ValueError( - "When specifying `tool_choice`, you must provide exactly one " - f"tool. Received {len(formatted_tools)} tools." - ) + if tool_choice: if isinstance(tool_choice, str): - if tool_choice not in ("auto", "none"): + # tool_choice is a tool/function name + if tool_choice not in ("auto", "none", "any", "required"): tool_choice = { "type": "function", "function": {"name": tool_choice}, } + # 'any' is not natively supported by OpenAI API. + # We support 'any' since other models use this instead of 'required'. + if tool_choice == "any": + tool_choice = "required" elif isinstance(tool_choice, bool): + if len(tools) > 1: + raise ValueError( + "tool_choice=True can only be used when a single tool is " + f"passed in, received {len(tools)} tools." + ) tool_choice = { "type": "function", "function": {"name": formatted_tools[0]["function"]["name"]}, } elif isinstance(tool_choice, dict): - if ( - formatted_tools[0]["function"]["name"] - != tool_choice["function"]["name"] + tool_names = [ + formatted_tool["function"]["name"] + for formatted_tool in formatted_tools + ] + if not any( + tool_name == tool_choice["function"]["name"] + for tool_name in tool_names ): raise ValueError( f"Tool choice {tool_choice} was specified, but the only " - f"provided tool was {formatted_tools[0]['function']['name']}." + f"provided tools were {tool_names}." ) else: raise ValueError( diff --git a/libs/partners/openai/poetry.lock b/libs/partners/openai/poetry.lock index 19d79b0c7e2..495bec291fd 100644 --- a/libs/partners/openai/poetry.lock +++ b/libs/partners/openai/poetry.lock @@ -385,7 +385,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.46" +version = "0.1.49" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -540,13 +540,13 @@ files = [ [[package]] name = "openai" -version = "1.16.2" +version = "1.25.1" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.16.2-py3-none-any.whl", hash = "sha256:46a435380921e42dae218d04d6dd0e89a30d7f3b9d8a778d5887f78003cf9354"}, - {file = "openai-1.16.2.tar.gz", hash = "sha256:c93d5efe5b73b6cb72c4cd31823852d2e7c84a138c0af3cbe4a8eb32b1164ab2"}, + {file = "openai-1.25.1-py3-none-any.whl", hash = "sha256:aa2f381f476f5fa4df8728a34a3e454c321caa064b7b68ab6e9daa1ed082dbf9"}, + {file = "openai-1.25.1.tar.gz", hash = "sha256:f561ce86f4b4008eb6c78622d641e4b7e1ab8a8cdb15d2f0b2a49942d40d21a8"}, ] [package.dependencies] @@ -1286,4 +1286,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "1d9cefc90178d94dee2a09afc14af160a7e35e4972ad4701d3bbbfdde14a81fa" +content-hash = "2dbfc54f73eec285047a224d9dcddd5d16d24c693f550b792d399826497bbbf8" diff --git a/libs/partners/openai/pyproject.toml b/libs/partners/openai/pyproject.toml index 9d48f9e5aff..387cf5ea364 100644 --- a/libs/partners/openai/pyproject.toml +++ b/libs/partners/openai/pyproject.toml @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.46" -openai = "^1.10.0" +openai = "^1.24.0" tiktoken = ">=0.5.2,<1" [tool.poetry.group.test] diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index ace749d5e49..e86c457932c 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -479,6 +479,15 @@ class GenerateUsername(BaseModel): hair_color: str +class MakeASandwich(BaseModel): + "Make a sandwich given a list of ingredients." + + bread_type: str + cheese_type: str + condiments: List[str] + vegetables: List[str] + + def test_tool_use() -> None: llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True) @@ -563,6 +572,21 @@ def test_manual_tool_call_msg() -> None: llm_with_tool.invoke(msgs) +def test_bind_tools_tool_choice() -> None: + """Test passing in manually construct tool call message.""" + llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + for tool_choice in ("any", "required"): + llm_with_tools = llm.bind_tools( + tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice + ) + msg = cast(AIMessage, llm_with_tools.invoke("how are you")) + assert msg.tool_calls + + llm_with_tools = llm.bind_tools(tools=[GenerateUsername, MakeASandwich]) + msg = cast(AIMessage, llm_with_tools.invoke("how are you")) + assert not msg.tool_calls + + def test_openai_structured_output() -> None: class MyModel(BaseModel): """A Person""" diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 1bbaac4d6e3..2d7c4d8efc4 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1,7 +1,7 @@ """Test OpenAI Chat API wrapper.""" import json -from typing import Any, List +from typing import Any, List, Type, Union from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -14,6 +14,7 @@ from langchain_core.messages import ( ToolCall, ToolMessage, ) +from langchain_core.pydantic_v1 import BaseModel from langchain_openai import ChatOpenAI from langchain_openai.chat_models.base import ( @@ -321,3 +322,45 @@ def test_format_message_content() -> None: }, ] assert [{"type": "text", "text": "hello"}] == _format_message_content(content) + + +class GenerateUsername(BaseModel): + "Get a username based on someone's name and hair color." + + name: str + hair_color: str + + +class MakeASandwich(BaseModel): + "Make a sandwich given a list of ingredients." + + bread_type: str + cheese_type: str + condiments: List[str] + vegetables: List[str] + + +@pytest.mark.parametrize( + "tool_choice", + [ + "any", + "none", + "auto", + "required", + "GenerateUsername", + {"type": "function", "function": {"name": "MakeASandwich"}}, + False, + None, + ], +) +def test_bind_tools_tool_choice(tool_choice: Any) -> None: + """Test passing in manually construct tool call message.""" + llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + llm.bind_tools(tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice) + + +@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()]) +def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None: + """Test passing in manually construct tool call message.""" + llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + llm.with_structured_output(schema)