groq[patch]: support tool_choice=any/required (#27000)

https://console.groq.com/docs/api-reference#chat-create
This commit is contained in:
Bagatur 2024-09-30 11:28:35 -07:00 committed by GitHub
parent db8845a62a
commit c7120d87dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -739,31 +739,11 @@ class ChatGroq(BaseChatModel):
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if tool_choice == "any":
if len(tools) > 1:
raise ValueError(
f"Groq does not currently support {tool_choice=}. Should "
f"be one of 'auto', 'none', or the name of the tool to call."
)
else:
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
tool_choice = "required"
if isinstance(tool_choice, str) and (
tool_choice not in ("auto", "any", "none")
tool_choice not in ("auto", "none", "required")
):
tool_choice = {"type": "function", "function": {"name": tool_choice}}
# TODO: Remove this update once 'any' is supported.
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, dict) and (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
if isinstance(tool_choice, bool):
if len(tools) > 1:
raise ValueError(