This commit is contained in:
Bagatur 2025-03-11 19:46:47 -07:00
parent 9808eb2149
commit d662b095ca
2 changed files with 191 additions and 95 deletions

View File

@ -12,6 +12,7 @@ import sys
import warnings
from functools import partial
from io import BytesIO
from json import JSONDecodeError
from math import ceil
from operator import itemgetter
from typing import (
@ -114,6 +115,8 @@ logger = logging.getLogger(__name__)
# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances
global_ssl_context = ssl.create_default_context(cafile=certifi.where())
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
@ -452,18 +455,6 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
raise
def _is_builtin_tool(tool: dict) -> bool:
return "type" in tool and tool["type"] != "function"
def _transform_payload_for_responses(payload: dict) -> dict:
updated_payload = payload.copy()
if messages := updated_payload.pop("messages"):
updated_payload["input"] = messages
return updated_payload
class _FunctionCall(TypedDict):
name: str
@ -793,8 +784,7 @@ class BaseChatOpenAI(BaseChatModel):
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
responses_payload = _transform_payload_for_responses(payload)
context_manager = self.root_client.responses.create(**responses_payload)
context_manager = self.root_client.responses.create(**payload)
with context_manager as response:
for chunk in response:
@ -816,10 +806,7 @@ class BaseChatOpenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]:
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
responses_payload = _transform_payload_for_responses(payload)
context_manager = await self.root_async_client.responses.create(
**responses_payload
)
context_manager = await self.root_async_client.responses.create(**payload)
async with context_manager as response:
async for chunk in response:
@ -926,13 +913,9 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
if "tools" in payload and any(
_is_builtin_tool(tool) for tool in payload["tools"]
):
responses_payload = _transform_payload_for_responses(payload)
response = self.root_client.responses.create(**responses_payload)
return self._create_chat_result_responses(response, generation_info)
elif _use_response_api(payload):
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_response_api(response)
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info)
@ -948,11 +931,12 @@ class BaseChatOpenAI(BaseChatModel):
if stop is not None:
kwargs["stop"] = stop
return {
"messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
**kwargs,
}
payload = {**self._default_params, **kwargs}
if _use_response_api(payload):
payload["input"] = _construct_response_api_input(messages)
else:
payload["messages"] = [_convert_message_to_dict(m) for m in messages]
return payload
def _create_chat_result(
self,
@ -1003,50 +987,6 @@ class BaseChatOpenAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output)
def _create_chat_result_responses(
self, response: Response, generation_info: Optional[Dict] = None
) -> ChatResult:
generations = []
if error := response.error:
raise ValueError(error)
token_usage = response.usage.model_dump() if response.usage else {}
generation_info = {}
content_blocks = []
for output in response.output:
if output.type == "message":
for content in output.content:
if content.type == "output_text":
block = {
"type": "text",
"text": content.text,
"annotations": [
annotation.model_dump()
for annotation in content.annotations
],
}
content_blocks.append(block)
usage_metadata = _create_usage_metadata_responses(token_usage)
message = AIMessage(
content=content_blocks, # type: ignore[arg-type]
id=output.id,
usage_metadata=usage_metadata,
)
if output.status:
generation_info["status"] = output.status
gen = ChatGeneration(message=message, generation_info=generation_info)
generations.append(gen)
else:
tool_output = output.model_dump()
if "tool_outputs" in generation_info:
generation_info["tool_outputs"].append(tool_output)
else:
generation_info["tool_outputs"] = [tool_output]
llm_output = {"model_name": response.model}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
@ -1147,15 +1087,9 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
if "tools" in payload and any(
_is_builtin_tool(tool) for tool in payload["tools"]
):
responses_payload = _transform_payload_for_responses(payload)
response = await self.root_async_client.responses.create(
**responses_payload
)
return self._create_chat_result_responses(response, generation_info)
elif _use_response_api(payload):
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_response_api(response)
else:
response = await self.async_client.create(**payload)
return await run_in_executor(
@ -2249,9 +2183,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
if "tools" in kwargs and any(
_is_builtin_tool(tool) for tool in kwargs["tools"]
):
if _use_response_api(kwargs):
return super()._stream_responses(*args, **kwargs)
else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
@ -2269,9 +2201,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
if "tools" in kwargs and any(
_is_builtin_tool(tool) for tool in kwargs["tools"]
):
if _use_response_api(kwargs):
async for chunk in super()._astream_responses(*args, **kwargs):
yield chunk
else:
@ -2818,3 +2748,162 @@ def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata:
**{k: v for k, v in output_token_details.items() if v is not None}
),
)
def _is_builtin_tool(tool: dict) -> bool:
return "type" in tool and tool["type"] != "function"
def _use_response_api(payload: dict) -> bool:
return "tools" in payload and any(
_is_builtin_tool(tool) for tool in payload["tools"]
)
def _construct_response_api_input(messages: list[BaseMessage]) -> list:
input_ = []
for lc_msg in messages:
msg = _convert_message_to_dict(lc_msg)
if msg["role"] == "tool":
function_call_output = {
"type": "function_call_output",
"output": msg["content"],
"call_id": msg["tool_call_id"],
}
input_.append(function_call_output)
elif msg["role"] == "assistant":
if tool_calls := msg.pop("tool_calls", None):
if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY):
raise ValueError(...)
function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY]
for tool_call in tool_calls:
function_call = {
"type": "function_call",
"name": tool_call["name"],
"arguments": tool_call["arguments"],
"call_id": tool_call["id"],
"id": function_call_ids[tool_call["id"]],
}
input_.append(function_call)
if msg.get("content"):
input_.append(msg)
elif msg["role"] == "user":
if isinstance(msg["content"], list):
for block in msg["content"]:
# chat api: {"type": "text", "text": "..."}
# response api: {"type": "input_text", "text": "..."}
if block["type"] == "text":
block["type"] = "input_text"
# chat api: {"type": "image_url", "image_url": {"url": "...", "detail": "..."}} # noqa: E501
# response api: {"type": "image_url", "image_url": "...", "detail": "...", "file_id": "..."} # noqa: E501
elif block["type"] == "image_url":
block["type"] = "input_image"
if isinstance(block.get("image_url"), dict):
image_url = block.pop("image_url")
block["image_url"] = image_url["url"]
if image_url.get("detail"):
block["detail"] = image_url["detail"]
else:
pass
input_.append(msg)
else:
input_.append(msg)
return input_
def _construct_lc_result_from_response_api(response: Response) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
raise ValueError(response.error)
response_metadata = {
k: v
for k, v in response.model_dump(exclude_none=True, mode="json").items()
if k
in (
"created_at",
"id",
"incomplete_details",
"metadata",
"object",
"status",
"user",
"model",
)
}
# for compatibility with chat completion calls.
response_metadata["model_name"] = response.get("model")
if response.usage:
usage_metadata = _create_usage_metadata_responses(response.usage.model_dump())
else:
usage_metadata = None
content_blocks = []
tool_calls = []
invalid_tool_calls = []
additional_kwargs: dict = {}
msg_id = None
for output in response.output:
if output.type == "message":
for content in output.content:
if content.type == "output_text":
block = {
"type": "text",
"text": content.text,
"annotations": [
annotation.model_dump()
for annotation in content.annotations
],
}
content_blocks.append(block)
if content.type == "refusal":
additional_kwargs["refusal"] = content.refusal
msg_id = output.id
elif output.type == "function_call":
try:
args = json.loads(output.arguments, strict=False)
error = None
except JSONDecodeError as e:
args = output.arguments
error = str(e)
if error is None:
tool_call = {
"type": "tool_call",
"name": output.name,
"args": args,
"id": output.call_id,
}
tool_calls.append(tool_call)
else:
tool_call = {
"type": "invalid_tool_call",
"name": output.name,
"args": args,
"id": output.call_id,
"error": error,
}
invalid_tool_calls.append(tool_call)
if _FUNCTION_CALL_IDS_MAP_KEY not in additional_kwargs:
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] = {}
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][output.call_id] = output.id
elif output.type == "reasoning":
additional_kwargs["reasoning"] = output.model_dump(
exclude_none=True, mode="json"
)
else:
tool_output = output.model_dump(exclude_none=True, mode="json")
if "tool_outputs" in additional_kwargs:
additional_kwargs["tool_outputs"].append(tool_output)
else:
additional_kwargs["tool_outputs"] = [tool_output]
message = AIMessage(
content=content_blocks or None, # type: ignore[arg-type]
id=msg_id,
usage_metadata=usage_metadata,
response_metadata=response_metadata,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -1258,14 +1258,14 @@ def _check_response(response: Optional[BaseMessage]) -> None:
assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
for tool_output in response.response_metadata["tool_outputs"]:
for tool_output in response.additional_kwargs["tool_outputs"]:
assert tool_output["id"]
assert tool_output["status"]
assert tool_output["type"]
def test_web_search() -> None:
llm = ChatOpenAI(model="gpt-4o")
llm = ChatOpenAI(model="gpt-4o-mini")
response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
@ -1283,9 +1283,16 @@ def test_web_search() -> None:
full = chunk if full is None else full + chunk
_check_response(full)
llm.invoke(
"what about a negative one",
tools=[{"type": "web_search_preview"}],
response_id=response.response_metadata["id"]
)
_check_response(response)
async def test_web_search_async() -> None:
llm = ChatOpenAI(model="gpt-4o")
llm = ChatOpenAI(model="gpt-4o-mini")
response = await llm.ainvoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
@ -1307,7 +1314,7 @@ async def test_web_search_async() -> None:
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model="gpt-4o")
llm = ChatOpenAI(model="gpt-4o-mini")
tool = {
"type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],