openai[patch]: support built-in code interpreter and remote MCP tools (#31304)

This commit is contained in:
ccurme
2025-05-22 11:47:57 -04:00
committed by GitHub
parent 1b5ffe4107
commit 053a1246da
6 changed files with 389 additions and 14 deletions

View File

@@ -775,16 +775,22 @@ class BaseChatOpenAI(BaseChatModel):
with context_manager as response:
is_first_chunk = True
has_reasoning = False
for chunk in response:
metadata = headers if is_first_chunk else {}
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk, schema=original_schema_obj, metadata=metadata
chunk,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
):
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
async def _astream_responses(
@@ -811,16 +817,22 @@ class BaseChatOpenAI(BaseChatModel):
async with context_manager as response:
is_first_chunk = True
has_reasoning = False
async for chunk in response:
metadata = headers if is_first_chunk else {}
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk, schema=original_schema_obj, metadata=metadata
chunk,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
):
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
def _should_stream_usage(
@@ -1176,12 +1188,22 @@ class BaseChatOpenAI(BaseChatModel):
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> dict[str, Any]:
"""Get the parameters used to invoke the model."""
return {
params = {
"model": self.model_name,
**super()._get_invocation_params(stop=stop),
**self._default_params,
**kwargs,
}
# Redact headers from built-in remote MCP tool invocations
if (tools := params.get("tools")) and isinstance(tools, list):
params["tools"] = [
({**tool, "headers": "**REDACTED**"} if "headers" in tool else tool)
if isinstance(tool, dict) and tool.get("type") == "mcp"
else tool
for tool in tools
]
return params
def _get_ls_params(
self, stop: Optional[list[str]] = None, **kwargs: Any
@@ -1456,6 +1478,8 @@ class BaseChatOpenAI(BaseChatModel):
"file_search",
"web_search_preview",
"computer_use_preview",
"code_interpreter",
"mcp",
):
tool_choice = {"type": tool_choice}
# 'any' is not natively supported by OpenAI API.
@@ -3150,12 +3174,22 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
):
function_call["id"] = _id
function_calls.append(function_call)
# Computer calls
# Built-in tool calls
computer_calls = []
code_interpreter_calls = []
mcp_calls = []
tool_outputs = lc_msg.additional_kwargs.get("tool_outputs", [])
for tool_output in tool_outputs:
if tool_output.get("type") == "computer_call":
computer_calls.append(tool_output)
elif tool_output.get("type") == "code_interpreter_call":
code_interpreter_calls.append(tool_output)
elif tool_output.get("type") == "mcp_call":
mcp_calls.append(tool_output)
else:
pass
input_.extend(code_interpreter_calls)
input_.extend(mcp_calls)
msg["content"] = msg.get("content") or []
if lc_msg.additional_kwargs.get("refusal"):
if isinstance(msg["content"], str):
@@ -3196,6 +3230,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
elif msg["role"] in ("user", "system", "developer"):
if isinstance(msg["content"], list):
new_blocks = []
non_message_item_types = ("mcp_approval_response",)
for block in msg["content"]:
# chat api: {"type": "text", "text": "..."}
# responses api: {"type": "input_text", "text": "..."}
@@ -3216,10 +3251,15 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
new_blocks.append(new_block)
elif block["type"] in ("input_text", "input_image", "input_file"):
new_blocks.append(block)
elif block["type"] in non_message_item_types:
input_.append(block)
else:
pass
msg["content"] = new_blocks
input_.append(msg)
if msg["content"]:
input_.append(msg)
else:
input_.append(msg)
else:
input_.append(msg)
@@ -3366,7 +3406,10 @@ def _construct_lc_result_from_responses_api(
def _convert_responses_chunk_to_generation_chunk(
chunk: Any, schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None
chunk: Any,
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
has_reasoning: bool = False,
) -> Optional[ChatGenerationChunk]:
content = []
tool_call_chunks: list = []
@@ -3429,6 +3472,10 @@ def _convert_responses_chunk_to_generation_chunk(
"web_search_call",
"file_search_call",
"computer_call",
"code_interpreter_call",
"mcp_call",
"mcp_list_tools",
"mcp_approval_request",
):
additional_kwargs["tool_outputs"] = [
chunk.item.model_dump(exclude_none=True, mode="json")
@@ -3444,9 +3491,11 @@ def _convert_responses_chunk_to_generation_chunk(
elif chunk.type == "response.refusal.done":
additional_kwargs["refusal"] = chunk.refusal
elif chunk.type == "response.output_item.added" and chunk.item.type == "reasoning":
additional_kwargs["reasoning"] = chunk.item.model_dump(
exclude_none=True, mode="json"
)
if not has_reasoning:
# Hack until breaking release: store first reasoning item ID.
additional_kwargs["reasoning"] = chunk.item.model_dump(
exclude_none=True, mode="json"
)
elif chunk.type == "response.reasoning_summary_part.added":
additional_kwargs["reasoning"] = {
# langchain-core uses the `index` key to aggregate text blocks.