This commit is contained in:
Bagatur 2025-03-11 19:46:47 -07:00
parent 9808eb2149
commit d662b095ca
2 changed files with 191 additions and 95 deletions

View File

@ -12,6 +12,7 @@ import sys
import warnings import warnings
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from json import JSONDecodeError
from math import ceil from math import ceil
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
@ -114,6 +115,8 @@ logger = logging.getLogger(__name__)
# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances # https://www.python-httpx.org/advanced/ssl/#configuring-client-instances
global_ssl_context = ssl.create_default_context(cafile=certifi.where()) global_ssl_context = ssl.create_default_context(cafile=certifi.where())
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message. """Convert a dictionary to a LangChain message.
@ -452,18 +455,6 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
raise raise
def _is_builtin_tool(tool: dict) -> bool:
return "type" in tool and tool["type"] != "function"
def _transform_payload_for_responses(payload: dict) -> dict:
updated_payload = payload.copy()
if messages := updated_payload.pop("messages"):
updated_payload["input"] = messages
return updated_payload
class _FunctionCall(TypedDict): class _FunctionCall(TypedDict):
name: str name: str
@ -793,8 +784,7 @@ class BaseChatOpenAI(BaseChatModel):
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
responses_payload = _transform_payload_for_responses(payload) context_manager = self.root_client.responses.create(**payload)
context_manager = self.root_client.responses.create(**responses_payload)
with context_manager as response: with context_manager as response:
for chunk in response: for chunk in response:
@ -816,10 +806,7 @@ class BaseChatOpenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
kwargs["stream"] = True kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
responses_payload = _transform_payload_for_responses(payload) context_manager = await self.root_async_client.responses.create(**payload)
context_manager = await self.root_async_client.responses.create(
**responses_payload
)
async with context_manager as response: async with context_manager as response:
async for chunk in response: async for chunk in response:
@ -926,15 +913,11 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = self.client.with_raw_response.create(**payload) raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
elif _use_response_api(payload):
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_response_api(response)
else: else:
if "tools" in payload and any( response = self.client.create(**payload)
_is_builtin_tool(tool) for tool in payload["tools"]
):
responses_payload = _transform_payload_for_responses(payload)
response = self.root_client.responses.create(**responses_payload)
return self._create_chat_result_responses(response, generation_info)
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info) return self._create_chat_result(response, generation_info)
def _get_request_payload( def _get_request_payload(
@ -948,11 +931,12 @@ class BaseChatOpenAI(BaseChatModel):
if stop is not None: if stop is not None:
kwargs["stop"] = stop kwargs["stop"] = stop
return { payload = {**self._default_params, **kwargs}
"messages": [_convert_message_to_dict(m) for m in messages], if _use_response_api(payload):
**self._default_params, payload["input"] = _construct_response_api_input(messages)
**kwargs, else:
} payload["messages"] = [_convert_message_to_dict(m) for m in messages]
return payload
def _create_chat_result( def _create_chat_result(
self, self,
@ -1003,50 +987,6 @@ class BaseChatOpenAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
def _create_chat_result_responses(
self, response: Response, generation_info: Optional[Dict] = None
) -> ChatResult:
generations = []
if error := response.error:
raise ValueError(error)
token_usage = response.usage.model_dump() if response.usage else {}
generation_info = {}
content_blocks = []
for output in response.output:
if output.type == "message":
for content in output.content:
if content.type == "output_text":
block = {
"type": "text",
"text": content.text,
"annotations": [
annotation.model_dump()
for annotation in content.annotations
],
}
content_blocks.append(block)
usage_metadata = _create_usage_metadata_responses(token_usage)
message = AIMessage(
content=content_blocks, # type: ignore[arg-type]
id=output.id,
usage_metadata=usage_metadata,
)
if output.status:
generation_info["status"] = output.status
gen = ChatGeneration(message=message, generation_info=generation_info)
generations.append(gen)
else:
tool_output = output.model_dump()
if "tool_outputs" in generation_info:
generation_info["tool_outputs"].append(tool_output)
else:
generation_info["tool_outputs"] = [tool_output]
llm_output = {"model_name": response.model}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -1147,17 +1087,11 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = await self.async_client.with_raw_response.create(**payload) raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
elif _use_response_api(payload):
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_response_api(response)
else: else:
if "tools" in payload and any( response = await self.async_client.create(**payload)
_is_builtin_tool(tool) for tool in payload["tools"]
):
responses_payload = _transform_payload_for_responses(payload)
response = await self.root_async_client.responses.create(
**responses_payload
)
return self._create_chat_result_responses(response, generation_info)
else:
response = await self.async_client.create(**payload)
return await run_in_executor( return await run_in_executor(
None, self._create_chat_result, response, generation_info None, self._create_chat_result, response, generation_info
) )
@ -2249,9 +2183,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options.""" """Set default stream_options."""
if "tools" in kwargs and any( if _use_response_api(kwargs):
_is_builtin_tool(tool) for tool in kwargs["tools"]
):
return super()._stream_responses(*args, **kwargs) return super()._stream_responses(*args, **kwargs)
else: else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs) stream_usage = self._should_stream_usage(stream_usage, **kwargs)
@ -2269,9 +2201,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options.""" """Set default stream_options."""
if "tools" in kwargs and any( if _use_response_api(kwargs):
_is_builtin_tool(tool) for tool in kwargs["tools"]
):
async for chunk in super()._astream_responses(*args, **kwargs): async for chunk in super()._astream_responses(*args, **kwargs):
yield chunk yield chunk
else: else:
@ -2818,3 +2748,162 @@ def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata:
**{k: v for k, v in output_token_details.items() if v is not None} **{k: v for k, v in output_token_details.items() if v is not None}
), ),
) )
def _is_builtin_tool(tool: dict) -> bool:
return "type" in tool and tool["type"] != "function"
def _use_response_api(payload: dict) -> bool:
return "tools" in payload and any(
_is_builtin_tool(tool) for tool in payload["tools"]
)
def _construct_response_api_input(messages: list[BaseMessage]) -> list:
input_ = []
for lc_msg in messages:
msg = _convert_message_to_dict(lc_msg)
if msg["role"] == "tool":
function_call_output = {
"type": "function_call_output",
"output": msg["content"],
"call_id": msg["tool_call_id"],
}
input_.append(function_call_output)
elif msg["role"] == "assistant":
if tool_calls := msg.pop("tool_calls", None):
if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY):
raise ValueError(...)
function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY]
for tool_call in tool_calls:
function_call = {
"type": "function_call",
"name": tool_call["name"],
"arguments": tool_call["arguments"],
"call_id": tool_call["id"],
"id": function_call_ids[tool_call["id"]],
}
input_.append(function_call)
if msg.get("content"):
input_.append(msg)
elif msg["role"] == "user":
if isinstance(msg["content"], list):
for block in msg["content"]:
# chat api: {"type": "text", "text": "..."}
# response api: {"type": "input_text", "text": "..."}
if block["type"] == "text":
block["type"] = "input_text"
# chat api: {"type": "image_url", "image_url": {"url": "...", "detail": "..."}} # noqa: E501
# response api: {"type": "image_url", "image_url": "...", "detail": "...", "file_id": "..."} # noqa: E501
elif block["type"] == "image_url":
block["type"] = "input_image"
if isinstance(block.get("image_url"), dict):
image_url = block.pop("image_url")
block["image_url"] = image_url["url"]
if image_url.get("detail"):
block["detail"] = image_url["detail"]
else:
pass
input_.append(msg)
else:
input_.append(msg)
return input_
def _construct_lc_result_from_response_api(response: Response) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
raise ValueError(response.error)
response_metadata = {
k: v
for k, v in response.model_dump(exclude_none=True, mode="json").items()
if k
in (
"created_at",
"id",
"incomplete_details",
"metadata",
"object",
"status",
"user",
"model",
)
}
# for compatibility with chat completion calls.
response_metadata["model_name"] = response.get("model")
if response.usage:
usage_metadata = _create_usage_metadata_responses(response.usage.model_dump())
else:
usage_metadata = None
content_blocks = []
tool_calls = []
invalid_tool_calls = []
additional_kwargs: dict = {}
msg_id = None
for output in response.output:
if output.type == "message":
for content in output.content:
if content.type == "output_text":
block = {
"type": "text",
"text": content.text,
"annotations": [
annotation.model_dump()
for annotation in content.annotations
],
}
content_blocks.append(block)
if content.type == "refusal":
additional_kwargs["refusal"] = content.refusal
msg_id = output.id
elif output.type == "function_call":
try:
args = json.loads(output.arguments, strict=False)
error = None
except JSONDecodeError as e:
args = output.arguments
error = str(e)
if error is None:
tool_call = {
"type": "tool_call",
"name": output.name,
"args": args,
"id": output.call_id,
}
tool_calls.append(tool_call)
else:
tool_call = {
"type": "invalid_tool_call",
"name": output.name,
"args": args,
"id": output.call_id,
"error": error,
}
invalid_tool_calls.append(tool_call)
if _FUNCTION_CALL_IDS_MAP_KEY not in additional_kwargs:
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] = {}
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][output.call_id] = output.id
elif output.type == "reasoning":
additional_kwargs["reasoning"] = output.model_dump(
exclude_none=True, mode="json"
)
else:
tool_output = output.model_dump(exclude_none=True, mode="json")
if "tool_outputs" in additional_kwargs:
additional_kwargs["tool_outputs"].append(tool_output)
else:
additional_kwargs["tool_outputs"] = [tool_output]
message = AIMessage(
content=content_blocks or None, # type: ignore[arg-type]
id=msg_id,
usage_metadata=usage_metadata,
response_metadata=response_metadata,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -1258,14 +1258,14 @@ def _check_response(response: Optional[BaseMessage]) -> None:
assert response.usage_metadata["output_tokens"] > 0 assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0 assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"] assert response.response_metadata["model_name"]
for tool_output in response.response_metadata["tool_outputs"]: for tool_output in response.additional_kwargs["tool_outputs"]:
assert tool_output["id"] assert tool_output["id"]
assert tool_output["status"] assert tool_output["status"]
assert tool_output["type"] assert tool_output["type"]
def test_web_search() -> None: def test_web_search() -> None:
llm = ChatOpenAI(model="gpt-4o") llm = ChatOpenAI(model="gpt-4o-mini")
response = llm.invoke( response = llm.invoke(
"What was a positive news story from today?", "What was a positive news story from today?",
tools=[{"type": "web_search_preview"}], tools=[{"type": "web_search_preview"}],
@ -1283,9 +1283,16 @@ def test_web_search() -> None:
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
_check_response(full) _check_response(full)
llm.invoke(
"what about a negative one",
tools=[{"type": "web_search_preview"}],
response_id=response.response_metadata["id"]
)
_check_response(response)
async def test_web_search_async() -> None: async def test_web_search_async() -> None:
llm = ChatOpenAI(model="gpt-4o") llm = ChatOpenAI(model="gpt-4o-mini")
response = await llm.ainvoke( response = await llm.ainvoke(
"What was a positive news story from today?", "What was a positive news story from today?",
tools=[{"type": "web_search_preview"}], tools=[{"type": "web_search_preview"}],
@ -1307,7 +1314,7 @@ async def test_web_search_async() -> None:
def test_file_search() -> None: def test_file_search() -> None:
pytest.skip() # TODO: set up infra pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model="gpt-4o") llm = ChatOpenAI(model="gpt-4o-mini")
tool = { tool = {
"type": "file_search", "type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]], "vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],