This commit is contained in:
Bagatur 2025-03-12 02:41:37 -07:00
parent 88d0101394
commit c2dedd45e7
2 changed files with 75 additions and 22 deletions

View File

@ -531,9 +531,19 @@ def convert_to_openai_tool(
'description' and 'parameters' keys are now optional. Only 'name' is 'description' and 'parameters' keys are now optional. Only 'name' is
required and guaranteed to be part of the output. required and guaranteed to be part of the output.
.. versionchanged:: 0.3.44
Return OpenAI Responses API-style tools unchanged. This includes
any dict with "type" in "file_search", "function", "computer_use_preview",
"web_search_preview".
""" """
if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool: if isinstance(tool, dict) :
return tool if tool.get("type") in ("function", "file_search", "computer_use_preview"):
return tool
# As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11"
if (tool.get("type") or "").startswith("web_search_preview"):
return tool
oai_function = convert_to_openai_function(tool, strict=strict) oai_function = convert_to_openai_function(tool, strict=strict)
return {"type": "function", "function": oai_function} return {"type": "function", "function": oai_function}

View File

@ -934,8 +934,7 @@ class BaseChatOpenAI(BaseChatModel):
payload = {**self._default_params, **kwargs} payload = {**self._default_params, **kwargs}
if _use_response_api(payload): if _use_response_api(payload):
payload["input"] = _construct_response_api_input(messages) payload = _construct_response_api_payload(messages, payload)
else: else:
payload["messages"] = [_convert_message_to_dict(m) for m in messages] payload["messages"] = [_convert_message_to_dict(m) for m in messages]
return payload return payload
@ -1361,33 +1360,38 @@ class BaseChatOpenAI(BaseChatModel):
formatted_tools = [ formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools convert_to_openai_tool(tool, strict=strict) for tool in tools
] ]
tool_names = []
for tool in formatted_tools:
if "function" in tool:
tool_names.append(tool["function"]["name"])
elif "name" in tool:
tool_names.append(tool["name"])
else:
pass
if tool_choice: if tool_choice:
if isinstance(tool_choice, str): if isinstance(tool_choice, str):
# tool_choice is a tool/function name # tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "any", "required"): if tool_choice in tool_names:
tool_choice = { tool_choice = {
"type": "function", "type": "function",
"function": {"name": tool_choice}, "function": {"name": tool_choice},
} }
elif tool_choice in (
"file_search",
"web_search_preview",
"computer_use_preview",
):
tool_choice = {"type": tool_choice}
# 'any' is not natively supported by OpenAI API. # 'any' is not natively supported by OpenAI API.
# We support 'any' since other models use this instead of 'required'. # We support 'any' since other models use this instead of 'required'.
if tool_choice == "any": elif tool_choice == "any":
tool_choice = "required" tool_choice = "required"
else:
pass
elif isinstance(tool_choice, bool): elif isinstance(tool_choice, bool):
tool_choice = "required" tool_choice = "required"
elif isinstance(tool_choice, dict): elif isinstance(tool_choice, dict):
tool_names = [ pass
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
else: else:
raise ValueError( raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. " f"Unrecognized tool_choice type. Expected str, bool or dict. "
@ -2762,6 +2766,43 @@ def _use_response_api(payload: dict) -> bool:
) )
def _construct_response_api_payload(
messages: Sequence[BaseMessage], payload: dict
) -> dict:
payload["input"] = _construct_response_api_input(messages)
if tools := payload.pop("tools", None):
new_tools: list = []
for tool in tools:
# chat api: {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}, "strict": ...}} # noqa: E501
# responses api: {"type": "function", "name": "...", "description": "...", "parameters": {...}, "strict": ...} # noqa: E501
if tool["type"] == "function" and "function" in tool:
new_tools.append({"type": "function", **tool["function"]})
else:
new_tools.append(tool)
payload["tools"] = new_tools
if tool_choice := payload.pop("tool_choice", None):
# chat api: {"type": "function", "function": {"name": "..."}}
# responses api: {"type": "function", "name": "..."}
if tool_choice["type"] == "function" and "function" in tool_choice:
payload["tool_choice"] = {"type": "function", **tool_choice["function"]}
else:
payload["tool_choice"] = tool_choice
if response_format := payload.pop("response_format", None):
if payload.get("text"):
text = payload["text"]
raise ValueError(
"Can specify at most one of 'response_format' or 'text', received both:"
f"\n{response_format=}\n{text=}"
)
# chat api: {"type": "json_schema, "json_schema": {"schema": {...}, "name": "...", "description": "...", "strict": ...}} # noqa: E501
# responses api: {"type": "json_schema, "schema": {...}, "name": "...", "description": "...", "strict": ...} # noqa: E501
if response_format["type"] == "json_schema":
payload["text"] = {"type": "json_schema", **response_format["json_schema"]}
else:
payload["text"] = response_format
return payload
def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list: def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list:
input_ = [] input_ = []
for lc_msg in messages: for lc_msg in messages:
@ -2811,7 +2852,7 @@ def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list:
new_blocks = [] new_blocks = []
for block in msg["content"]: for block in msg["content"]:
# chat api: {"type": "text", "text": "..."} # chat api: {"type": "text", "text": "..."}
# response api: {"type": "output_text", "text": "...", "annotations": [...]} # noqa: E501 # responses api: {"type": "output_text", "text": "...", "annotations": [...]} # noqa: E501
if block["type"] == "text": if block["type"] == "text":
new_blocks.append( new_blocks.append(
{ {
@ -2820,8 +2861,10 @@ def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list:
"annotations": block.get("annotations") or [], "annotations": block.get("annotations") or [],
} }
) )
else: elif block["type"] in ("output_text", "refusal"):
new_blocks.append(block) new_blocks.append(block)
else:
pass
msg["content"] = new_blocks msg["content"] = new_blocks
if msg["content"]: if msg["content"]:
input_.append(msg) input_.append(msg)
@ -2831,11 +2874,11 @@ def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list:
new_blocks = [] new_blocks = []
for block in msg["content"]: for block in msg["content"]:
# chat api: {"type": "text", "text": "..."} # chat api: {"type": "text", "text": "..."}
# response api: {"type": "input_text", "text": "..."} # responses api: {"type": "input_text", "text": "..."}
if block["type"] == "text": if block["type"] == "text":
new_blocks.append({"type": "input_text", "text": block["text"]}) new_blocks.append({"type": "input_text", "text": block["text"]})
# chat api: {"type": "image_url", "image_url": {"url": "...", "detail": "..."}} # noqa: E501 # chat api: {"type": "image_url", "image_url": {"url": "...", "detail": "..."}} # noqa: E501
# response api: {"type": "image_url", "image_url": "...", "detail": "...", "file_id": "..."} # noqa: E501 # responses api: {"type": "image_url", "image_url": "...", "detail": "...", "file_id": "..."} # noqa: E501
elif block["type"] == "image_url": elif block["type"] == "image_url":
new_block = { new_block = {
"type": "input_image", "type": "input_image",