mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-21 10:26:57 +00:00
cr
This commit is contained in:
parent
9808eb2149
commit
d662b095ca
@ -12,6 +12,7 @@ import sys
|
||||
import warnings
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from json import JSONDecodeError
|
||||
from math import ceil
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
@ -114,6 +115,8 @@ logger = logging.getLogger(__name__)
|
||||
# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances
|
||||
global_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
@ -452,18 +455,6 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def _is_builtin_tool(tool: dict) -> bool:
|
||||
return "type" in tool and tool["type"] != "function"
|
||||
|
||||
|
||||
def _transform_payload_for_responses(payload: dict) -> dict:
|
||||
updated_payload = payload.copy()
|
||||
if messages := updated_payload.pop("messages"):
|
||||
updated_payload["input"] = messages
|
||||
|
||||
return updated_payload
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
name: str
|
||||
|
||||
@ -793,8 +784,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
responses_payload = _transform_payload_for_responses(payload)
|
||||
context_manager = self.root_client.responses.create(**responses_payload)
|
||||
context_manager = self.root_client.responses.create(**payload)
|
||||
|
||||
with context_manager as response:
|
||||
for chunk in response:
|
||||
@ -816,10 +806,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
responses_payload = _transform_payload_for_responses(payload)
|
||||
context_manager = await self.root_async_client.responses.create(
|
||||
**responses_payload
|
||||
)
|
||||
context_manager = await self.root_async_client.responses.create(**payload)
|
||||
|
||||
async with context_manager as response:
|
||||
async for chunk in response:
|
||||
@ -926,13 +913,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
if "tools" in payload and any(
|
||||
_is_builtin_tool(tool) for tool in payload["tools"]
|
||||
):
|
||||
responses_payload = _transform_payload_for_responses(payload)
|
||||
response = self.root_client.responses.create(**responses_payload)
|
||||
return self._create_chat_result_responses(response, generation_info)
|
||||
elif _use_response_api(payload):
|
||||
response = self.root_client.responses.create(**payload)
|
||||
return _construct_lc_result_from_response_api(response)
|
||||
else:
|
||||
response = self.client.create(**payload)
|
||||
return self._create_chat_result(response, generation_info)
|
||||
@ -948,11 +931,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if stop is not None:
|
||||
kwargs["stop"] = stop
|
||||
|
||||
return {
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
payload = {**self._default_params, **kwargs}
|
||||
if _use_response_api(payload):
|
||||
payload["input"] = _construct_response_api_input(messages)
|
||||
else:
|
||||
payload["messages"] = [_convert_message_to_dict(m) for m in messages]
|
||||
return payload
|
||||
|
||||
def _create_chat_result(
|
||||
self,
|
||||
@ -1003,50 +987,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_chat_result_responses(
|
||||
self, response: Response, generation_info: Optional[Dict] = None
|
||||
) -> ChatResult:
|
||||
generations = []
|
||||
|
||||
if error := response.error:
|
||||
raise ValueError(error)
|
||||
|
||||
token_usage = response.usage.model_dump() if response.usage else {}
|
||||
generation_info = {}
|
||||
content_blocks = []
|
||||
for output in response.output:
|
||||
if output.type == "message":
|
||||
for content in output.content:
|
||||
if content.type == "output_text":
|
||||
block = {
|
||||
"type": "text",
|
||||
"text": content.text,
|
||||
"annotations": [
|
||||
annotation.model_dump()
|
||||
for annotation in content.annotations
|
||||
],
|
||||
}
|
||||
content_blocks.append(block)
|
||||
usage_metadata = _create_usage_metadata_responses(token_usage)
|
||||
message = AIMessage(
|
||||
content=content_blocks, # type: ignore[arg-type]
|
||||
id=output.id,
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
if output.status:
|
||||
generation_info["status"] = output.status
|
||||
gen = ChatGeneration(message=message, generation_info=generation_info)
|
||||
generations.append(gen)
|
||||
else:
|
||||
tool_output = output.model_dump()
|
||||
if "tool_outputs" in generation_info:
|
||||
generation_info["tool_outputs"].append(tool_output)
|
||||
else:
|
||||
generation_info["tool_outputs"] = [tool_output]
|
||||
llm_output = {"model_name": response.model}
|
||||
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -1147,15 +1087,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
raw_response = await self.async_client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
if "tools" in payload and any(
|
||||
_is_builtin_tool(tool) for tool in payload["tools"]
|
||||
):
|
||||
responses_payload = _transform_payload_for_responses(payload)
|
||||
response = await self.root_async_client.responses.create(
|
||||
**responses_payload
|
||||
)
|
||||
return self._create_chat_result_responses(response, generation_info)
|
||||
elif _use_response_api(payload):
|
||||
response = await self.root_async_client.responses.create(**payload)
|
||||
return _construct_lc_result_from_response_api(response)
|
||||
else:
|
||||
response = await self.async_client.create(**payload)
|
||||
return await run_in_executor(
|
||||
@ -2249,9 +2183,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
if "tools" in kwargs and any(
|
||||
_is_builtin_tool(tool) for tool in kwargs["tools"]
|
||||
):
|
||||
if _use_response_api(kwargs):
|
||||
return super()._stream_responses(*args, **kwargs)
|
||||
else:
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
@ -2269,9 +2201,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
if "tools" in kwargs and any(
|
||||
_is_builtin_tool(tool) for tool in kwargs["tools"]
|
||||
):
|
||||
if _use_response_api(kwargs):
|
||||
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||
yield chunk
|
||||
else:
|
||||
@ -2818,3 +2748,162 @@ def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata:
|
||||
**{k: v for k, v in output_token_details.items() if v is not None}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _is_builtin_tool(tool: dict) -> bool:
|
||||
return "type" in tool and tool["type"] != "function"
|
||||
|
||||
|
||||
def _use_response_api(payload: dict) -> bool:
|
||||
return "tools" in payload and any(
|
||||
_is_builtin_tool(tool) for tool in payload["tools"]
|
||||
)
|
||||
|
||||
|
||||
def _construct_response_api_input(messages: list[BaseMessage]) -> list:
|
||||
input_ = []
|
||||
for lc_msg in messages:
|
||||
msg = _convert_message_to_dict(lc_msg)
|
||||
if msg["role"] == "tool":
|
||||
function_call_output = {
|
||||
"type": "function_call_output",
|
||||
"output": msg["content"],
|
||||
"call_id": msg["tool_call_id"],
|
||||
}
|
||||
input_.append(function_call_output)
|
||||
elif msg["role"] == "assistant":
|
||||
if tool_calls := msg.pop("tool_calls", None):
|
||||
if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY):
|
||||
raise ValueError(...)
|
||||
function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY]
|
||||
for tool_call in tool_calls:
|
||||
function_call = {
|
||||
"type": "function_call",
|
||||
"name": tool_call["name"],
|
||||
"arguments": tool_call["arguments"],
|
||||
"call_id": tool_call["id"],
|
||||
"id": function_call_ids[tool_call["id"]],
|
||||
}
|
||||
input_.append(function_call)
|
||||
if msg.get("content"):
|
||||
input_.append(msg)
|
||||
elif msg["role"] == "user":
|
||||
if isinstance(msg["content"], list):
|
||||
for block in msg["content"]:
|
||||
# chat api: {"type": "text", "text": "..."}
|
||||
# response api: {"type": "input_text", "text": "..."}
|
||||
if block["type"] == "text":
|
||||
block["type"] = "input_text"
|
||||
# chat api: {"type": "image_url", "image_url": {"url": "...", "detail": "..."}} # noqa: E501
|
||||
# response api: {"type": "image_url", "image_url": "...", "detail": "...", "file_id": "..."} # noqa: E501
|
||||
elif block["type"] == "image_url":
|
||||
block["type"] = "input_image"
|
||||
if isinstance(block.get("image_url"), dict):
|
||||
image_url = block.pop("image_url")
|
||||
block["image_url"] = image_url["url"]
|
||||
if image_url.get("detail"):
|
||||
block["detail"] = image_url["detail"]
|
||||
else:
|
||||
pass
|
||||
input_.append(msg)
|
||||
else:
|
||||
input_.append(msg)
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _construct_lc_result_from_response_api(response: Response) -> ChatResult:
|
||||
"""Construct ChatResponse from OpenAI Response API response."""
|
||||
if response.error:
|
||||
raise ValueError(response.error)
|
||||
|
||||
response_metadata = {
|
||||
k: v
|
||||
for k, v in response.model_dump(exclude_none=True, mode="json").items()
|
||||
if k
|
||||
in (
|
||||
"created_at",
|
||||
"id",
|
||||
"incomplete_details",
|
||||
"metadata",
|
||||
"object",
|
||||
"status",
|
||||
"user",
|
||||
"model",
|
||||
)
|
||||
}
|
||||
# for compatibility with chat completion calls.
|
||||
response_metadata["model_name"] = response.get("model")
|
||||
if response.usage:
|
||||
usage_metadata = _create_usage_metadata_responses(response.usage.model_dump())
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
content_blocks = []
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
additional_kwargs: dict = {}
|
||||
msg_id = None
|
||||
for output in response.output:
|
||||
if output.type == "message":
|
||||
for content in output.content:
|
||||
if content.type == "output_text":
|
||||
block = {
|
||||
"type": "text",
|
||||
"text": content.text,
|
||||
"annotations": [
|
||||
annotation.model_dump()
|
||||
for annotation in content.annotations
|
||||
],
|
||||
}
|
||||
content_blocks.append(block)
|
||||
if content.type == "refusal":
|
||||
additional_kwargs["refusal"] = content.refusal
|
||||
msg_id = output.id
|
||||
elif output.type == "function_call":
|
||||
try:
|
||||
args = json.loads(output.arguments, strict=False)
|
||||
error = None
|
||||
except JSONDecodeError as e:
|
||||
args = output.arguments
|
||||
error = str(e)
|
||||
if error is None:
|
||||
tool_call = {
|
||||
"type": "tool_call",
|
||||
"name": output.name,
|
||||
"args": args,
|
||||
"id": output.call_id,
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
else:
|
||||
tool_call = {
|
||||
"type": "invalid_tool_call",
|
||||
"name": output.name,
|
||||
"args": args,
|
||||
"id": output.call_id,
|
||||
"error": error,
|
||||
}
|
||||
invalid_tool_calls.append(tool_call)
|
||||
if _FUNCTION_CALL_IDS_MAP_KEY not in additional_kwargs:
|
||||
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] = {}
|
||||
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][output.call_id] = output.id
|
||||
elif output.type == "reasoning":
|
||||
additional_kwargs["reasoning"] = output.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
else:
|
||||
tool_output = output.model_dump(exclude_none=True, mode="json")
|
||||
if "tool_outputs" in additional_kwargs:
|
||||
additional_kwargs["tool_outputs"].append(tool_output)
|
||||
else:
|
||||
additional_kwargs["tool_outputs"] = [tool_output]
|
||||
message = AIMessage(
|
||||
content=content_blocks or None, # type: ignore[arg-type]
|
||||
id=msg_id,
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata=response_metadata,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
@ -1258,14 +1258,14 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
||||
assert response.usage_metadata["output_tokens"] > 0
|
||||
assert response.usage_metadata["total_tokens"] > 0
|
||||
assert response.response_metadata["model_name"]
|
||||
for tool_output in response.response_metadata["tool_outputs"]:
|
||||
for tool_output in response.additional_kwargs["tool_outputs"]:
|
||||
assert tool_output["id"]
|
||||
assert tool_output["status"]
|
||||
assert tool_output["type"]
|
||||
|
||||
|
||||
def test_web_search() -> None:
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
response = llm.invoke(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
@ -1283,9 +1283,16 @@ def test_web_search() -> None:
|
||||
full = chunk if full is None else full + chunk
|
||||
_check_response(full)
|
||||
|
||||
llm.invoke(
|
||||
"what about a negative one",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
response_id=response.response_metadata["id"]
|
||||
)
|
||||
_check_response(response)
|
||||
|
||||
|
||||
async def test_web_search_async() -> None:
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
response = await llm.ainvoke(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
@ -1307,7 +1314,7 @@ async def test_web_search_async() -> None:
|
||||
|
||||
def test_file_search() -> None:
|
||||
pytest.skip() # TODO: set up infra
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
tool = {
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
|
||||
|
Loading…
Reference in New Issue
Block a user