openai[patch]: support built-in code interpreter and remote MCP tools (#31304)

This commit is contained in:
ccurme
2025-05-22 11:47:57 -04:00
committed by GitHub
parent 1b5ffe4107
commit 053a1246da
6 changed files with 389 additions and 14 deletions

View File

@@ -775,16 +775,22 @@ class BaseChatOpenAI(BaseChatModel):
with context_manager as response:
is_first_chunk = True
has_reasoning = False
for chunk in response:
metadata = headers if is_first_chunk else {}
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk, schema=original_schema_obj, metadata=metadata
chunk,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
):
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
async def _astream_responses(
@@ -811,16 +817,22 @@ class BaseChatOpenAI(BaseChatModel):
async with context_manager as response:
is_first_chunk = True
has_reasoning = False
async for chunk in response:
metadata = headers if is_first_chunk else {}
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk, schema=original_schema_obj, metadata=metadata
chunk,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
):
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
def _should_stream_usage(
@@ -1176,12 +1188,22 @@ class BaseChatOpenAI(BaseChatModel):
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> dict[str, Any]:
"""Get the parameters used to invoke the model."""
return {
params = {
"model": self.model_name,
**super()._get_invocation_params(stop=stop),
**self._default_params,
**kwargs,
}
# Redact headers from built-in remote MCP tool invocations
if (tools := params.get("tools")) and isinstance(tools, list):
params["tools"] = [
({**tool, "headers": "**REDACTED**"} if "headers" in tool else tool)
if isinstance(tool, dict) and tool.get("type") == "mcp"
else tool
for tool in tools
]
return params
def _get_ls_params(
self, stop: Optional[list[str]] = None, **kwargs: Any
@@ -1456,6 +1478,8 @@ class BaseChatOpenAI(BaseChatModel):
"file_search",
"web_search_preview",
"computer_use_preview",
"code_interpreter",
"mcp",
):
tool_choice = {"type": tool_choice}
# 'any' is not natively supported by OpenAI API.
@@ -3150,12 +3174,22 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
):
function_call["id"] = _id
function_calls.append(function_call)
# Computer calls
# Built-in tool calls
computer_calls = []
code_interpreter_calls = []
mcp_calls = []
tool_outputs = lc_msg.additional_kwargs.get("tool_outputs", [])
for tool_output in tool_outputs:
if tool_output.get("type") == "computer_call":
computer_calls.append(tool_output)
elif tool_output.get("type") == "code_interpreter_call":
code_interpreter_calls.append(tool_output)
elif tool_output.get("type") == "mcp_call":
mcp_calls.append(tool_output)
else:
pass
input_.extend(code_interpreter_calls)
input_.extend(mcp_calls)
msg["content"] = msg.get("content") or []
if lc_msg.additional_kwargs.get("refusal"):
if isinstance(msg["content"], str):
@@ -3196,6 +3230,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
elif msg["role"] in ("user", "system", "developer"):
if isinstance(msg["content"], list):
new_blocks = []
non_message_item_types = ("mcp_approval_response",)
for block in msg["content"]:
# chat api: {"type": "text", "text": "..."}
# responses api: {"type": "input_text", "text": "..."}
@@ -3216,10 +3251,15 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
new_blocks.append(new_block)
elif block["type"] in ("input_text", "input_image", "input_file"):
new_blocks.append(block)
elif block["type"] in non_message_item_types:
input_.append(block)
else:
pass
msg["content"] = new_blocks
input_.append(msg)
if msg["content"]:
input_.append(msg)
else:
input_.append(msg)
else:
input_.append(msg)
@@ -3366,7 +3406,10 @@ def _construct_lc_result_from_responses_api(
def _convert_responses_chunk_to_generation_chunk(
chunk: Any, schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None
chunk: Any,
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
has_reasoning: bool = False,
) -> Optional[ChatGenerationChunk]:
content = []
tool_call_chunks: list = []
@@ -3429,6 +3472,10 @@ def _convert_responses_chunk_to_generation_chunk(
"web_search_call",
"file_search_call",
"computer_call",
"code_interpreter_call",
"mcp_call",
"mcp_list_tools",
"mcp_approval_request",
):
additional_kwargs["tool_outputs"] = [
chunk.item.model_dump(exclude_none=True, mode="json")
@@ -3444,9 +3491,11 @@ def _convert_responses_chunk_to_generation_chunk(
elif chunk.type == "response.refusal.done":
additional_kwargs["refusal"] = chunk.refusal
elif chunk.type == "response.output_item.added" and chunk.item.type == "reasoning":
additional_kwargs["reasoning"] = chunk.item.model_dump(
exclude_none=True, mode="json"
)
if not has_reasoning:
# Hack until breaking release: store first reasoning item ID.
additional_kwargs["reasoning"] = chunk.item.model_dump(
exclude_none=True, mode="json"
)
elif chunk.type == "response.reasoning_summary_part.added":
additional_kwargs["reasoning"] = {
# langchain-core uses the `index` key to aggregate text blocks.

View File

@@ -11,6 +11,7 @@ from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from pydantic import BaseModel
from typing_extensions import TypedDict
@@ -377,3 +378,73 @@ def test_stream_reasoning_summary() -> None:
message_2 = {"role": "user", "content": "Thank you."}
response_2 = llm.invoke([message_1, response_1, message_2])
assert isinstance(response_2, AIMessage)
# TODO: VCR some of these
def test_code_interpreter() -> None:
llm = ChatOpenAI(model="o4-mini", use_responses_api=True)
llm_with_tools = llm.bind_tools(
[{"type": "code_interpreter", "container": {"type": "auto"}}]
)
response = llm_with_tools.invoke(
"Write and run code to answer the question: what is 3^3?"
)
_check_response(response)
tool_outputs = response.additional_kwargs["tool_outputs"]
assert tool_outputs
assert any(output["type"] == "code_interpreter_call" for output in tool_outputs)
# Test streaming
# Use same container
tool_outputs = response.additional_kwargs["tool_outputs"]
assert len(tool_outputs) == 1
container_id = tool_outputs[0]["container_id"]
llm_with_tools = llm.bind_tools(
[{"type": "code_interpreter", "container": container_id}]
)
full: Optional[BaseMessageChunk] = None
for chunk in llm_with_tools.stream(
"Write and run code to answer the question: what is 3^3?"
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
tool_outputs = full.additional_kwargs["tool_outputs"]
assert tool_outputs
assert any(output["type"] == "code_interpreter_call" for output in tool_outputs)
def test_mcp_builtin() -> None:
pytest.skip() # TODO: set up VCR
llm = ChatOpenAI(model="o4-mini", use_responses_api=True)
llm_with_tools = llm.bind_tools(
[
{
"type": "mcp",
"server_label": "deepwiki",
"server_url": "https://mcp.deepwiki.com/mcp",
"require_approval": {"always": {"tool_names": ["read_wiki_structure"]}},
}
]
)
response = llm_with_tools.invoke(
"What transport protocols does the 2025-03-26 version of the MCP spec "
"(modelcontextprotocol/modelcontextprotocol) support?"
)
approval_message = HumanMessage(
[
{
"type": "mcp_approval_response",
"approve": True,
"approval_request_id": output["id"],
}
for output in response.additional_kwargs["tool_outputs"]
if output["type"] == "mcp_approval_request"
]
)
_ = llm_with_tools.invoke(
[approval_message], previous_response_id=response.response_metadata["id"]
)

View File

@@ -21,6 +21,8 @@ from langchain_core.messages import (
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
from openai.types.responses import ResponseOutputMessage
from openai.types.responses.response import IncompleteDetails, Response, ResponseUsage
from openai.types.responses.response_error import ResponseError
@@ -1849,3 +1851,77 @@ def test_service_tier() -> None:
llm = ChatOpenAI(model="o4-mini", service_tier="flex")
payload = llm._get_request_payload([HumanMessage("Hello")])
assert payload["service_tier"] == "flex"
class FakeTracer(BaseTracer):
def __init__(self) -> None:
super().__init__()
self.chat_model_start_inputs: list = []
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
pass
def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:
self.chat_model_start_inputs.append({"args": args, "kwargs": kwargs})
return super().on_chat_model_start(*args, **kwargs)
def test_mcp_tracing() -> None:
# Test we exclude sensitive information from traces
llm = ChatOpenAI(model="o4-mini", use_responses_api=True)
tracer = FakeTracer()
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> Response:
return Response(
id="resp_123",
created_at=1234567890,
model="o4-mini",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseOutputMessage(
type="message",
id="msg_123",
content=[
ResponseOutputText(
type="output_text", text="Test response", annotations=[]
)
],
role="assistant",
status="completed",
)
],
)
mock_client.responses.create = mock_create
input_message = HumanMessage("Test query")
tools = [
{
"type": "mcp",
"server_label": "deepwiki",
"server_url": "https://mcp.deepwiki.com/mcp",
"require_approval": "always",
"headers": {"Authorization": "Bearer PLACEHOLDER"},
}
]
with patch.object(llm, "root_client", mock_client):
llm_with_tools = llm.bind_tools(tools)
_ = llm_with_tools.invoke([input_message], config={"callbacks": [tracer]})
# Test headers are not traced
assert len(tracer.chat_model_start_inputs) == 1
invocation_params = tracer.chat_model_start_inputs[0]["kwargs"]["invocation_params"]
for tool in invocation_params["tools"]:
if "headers" in tool:
assert tool["headers"] == "**REDACTED**"
for substring in ["Authorization", "Bearer", "PLACEHOLDER"]:
assert substring not in str(tracer.chat_model_start_inputs)
# Test headers are correctly propagated to request
payload = llm_with_tools._get_request_payload([input_message], tools=tools) # type: ignore[attr-defined]
assert payload["tools"][0]["headers"]["Authorization"] == "Bearer PLACEHOLDER"

View File

@@ -462,7 +462,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.59"
version = "0.3.60"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },
@@ -827,7 +827,7 @@ wheels = [
[[package]]
name = "openai"
version = "1.68.2"
version = "1.81.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -839,9 +839,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3f/6b/6b002d5d38794645437ae3ddb42083059d556558493408d39a0fcea608bc/openai-1.68.2.tar.gz", hash = "sha256:b720f0a95a1dbe1429c0d9bb62096a0d98057bcda82516f6e8af10284bdd5b19", size = 413429 }
sdist = { url = "https://files.pythonhosted.org/packages/1c/89/a1e4f3fa7ca4f7fec90dbf47d93b7cd5ff65924926733af15044e302a192/openai-1.81.0.tar.gz", hash = "sha256:349567a8607e0bcffd28e02f96b5c2397d0d25d06732d90ab3ecbf97abf030f9", size = 456861 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fd/34/cebce15f64eb4a3d609a83ac3568d43005cc9a1cba9d7fde5590fd415423/openai-1.68.2-py3-none-any.whl", hash = "sha256:24484cb5c9a33b58576fdc5acf0e5f92603024a4e39d0b99793dfa1eb14c2b36", size = 606073 },
{ url = "https://files.pythonhosted.org/packages/02/66/bcc7f9bf48e8610a33e3b5c96a5a644dad032d92404ea2a5e8b43ba067e8/openai-1.81.0-py3-none-any.whl", hash = "sha256:1c71572e22b43876c5d7d65ade0b7b516bb527c3d44ae94111267a09125f7bae", size = 717529 },
]
[[package]]