mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 10:13:29 +00:00
openai[patch]: accept function_call dict in bind_functions (#16483)
Confusing that you can't pass in a dict
This commit is contained in:
parent
db80832e4f
commit
31790d15ec
@ -18,6 +18,7 @@ from typing import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -182,6 +183,10 @@ def _convert_delta_to_message_chunk(
|
|||||||
return default_class(content=content) # type: ignore
|
return default_class(content=content) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class _FunctionCall(TypedDict):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class ChatOpenAI(BaseChatModel):
|
class ChatOpenAI(BaseChatModel):
|
||||||
"""`OpenAI` Chat large language models API.
|
"""`OpenAI` Chat large language models API.
|
||||||
|
|
||||||
@ -632,7 +637,9 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
def bind_functions(
|
def bind_functions(
|
||||||
self,
|
self,
|
||||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
function_call: Optional[str] = None,
|
function_call: Optional[
|
||||||
|
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||||
|
] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
"""Bind functions (and other objects) to this chat model.
|
"""Bind functions (and other objects) to this chat model.
|
||||||
@ -658,18 +665,26 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||||
if function_call is not None:
|
if function_call is not None:
|
||||||
if len(formatted_functions) != 1:
|
function_call = (
|
||||||
|
{"name": function_call}
|
||||||
|
if isinstance(function_call, str)
|
||||||
|
and function_call not in ("auto", "none")
|
||||||
|
else function_call
|
||||||
|
)
|
||||||
|
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When specifying `function_call`, you must provide exactly one "
|
"When specifying `function_call`, you must provide exactly one "
|
||||||
"function."
|
"function."
|
||||||
)
|
)
|
||||||
if formatted_functions[0]["name"] != function_call:
|
if (
|
||||||
|
isinstance(function_call, dict)
|
||||||
|
and formatted_functions[0]["name"] != function_call["name"]
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Function call {function_call} was specified, but the only "
|
f"Function call {function_call} was specified, but the only "
|
||||||
f"provided function was {formatted_functions[0]['name']}."
|
f"provided function was {formatted_functions[0]['name']}."
|
||||||
)
|
)
|
||||||
function_call_ = {"name": function_call}
|
kwargs = {**kwargs, "function_call": function_call}
|
||||||
kwargs = {**kwargs, "function_call": function_call_}
|
|
||||||
return super().bind(
|
return super().bind(
|
||||||
functions=formatted_functions,
|
functions=formatted_functions,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
Loading…
Reference in New Issue
Block a user