This commit is contained in:
Chester Curme 2025-03-11 15:08:32 -04:00
parent 81d1653a30
commit aee80f0e48
3 changed files with 79 additions and 7 deletions

View File

@ -100,6 +100,7 @@ from langchain_core.utils.pydantic import (
is_basemodel_subclass,
)
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from openai.types.responses import Response
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self
@ -405,6 +406,22 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
raise
def _is_builtin_tool(tool: dict) -> bool:
return set(tool.keys()) == {"type"}
def _transform_payload_for_responses(payload: dict) -> dict:
updated_payload = payload.copy()
if messages := updated_payload.pop("messages"):
last_user_message = next(
(m for m in reversed(messages) if m.get("role") == "user"), None
)
if last_user_message:
updated_payload["input"] = last_user_message["content"]
return updated_payload
class _FunctionCall(TypedDict):
name: str
@ -654,7 +671,7 @@ class BaseChatOpenAI(BaseChatModel):
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
token_usage = output.get("token_usage")
if token_usage is not None:
for k, v in token_usage.items():
if v is None:
@ -820,7 +837,14 @@ class BaseChatOpenAI(BaseChatModel):
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
if "tools" in payload and any(
_is_builtin_tool(tool) for tool in payload["tools"]
):
responses_payload = _transform_payload_for_responses(payload)
response = self.root_client.responses.create(**responses_payload)
return self._create_chat_result_responses(response, generation_info)
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info)
def _get_request_payload(
@ -889,6 +913,31 @@ class BaseChatOpenAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output)
def _create_chat_result_responses(
self, response: Response, generation_info: Optional[Dict] = None
) -> ChatResult:
generations = []
if error := response.error:
raise ValueError(error)
token_usage = response.usage.model_dump()
generation_info = {}
for output in response.output:
if output.type == "message":
joined = "".join(content.text for content in output.content)
usage_metadata = _create_usage_metadata_responses(token_usage)
message = AIMessage(
content=joined, id=output.id, usage_metadata=usage_metadata
)
if output.status:
generation_info["status"] = output.status
gen = ChatGeneration(message=message, generation_info=generation_info)
generations.append(gen)
llm_output = {"model_name": response.model}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
@ -2617,3 +2666,26 @@ def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
**{k: v for k, v in output_token_details.items() if v is not None}
),
)
def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata:
input_tokens = oai_token_usage.get("input_tokens", 0)
output_tokens = oai_token_usage.get("output_tokens", 0)
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
output_token_details: dict = {
"audio": (oai_token_usage.get("completion_tokens_details") or {}).get(
"audio_tokens"
),
"reasoning": (oai_token_usage.get("output_token_details") or {}).get(
"reasoning_tokens"
),
}
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
output_token_details=OutputTokenDetails(
**{k: v for k, v in output_token_details.items() if v is not None}
),
)

View File

@ -8,7 +8,7 @@ license = { text = "MIT" }
requires-python = "<4.0,>=3.9"
dependencies = [
"langchain-core<1.0.0,>=0.3.43",
"openai<2.0.0,>=1.58.1",
"openai<2.0.0,>=1.66.0",
"tiktoken<1,>=0.7",
]
name = "langchain-openai"

View File

@ -566,7 +566,7 @@ typing = [
[package.metadata]
requires-dist = [
{ name = "langchain-core", editable = "../../core" },
{ name = "openai", specifier = ">=1.58.1,<2.0.0" },
{ name = "openai", specifier = ">=1.66.0,<2.0.0" },
{ name = "tiktoken", specifier = ">=0.7,<1" },
]
@ -751,7 +751,7 @@ wheels = [
[[package]]
name = "openai"
version = "1.61.1"
version = "1.66.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@ -763,9 +763,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d9/cf/61e71ce64cf0a38f029da0f9a5f10c9fa0e69a7a977b537126dac50adfea/openai-1.61.1.tar.gz", hash = "sha256:ce1851507218209961f89f3520e06726c0aa7d0512386f0f977e3ac3e4f2472e", size = 350784 }
sdist = { url = "https://files.pythonhosted.org/packages/84/c5/3c422ca3ccc81c063955e7c20739d7f8f37fea0af865c4a60c81e6225e14/openai-1.66.0.tar.gz", hash = "sha256:8a9e672bc6eadec60a962f0b40d7d1c09050010179c919ed65322e433e2d1025", size = 396819 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9a/b6/2e2a011b2dc27a6711376808b4cd8c922c476ea0f1420b39892117fa8563/openai-1.61.1-py3-none-any.whl", hash = "sha256:72b0826240ce26026ac2cd17951691f046e5be82ad122d20a8e1b30ca18bd11e", size = 463126 },
{ url = "https://files.pythonhosted.org/packages/d7/f1/d52960dac9519c9de64593460826a0fe2e19159389ec97ecf3e931d2e6a3/openai-1.66.0-py3-none-any.whl", hash = "sha256:43e4a3c0c066cc5809be4e6aac456a3ebc4ec1848226ef9d1340859ac130d45a", size = 566389 },
]
[[package]]