This commit is contained in:
Chester Curme 2025-03-11 15:18:03 -04:00
parent d4f2aad929
commit 1460539bd1
2 changed files with 12 additions and 3 deletions

View File

@ -15,6 +15,7 @@ from io import BytesIO
from math import ceil from math import ceil
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
AsyncIterator, AsyncIterator,
Callable, Callable,
@ -100,11 +101,13 @@ from langchain_core.utils.pydantic import (
is_basemodel_subclass, is_basemodel_subclass,
) )
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from openai.types.responses import Response
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self from typing_extensions import Self
if TYPE_CHECKING:
from openai.types.responses import Response
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# This SSL context is equivelent to the default `verify=True`. # This SSL context is equivelent to the default `verify=True`.
@ -921,11 +924,15 @@ class BaseChatOpenAI(BaseChatModel):
if error := response.error: if error := response.error:
raise ValueError(error) raise ValueError(error)
token_usage = response.usage.model_dump() token_usage = response.usage.model_dump() if response.usage else {}
generation_info = {} generation_info = {}
for output in response.output: for output in response.output:
if output.type == "message": if output.type == "message":
joined = "".join(content.text for content in output.content) joined = "".join(
content.text
for content in output.content
if content.type == "output_text"
)
usage_metadata = _create_usage_metadata_responses(token_usage) usage_metadata = _create_usage_metadata_responses(token_usage)
message = AIMessage( message = AIMessage(
content=joined, id=output.id, usage_metadata=usage_metadata content=joined, id=output.id, usage_metadata=usage_metadata

View File

@ -1238,6 +1238,7 @@ def test_web_search() -> None:
) )
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert response.content assert response.content
assert response.usage_metadata
assert response.usage_metadata["input_tokens"] > 0 assert response.usage_metadata["input_tokens"] > 0
assert response.usage_metadata["output_tokens"] > 0 assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0 assert response.usage_metadata["total_tokens"] > 0
@ -1253,6 +1254,7 @@ async def test_web_search_async() -> None:
) )
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert response.content assert response.content
assert response.usage_metadata
assert response.usage_metadata["input_tokens"] > 0 assert response.usage_metadata["input_tokens"] > 0
assert response.usage_metadata["output_tokens"] > 0 assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0 assert response.usage_metadata["total_tokens"] > 0