openai, anthropic, ...: with_structured_output to pass in explicit tool choice (#23645)

...community, mistralai, groq, fireworks

part of #23644
This commit is contained in:
Bagatur 2024-06-28 16:39:53 -07:00 committed by GitHub
parent c5f35a72da
commit fc8fd49328
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 22 additions and 16 deletions

View File

@ -531,15 +531,15 @@ class ChatLlamaCpp(BaseChatModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast(Type, schema)], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
if include_raw:

View File

@ -986,7 +986,9 @@ class ChatAnthropic(BaseChatModel):
# }
""" # noqa: E501
llm = self.bind_tools([schema], tool_choice="any")
tool_name = convert_to_anthropic_tool(schema)["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser = ToolsOutputParser(
first_tool_only=True, pydantic_schemas=[schema]

View File

@ -884,15 +884,15 @@ class ChatFireworks(BaseChatModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})

View File

@ -1014,15 +1014,15 @@ class ChatGroq(BaseChatModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})

View File

@ -794,6 +794,8 @@ class ChatMistralAI(BaseChatModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
# specifying a tool.
llm = self.bind_tools([schema], tool_choice="any")
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(

View File

@ -857,15 +857,15 @@ class AzureChatOpenAI(BaseChatOpenAI):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})

View File

@ -1140,15 +1140,17 @@ class BaseChatOpenAI(BaseChatModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True, parallel_tool_calls=False)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools(
[schema], tool_choice=tool_name, parallel_tool_calls=False
)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})