mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-13 11:58:23 +00:00
core[patch], community[patch], openai[patch]: consolidate openai tool… (#16485)
… converters One way to convert anything to an OAI function: convert_to_openai_function One way to convert anything to an OAI tool: convert_to_openai_tool Corresponding bind functions on OAI models: bind_functions, bind_tools
This commit is contained in:
@@ -1,38 +1,6 @@
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.utils.openai_functions import (
|
||||
FunctionDescription,
|
||||
ToolDescription,
|
||||
convert_pydantic_to_openai_function,
|
||||
from langchain_core.utils.function_calling import (
|
||||
format_tool_to_openai_function,
|
||||
format_tool_to_openai_tool,
|
||||
)
|
||||
|
||||
|
||||
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
if tool.args_schema:
|
||||
return convert_pydantic_to_openai_function(
|
||||
tool.args_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
# This is a hack to get around the fact that some tools
|
||||
# do not expose an args_schema, and expect an argument
|
||||
# which is a string.
|
||||
# And Open AI does not support an array type for the
|
||||
# parameters.
|
||||
"properties": {
|
||||
"__arg1": {"title": "__arg1", "type": "string"},
|
||||
},
|
||||
"required": ["__arg1"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
function = format_tool_to_openai_function(tool)
|
||||
return {"type": "function", "function": function}
|
||||
__all__ = ["format_tool_to_openai_function", "format_tool_to_openai_tool"]
|
||||
|
||||
@@ -1,44 +1,6 @@
|
||||
"""Different methods for rendering Tools to be passed to LLMs.
|
||||
|
||||
Depending on the LLM you are using and the prompting strategy you are using,
|
||||
you may want Tools to be rendered in a different way.
|
||||
This module contains various ways to render tools.
|
||||
"""
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.utils.openai_functions import (
|
||||
FunctionDescription,
|
||||
ToolDescription,
|
||||
convert_pydantic_to_openai_function,
|
||||
from langchain_core.utils.function_calling import (
|
||||
format_tool_to_openai_function,
|
||||
format_tool_to_openai_tool,
|
||||
)
|
||||
|
||||
|
||||
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
if tool.args_schema:
|
||||
return convert_pydantic_to_openai_function(
|
||||
tool.args_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
# This is a hack to get around the fact that some tools
|
||||
# do not expose an args_schema, and expect an argument
|
||||
# which is a string.
|
||||
# And Open AI does not support an array type for the
|
||||
# parameters.
|
||||
"properties": {
|
||||
"__arg1": {"title": "__arg1", "type": "string"},
|
||||
},
|
||||
"required": ["__arg1"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
function = format_tool_to_openai_function(tool)
|
||||
return {"type": "function", "function": function}
|
||||
__all__ = ["format_tool_to_openai_function", "format_tool_to_openai_tool"]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Methods for creating function specs in the style of OpenAI Functions"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
@@ -16,12 +18,16 @@ from typing import (
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
PYTHON_TO_JSON_TYPES = {
|
||||
"str": "string",
|
||||
"int": "number",
|
||||
"int": "integer",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
}
|
||||
@@ -45,22 +51,47 @@ class ToolDescription(TypedDict):
|
||||
function: FunctionDescription
|
||||
|
||||
|
||||
def _rm_titles(kv: dict) -> dict:
|
||||
new_kv = {}
|
||||
for k, v in kv.items():
|
||||
if k == "title":
|
||||
continue
|
||||
elif isinstance(v, dict):
|
||||
new_kv[k] = _rm_titles(v)
|
||||
else:
|
||||
new_kv[k] = v
|
||||
return new_kv
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="0.2.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_function(
|
||||
model: Type[BaseModel],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
rm_titles: bool = True,
|
||||
) -> FunctionDescription:
|
||||
"""Converts a Pydantic model to a function description for the OpenAI API."""
|
||||
schema = dereference_refs(model.schema())
|
||||
schema.pop("definitions", None)
|
||||
title = schema.pop("title", "")
|
||||
default_description = schema.pop("description", "")
|
||||
return {
|
||||
"name": name or schema["title"],
|
||||
"description": description or schema["description"],
|
||||
"parameters": schema,
|
||||
"name": name or title,
|
||||
"description": description or default_description,
|
||||
"parameters": _rm_titles(schema) if rm_titles else schema,
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="0.2.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_tool(
|
||||
model: Type[BaseModel],
|
||||
*,
|
||||
@@ -132,8 +163,19 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
|
||||
# Mypy error:
|
||||
# "type" has no attribute "schema"
|
||||
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
|
||||
elif arg_type.__name__ in PYTHON_TO_JSON_TYPES:
|
||||
elif (
|
||||
hasattr(arg_type, "__name__")
|
||||
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
|
||||
):
|
||||
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
|
||||
elif (
|
||||
hasattr(arg_type, "__dict__")
|
||||
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
|
||||
):
|
||||
properties[arg] = {
|
||||
"enum": list(arg_type.__args__), # type: ignore
|
||||
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__], # type: ignore
|
||||
}
|
||||
if arg in arg_descriptions:
|
||||
if arg not in properties:
|
||||
properties[arg] = {}
|
||||
@@ -153,6 +195,11 @@ def _get_python_function_required_args(function: Callable) -> List[str]:
|
||||
return required
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="0.2.0",
|
||||
)
|
||||
def convert_python_function_to_openai_function(
|
||||
function: Callable,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -174,8 +221,49 @@ def convert_python_function_to_openai_function(
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="0.2.0",
|
||||
)
|
||||
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
if tool.args_schema:
|
||||
return convert_pydantic_to_openai_function(
|
||||
tool.args_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
# This is a hack to get around the fact that some tools
|
||||
# do not expose an args_schema, and expect an argument
|
||||
# which is a string.
|
||||
# And Open AI does not support an array type for the
|
||||
# parameters.
|
||||
"properties": {
|
||||
"__arg1": {"title": "__arg1", "type": "string"},
|
||||
},
|
||||
"required": ["__arg1"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="0.2.0",
|
||||
)
|
||||
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
function = format_tool_to_openai_function(tool)
|
||||
return {"type": "function", "function": function}
|
||||
|
||||
|
||||
def convert_to_openai_function(
|
||||
function: Union[Dict[str, Any], Type[BaseModel], Callable],
|
||||
function: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI function.
|
||||
|
||||
@@ -188,15 +276,38 @@ def convert_to_openai_function(
|
||||
A dict version of the passed in function which is compatible with the
|
||||
OpenAI function-calling API.
|
||||
"""
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
if isinstance(function, dict):
|
||||
return function
|
||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||
return cast(Dict, convert_pydantic_to_openai_function(function))
|
||||
elif isinstance(function, BaseTool):
|
||||
return format_tool_to_openai_function(function)
|
||||
elif callable(function):
|
||||
return convert_python_function_to_openai_function(function)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported function type {type(function)}. Functions must be passed in"
|
||||
f" as Dict, pydantic.BaseModel, or Callable."
|
||||
)
|
||||
|
||||
|
||||
def convert_to_openai_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI tool.
|
||||
|
||||
Args:
|
||||
tool: Either a dictionary, a pydantic.BaseModel class, Python function, or
|
||||
BaseTool. If a dictionary is passed in, it is assumed to already be a valid
|
||||
OpenAI tool or OpenAI function.
|
||||
|
||||
Returns:
|
||||
A dict version of the passed in tool which is compatible with the
|
||||
OpenAI tool-calling API.
|
||||
"""
|
||||
if isinstance(tool, dict) and "type" in tool:
|
||||
return tool
|
||||
function = convert_to_openai_function(tool)
|
||||
return {"type": "function", "function": function}
|
||||
|
||||
74
libs/core/tests/unit_tests/utils/test_function_calling.py
Normal file
74
libs/core/tests/unit_tests/utils/test_function_calling.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Any, Callable, Literal, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pydantic() -> Type[BaseModel]:
|
||||
class dummy_function(BaseModel):
|
||||
"""dummy function"""
|
||||
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def function() -> Callable:
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""dummy function
|
||||
|
||||
Args:
|
||||
arg1: foo
|
||||
arg2: one of 'bar', 'baz'
|
||||
"""
|
||||
pass
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool() -> BaseTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
class DummyFunction(BaseTool):
|
||||
args_schema: Type[BaseModel] = Schema
|
||||
name: str = "dummy_function"
|
||||
description: str = "dummy function"
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
return DummyFunction()
|
||||
|
||||
|
||||
def test_convert_to_openai_function(
|
||||
pydantic: Type[BaseModel], function: Callable, tool: BaseTool
|
||||
) -> None:
|
||||
expected = {
|
||||
"name": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
},
|
||||
}
|
||||
|
||||
for fn in (pydantic, function, tool, expected):
|
||||
actual = convert_to_openai_function(fn) # type: ignore
|
||||
assert actual == expected
|
||||
@@ -5,13 +5,13 @@ from json import JSONDecodeError
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import openai
|
||||
@@ -180,16 +180,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
OpenAIAssistantRunnable configured to run using the created assistant.
|
||||
"""
|
||||
client = client or _get_openai_client()
|
||||
openai_tools: List = []
|
||||
for tool in tools:
|
||||
oai_tool = (
|
||||
tool if isinstance(tool, dict) else format_tool_to_openai_tool(tool)
|
||||
)
|
||||
openai_tools.append(oai_tool)
|
||||
assistant = client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
tools=openai_tools,
|
||||
tools=[convert_to_openai_tool(tool) for tool in tools],
|
||||
model=model,
|
||||
)
|
||||
return cls(assistant_id=assistant.id, client=client, **kwargs)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from langchain_community.tools.convert_to_openai import format_tool_to_openai_function
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
||||
@@ -20,6 +19,7 @@ from langchain_core.prompts.chat import (
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent
|
||||
from langchain.agents.format_scratchpad.openai_functions import (
|
||||
@@ -71,7 +71,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
|
||||
@property
|
||||
def functions(self) -> List[dict]:
|
||||
return [dict(format_tool_to_openai_function(t)) for t in self.tools]
|
||||
return [dict(convert_to_openai_function(t)) for t in self.tools]
|
||||
|
||||
def plan(
|
||||
self,
|
||||
@@ -303,9 +303,7 @@ def create_openai_functions_agent(
|
||||
"Prompt must have input variable `agent_scratchpad`, but wasn't found. "
|
||||
f"Found {prompt.input_variables} instead."
|
||||
)
|
||||
llm_with_tools = llm.bind(
|
||||
functions=[format_tool_to_openai_function(t) for t in tools]
|
||||
)
|
||||
llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_openai_function_messages(
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Sequence
|
||||
|
||||
from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain.agents.format_scratchpad.openai_tools import (
|
||||
format_to_openai_tool_messages,
|
||||
@@ -82,9 +82,7 @@ def create_openai_tools_agent(
|
||||
if missing_vars:
|
||||
raise ValueError(f"Prompt missing required variables: {missing_vars}")
|
||||
|
||||
llm_with_tools = llm.bind(
|
||||
tools=[format_tool_to_openai_tool(tool) for tool in tools]
|
||||
)
|
||||
llm_with_tools = llm.bind(tools=[convert_to_openai_tool(tool) for tool in tools])
|
||||
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from langchain_community.tools.convert_to_openai import format_tool_to_openai_function
|
||||
from langchain_core.utils.function_calling import format_tool_to_openai_function
|
||||
|
||||
# For backwards compatibility
|
||||
__all__ = ["format_tool_to_openai_function"]
|
||||
|
||||
@@ -7,11 +7,11 @@ This module contains various ways to render tools.
|
||||
from typing import List
|
||||
|
||||
# For backwards compatibility
|
||||
from langchain_community.tools.convert_to_openai import (
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import (
|
||||
format_tool_to_openai_function,
|
||||
format_tool_to_openai_tool,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
__all__ = [
|
||||
"render_text_description",
|
||||
|
||||
@@ -15,13 +15,10 @@ def test_convert_pydantic_to_openai_function() -> None:
|
||||
"name": "Data",
|
||||
"description": "The data to return.",
|
||||
"parameters": {
|
||||
"title": "Data",
|
||||
"description": "The data to return.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"title": "Key", "description": "API key", "type": "string"},
|
||||
"key": {"description": "API key", "type": "string"},
|
||||
"days": {
|
||||
"title": "Days",
|
||||
"description": "Number of days to forecast",
|
||||
"default": 0,
|
||||
"type": "integer",
|
||||
@@ -50,22 +47,17 @@ def test_convert_pydantic_to_openai_function_nested() -> None:
|
||||
"name": "Model",
|
||||
"description": "The model to return.",
|
||||
"parameters": {
|
||||
"title": "Model",
|
||||
"description": "The model to return.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"title": "Data",
|
||||
"description": "The data to return.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"title": "Key",
|
||||
"description": "API key",
|
||||
"type": "string",
|
||||
},
|
||||
"days": {
|
||||
"title": "Days",
|
||||
"description": "Number of days to forecast",
|
||||
"default": 0,
|
||||
"type": "integer",
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@@ -52,11 +53,15 @@ from langchain_core.messages import (
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -626,12 +631,18 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind functions (and other objects) to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI function-calling API.
|
||||
|
||||
NOTE: Using bind_tools is recommended instead, as the `functions` and
|
||||
`function_call` request parameters are officially marked as deprecated by
|
||||
OpenAI.
|
||||
|
||||
Args:
|
||||
functions: A list of function definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, or callable. Pydantic
|
||||
@@ -663,3 +674,51 @@ class ChatOpenAI(BaseChatModel):
|
||||
functions=formatted_functions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "none"]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any), or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None:
|
||||
if isinstance(tool_choice, str) and tool_choice not in ("auto", "none"):
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
||||
if isinstance(tool_choice, dict) and len(formatted_tools) != 1:
|
||||
raise ValueError(
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
f"tool. Received {len(formatted_tools)} tools."
|
||||
)
|
||||
if (
|
||||
isinstance(tool_choice, dict)
|
||||
and formatted_tools[0]["function"]["name"]
|
||||
!= tool_choice["function"]["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||
)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(
|
||||
tools=formatted_tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user