mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-21 10:26:57 +00:00
cr
This commit is contained in:
parent
9808eb2149
commit
d662b095ca
@ -12,6 +12,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from json import JSONDecodeError
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -114,6 +115,8 @@ logger = logging.getLogger(__name__)
|
|||||||
# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances
|
# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances
|
||||||
global_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
global_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||||
|
|
||||||
|
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
"""Convert a dictionary to a LangChain message.
|
"""Convert a dictionary to a LangChain message.
|
||||||
@ -452,18 +455,6 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _is_builtin_tool(tool: dict) -> bool:
|
|
||||||
return "type" in tool and tool["type"] != "function"
|
|
||||||
|
|
||||||
|
|
||||||
def _transform_payload_for_responses(payload: dict) -> dict:
|
|
||||||
updated_payload = payload.copy()
|
|
||||||
if messages := updated_payload.pop("messages"):
|
|
||||||
updated_payload["input"] = messages
|
|
||||||
|
|
||||||
return updated_payload
|
|
||||||
|
|
||||||
|
|
||||||
class _FunctionCall(TypedDict):
|
class _FunctionCall(TypedDict):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
@ -793,8 +784,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
responses_payload = _transform_payload_for_responses(payload)
|
context_manager = self.root_client.responses.create(**payload)
|
||||||
context_manager = self.root_client.responses.create(**responses_payload)
|
|
||||||
|
|
||||||
with context_manager as response:
|
with context_manager as response:
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
@ -816,10 +806,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
responses_payload = _transform_payload_for_responses(payload)
|
context_manager = await self.root_async_client.responses.create(**payload)
|
||||||
context_manager = await self.root_async_client.responses.create(
|
|
||||||
**responses_payload
|
|
||||||
)
|
|
||||||
|
|
||||||
async with context_manager as response:
|
async with context_manager as response:
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
@ -926,15 +913,11 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
raw_response = self.client.with_raw_response.create(**payload)
|
raw_response = self.client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
generation_info = {"headers": dict(raw_response.headers)}
|
generation_info = {"headers": dict(raw_response.headers)}
|
||||||
|
elif _use_response_api(payload):
|
||||||
|
response = self.root_client.responses.create(**payload)
|
||||||
|
return _construct_lc_result_from_response_api(response)
|
||||||
else:
|
else:
|
||||||
if "tools" in payload and any(
|
response = self.client.create(**payload)
|
||||||
_is_builtin_tool(tool) for tool in payload["tools"]
|
|
||||||
):
|
|
||||||
responses_payload = _transform_payload_for_responses(payload)
|
|
||||||
response = self.root_client.responses.create(**responses_payload)
|
|
||||||
return self._create_chat_result_responses(response, generation_info)
|
|
||||||
else:
|
|
||||||
response = self.client.create(**payload)
|
|
||||||
return self._create_chat_result(response, generation_info)
|
return self._create_chat_result(response, generation_info)
|
||||||
|
|
||||||
def _get_request_payload(
|
def _get_request_payload(
|
||||||
@ -948,11 +931,12 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
if stop is not None:
|
if stop is not None:
|
||||||
kwargs["stop"] = stop
|
kwargs["stop"] = stop
|
||||||
|
|
||||||
return {
|
payload = {**self._default_params, **kwargs}
|
||||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
if _use_response_api(payload):
|
||||||
**self._default_params,
|
payload["input"] = _construct_response_api_input(messages)
|
||||||
**kwargs,
|
else:
|
||||||
}
|
payload["messages"] = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
return payload
|
||||||
|
|
||||||
def _create_chat_result(
|
def _create_chat_result(
|
||||||
self,
|
self,
|
||||||
@ -1003,50 +987,6 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
def _create_chat_result_responses(
|
|
||||||
self, response: Response, generation_info: Optional[Dict] = None
|
|
||||||
) -> ChatResult:
|
|
||||||
generations = []
|
|
||||||
|
|
||||||
if error := response.error:
|
|
||||||
raise ValueError(error)
|
|
||||||
|
|
||||||
token_usage = response.usage.model_dump() if response.usage else {}
|
|
||||||
generation_info = {}
|
|
||||||
content_blocks = []
|
|
||||||
for output in response.output:
|
|
||||||
if output.type == "message":
|
|
||||||
for content in output.content:
|
|
||||||
if content.type == "output_text":
|
|
||||||
block = {
|
|
||||||
"type": "text",
|
|
||||||
"text": content.text,
|
|
||||||
"annotations": [
|
|
||||||
annotation.model_dump()
|
|
||||||
for annotation in content.annotations
|
|
||||||
],
|
|
||||||
}
|
|
||||||
content_blocks.append(block)
|
|
||||||
usage_metadata = _create_usage_metadata_responses(token_usage)
|
|
||||||
message = AIMessage(
|
|
||||||
content=content_blocks, # type: ignore[arg-type]
|
|
||||||
id=output.id,
|
|
||||||
usage_metadata=usage_metadata,
|
|
||||||
)
|
|
||||||
if output.status:
|
|
||||||
generation_info["status"] = output.status
|
|
||||||
gen = ChatGeneration(message=message, generation_info=generation_info)
|
|
||||||
generations.append(gen)
|
|
||||||
else:
|
|
||||||
tool_output = output.model_dump()
|
|
||||||
if "tool_outputs" in generation_info:
|
|
||||||
generation_info["tool_outputs"].append(tool_output)
|
|
||||||
else:
|
|
||||||
generation_info["tool_outputs"] = [tool_output]
|
|
||||||
llm_output = {"model_name": response.model}
|
|
||||||
|
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -1147,17 +1087,11 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
raw_response = await self.async_client.with_raw_response.create(**payload)
|
raw_response = await self.async_client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
generation_info = {"headers": dict(raw_response.headers)}
|
generation_info = {"headers": dict(raw_response.headers)}
|
||||||
|
elif _use_response_api(payload):
|
||||||
|
response = await self.root_async_client.responses.create(**payload)
|
||||||
|
return _construct_lc_result_from_response_api(response)
|
||||||
else:
|
else:
|
||||||
if "tools" in payload and any(
|
response = await self.async_client.create(**payload)
|
||||||
_is_builtin_tool(tool) for tool in payload["tools"]
|
|
||||||
):
|
|
||||||
responses_payload = _transform_payload_for_responses(payload)
|
|
||||||
response = await self.root_async_client.responses.create(
|
|
||||||
**responses_payload
|
|
||||||
)
|
|
||||||
return self._create_chat_result_responses(response, generation_info)
|
|
||||||
else:
|
|
||||||
response = await self.async_client.create(**payload)
|
|
||||||
return await run_in_executor(
|
return await run_in_executor(
|
||||||
None, self._create_chat_result, response, generation_info
|
None, self._create_chat_result, response, generation_info
|
||||||
)
|
)
|
||||||
@ -2249,9 +2183,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
"""Set default stream_options."""
|
"""Set default stream_options."""
|
||||||
if "tools" in kwargs and any(
|
if _use_response_api(kwargs):
|
||||||
_is_builtin_tool(tool) for tool in kwargs["tools"]
|
|
||||||
):
|
|
||||||
return super()._stream_responses(*args, **kwargs)
|
return super()._stream_responses(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||||
@ -2269,9 +2201,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
"""Set default stream_options."""
|
"""Set default stream_options."""
|
||||||
if "tools" in kwargs and any(
|
if _use_response_api(kwargs):
|
||||||
_is_builtin_tool(tool) for tool in kwargs["tools"]
|
|
||||||
):
|
|
||||||
async for chunk in super()._astream_responses(*args, **kwargs):
|
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
@ -2818,3 +2748,162 @@ def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata:
|
|||||||
**{k: v for k, v in output_token_details.items() if v is not None}
|
**{k: v for k, v in output_token_details.items() if v is not None}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_builtin_tool(tool: dict) -> bool:
|
||||||
|
return "type" in tool and tool["type"] != "function"
|
||||||
|
|
||||||
|
|
||||||
|
def _use_response_api(payload: dict) -> bool:
|
||||||
|
return "tools" in payload and any(
|
||||||
|
_is_builtin_tool(tool) for tool in payload["tools"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_response_api_input(messages: list[BaseMessage]) -> list:
|
||||||
|
input_ = []
|
||||||
|
for lc_msg in messages:
|
||||||
|
msg = _convert_message_to_dict(lc_msg)
|
||||||
|
if msg["role"] == "tool":
|
||||||
|
function_call_output = {
|
||||||
|
"type": "function_call_output",
|
||||||
|
"output": msg["content"],
|
||||||
|
"call_id": msg["tool_call_id"],
|
||||||
|
}
|
||||||
|
input_.append(function_call_output)
|
||||||
|
elif msg["role"] == "assistant":
|
||||||
|
if tool_calls := msg.pop("tool_calls", None):
|
||||||
|
if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY):
|
||||||
|
raise ValueError(...)
|
||||||
|
function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY]
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
function_call = {
|
||||||
|
"type": "function_call",
|
||||||
|
"name": tool_call["name"],
|
||||||
|
"arguments": tool_call["arguments"],
|
||||||
|
"call_id": tool_call["id"],
|
||||||
|
"id": function_call_ids[tool_call["id"]],
|
||||||
|
}
|
||||||
|
input_.append(function_call)
|
||||||
|
if msg.get("content"):
|
||||||
|
input_.append(msg)
|
||||||
|
elif msg["role"] == "user":
|
||||||
|
if isinstance(msg["content"], list):
|
||||||
|
for block in msg["content"]:
|
||||||
|
# chat api: {"type": "text", "text": "..."}
|
||||||
|
# response api: {"type": "input_text", "text": "..."}
|
||||||
|
if block["type"] == "text":
|
||||||
|
block["type"] = "input_text"
|
||||||
|
# chat api: {"type": "image_url", "image_url": {"url": "...", "detail": "..."}} # noqa: E501
|
||||||
|
# response api: {"type": "image_url", "image_url": "...", "detail": "...", "file_id": "..."} # noqa: E501
|
||||||
|
elif block["type"] == "image_url":
|
||||||
|
block["type"] = "input_image"
|
||||||
|
if isinstance(block.get("image_url"), dict):
|
||||||
|
image_url = block.pop("image_url")
|
||||||
|
block["image_url"] = image_url["url"]
|
||||||
|
if image_url.get("detail"):
|
||||||
|
block["detail"] = image_url["detail"]
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
input_.append(msg)
|
||||||
|
else:
|
||||||
|
input_.append(msg)
|
||||||
|
|
||||||
|
return input_
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_lc_result_from_response_api(response: Response) -> ChatResult:
|
||||||
|
"""Construct ChatResponse from OpenAI Response API response."""
|
||||||
|
if response.error:
|
||||||
|
raise ValueError(response.error)
|
||||||
|
|
||||||
|
response_metadata = {
|
||||||
|
k: v
|
||||||
|
for k, v in response.model_dump(exclude_none=True, mode="json").items()
|
||||||
|
if k
|
||||||
|
in (
|
||||||
|
"created_at",
|
||||||
|
"id",
|
||||||
|
"incomplete_details",
|
||||||
|
"metadata",
|
||||||
|
"object",
|
||||||
|
"status",
|
||||||
|
"user",
|
||||||
|
"model",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# for compatibility with chat completion calls.
|
||||||
|
response_metadata["model_name"] = response.get("model")
|
||||||
|
if response.usage:
|
||||||
|
usage_metadata = _create_usage_metadata_responses(response.usage.model_dump())
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
|
|
||||||
|
content_blocks = []
|
||||||
|
tool_calls = []
|
||||||
|
invalid_tool_calls = []
|
||||||
|
additional_kwargs: dict = {}
|
||||||
|
msg_id = None
|
||||||
|
for output in response.output:
|
||||||
|
if output.type == "message":
|
||||||
|
for content in output.content:
|
||||||
|
if content.type == "output_text":
|
||||||
|
block = {
|
||||||
|
"type": "text",
|
||||||
|
"text": content.text,
|
||||||
|
"annotations": [
|
||||||
|
annotation.model_dump()
|
||||||
|
for annotation in content.annotations
|
||||||
|
],
|
||||||
|
}
|
||||||
|
content_blocks.append(block)
|
||||||
|
if content.type == "refusal":
|
||||||
|
additional_kwargs["refusal"] = content.refusal
|
||||||
|
msg_id = output.id
|
||||||
|
elif output.type == "function_call":
|
||||||
|
try:
|
||||||
|
args = json.loads(output.arguments, strict=False)
|
||||||
|
error = None
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
args = output.arguments
|
||||||
|
error = str(e)
|
||||||
|
if error is None:
|
||||||
|
tool_call = {
|
||||||
|
"type": "tool_call",
|
||||||
|
"name": output.name,
|
||||||
|
"args": args,
|
||||||
|
"id": output.call_id,
|
||||||
|
}
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
else:
|
||||||
|
tool_call = {
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
"name": output.name,
|
||||||
|
"args": args,
|
||||||
|
"id": output.call_id,
|
||||||
|
"error": error,
|
||||||
|
}
|
||||||
|
invalid_tool_calls.append(tool_call)
|
||||||
|
if _FUNCTION_CALL_IDS_MAP_KEY not in additional_kwargs:
|
||||||
|
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] = {}
|
||||||
|
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][output.call_id] = output.id
|
||||||
|
elif output.type == "reasoning":
|
||||||
|
additional_kwargs["reasoning"] = output.model_dump(
|
||||||
|
exclude_none=True, mode="json"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_output = output.model_dump(exclude_none=True, mode="json")
|
||||||
|
if "tool_outputs" in additional_kwargs:
|
||||||
|
additional_kwargs["tool_outputs"].append(tool_output)
|
||||||
|
else:
|
||||||
|
additional_kwargs["tool_outputs"] = [tool_output]
|
||||||
|
message = AIMessage(
|
||||||
|
content=content_blocks or None, # type: ignore[arg-type]
|
||||||
|
id=msg_id,
|
||||||
|
usage_metadata=usage_metadata,
|
||||||
|
response_metadata=response_metadata,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
|
)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
@ -1258,14 +1258,14 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
|||||||
assert response.usage_metadata["output_tokens"] > 0
|
assert response.usage_metadata["output_tokens"] > 0
|
||||||
assert response.usage_metadata["total_tokens"] > 0
|
assert response.usage_metadata["total_tokens"] > 0
|
||||||
assert response.response_metadata["model_name"]
|
assert response.response_metadata["model_name"]
|
||||||
for tool_output in response.response_metadata["tool_outputs"]:
|
for tool_output in response.additional_kwargs["tool_outputs"]:
|
||||||
assert tool_output["id"]
|
assert tool_output["id"]
|
||||||
assert tool_output["status"]
|
assert tool_output["status"]
|
||||||
assert tool_output["type"]
|
assert tool_output["type"]
|
||||||
|
|
||||||
|
|
||||||
def test_web_search() -> None:
|
def test_web_search() -> None:
|
||||||
llm = ChatOpenAI(model="gpt-4o")
|
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||||
response = llm.invoke(
|
response = llm.invoke(
|
||||||
"What was a positive news story from today?",
|
"What was a positive news story from today?",
|
||||||
tools=[{"type": "web_search_preview"}],
|
tools=[{"type": "web_search_preview"}],
|
||||||
@ -1283,9 +1283,16 @@ def test_web_search() -> None:
|
|||||||
full = chunk if full is None else full + chunk
|
full = chunk if full is None else full + chunk
|
||||||
_check_response(full)
|
_check_response(full)
|
||||||
|
|
||||||
|
llm.invoke(
|
||||||
|
"what about a negative one",
|
||||||
|
tools=[{"type": "web_search_preview"}],
|
||||||
|
response_id=response.response_metadata["id"]
|
||||||
|
)
|
||||||
|
_check_response(response)
|
||||||
|
|
||||||
|
|
||||||
async def test_web_search_async() -> None:
|
async def test_web_search_async() -> None:
|
||||||
llm = ChatOpenAI(model="gpt-4o")
|
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||||
response = await llm.ainvoke(
|
response = await llm.ainvoke(
|
||||||
"What was a positive news story from today?",
|
"What was a positive news story from today?",
|
||||||
tools=[{"type": "web_search_preview"}],
|
tools=[{"type": "web_search_preview"}],
|
||||||
@ -1307,7 +1314,7 @@ async def test_web_search_async() -> None:
|
|||||||
|
|
||||||
def test_file_search() -> None:
|
def test_file_search() -> None:
|
||||||
pytest.skip() # TODO: set up infra
|
pytest.skip() # TODO: set up infra
|
||||||
llm = ChatOpenAI(model="gpt-4o")
|
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||||
tool = {
|
tool = {
|
||||||
"type": "file_search",
|
"type": "file_search",
|
||||||
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
|
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
|
||||||
|
Loading…
Reference in New Issue
Block a user