From 58e72febeb6cf91f92c05227e50ea8515ee73c94 Mon Sep 17 00:00:00 2001 From: Hyman Date: Sat, 24 Aug 2024 07:59:14 +0800 Subject: [PATCH] openai:compatible with other llm usage meta data (#24500) - [ ] **PR message**: - **Description:** Compatible with other llm (eg: deepseek-chat, glm-4) usage meta data - **Issue:** N/A - **Dependencies:** no new dependencies added - [ ] **Add tests and docs**: libs/partners/openai/tests/unit_tests/chat_models/test_base.py ```shell cd libs/partners/openai poetry run pytest tests/unit_tests/chat_models/test_base.py::test_openai_astream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_openai_stream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_deepseek_astream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_deepseek_stream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_glm4_astream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_glm4_stream ``` --------- Co-authored-by: hyman Co-authored-by: Erick Friis --- .../langchain_openai/chat_models/base.py | 150 +++++----- .../tests/unit_tests/chat_models/test_base.py | 281 ++++++++++++++++++ 2 files changed, 351 insertions(+), 80 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index ef7a2559b6b..ca54ce32e6e 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -284,6 +284,57 @@ def _convert_delta_to_message_chunk( return default_class(content=content, id=id_) # type: ignore +def _convert_chunk_to_generation_chunk( + chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] +) -> Optional[ChatGenerationChunk]: + token_usage = chunk.get("usage") + choices = chunk.get("choices", []) + usage_metadata: Optional[UsageMetadata] = ( + UsageMetadata( + input_tokens=token_usage.get("prompt_tokens", 0), + output_tokens=token_usage.get("completion_tokens", 0), + total_tokens=token_usage.get("total_tokens", 0), + ) + if token_usage + else None + ) + + if len(choices) == 0: + # logprobs is implicitly None + generation_chunk = ChatGenerationChunk( + message=default_chunk_class(content="", usage_metadata=usage_metadata) + ) + return generation_chunk + + choice = choices[0] + if choice["delta"] is None: + return None + + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {**base_generation_info} if base_generation_info else {} + + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + + if usage_metadata and isinstance(message_chunk, AIMessageChunk): + message_chunk.usage_metadata = usage_metadata + + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + return generation_chunk + + class _FunctionCall(TypedDict): name: str @@ -561,43 +612,15 @@ class BaseChatOpenAI(BaseChatModel): for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - if token_usage := chunk.get("usage"): - usage_metadata = UsageMetadata( - input_tokens=token_usage.get("prompt_tokens", 0), - output_tokens=token_usage.get("completion_tokens", 0), - total_tokens=token_usage.get("total_tokens", 0), - ) - generation_chunk = ChatGenerationChunk( - message=default_chunk_class( # type: ignore[call-arg] - content="", usage_metadata=usage_metadata - ) - ) - logprobs = None - else: - continue - else: - choice = chunk["choices"][0] - if choice["delta"] is None: - continue - message_chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {**base_generation_info} if is_first_chunk else {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - if model_name := chunk.get("model"): - generation_info["model_name"] = model_name - if system_fingerprint := chunk.get("system_fingerprint"): - generation_info["system_fingerprint"] = system_fingerprint - - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = message_chunk.__class__ - generation_chunk = ChatGenerationChunk( - message=message_chunk, generation_info=generation_info or None - ) + generation_chunk = _convert_chunk_to_generation_chunk( + chunk, + default_chunk_class, + base_generation_info if is_first_chunk else {}, + ) + if generation_chunk is None: + continue + default_chunk_class = generation_chunk.message.__class__ + logprobs = (generation_chunk.generation_info or {}).get("logprobs") if run_manager: run_manager.on_llm_new_token( generation_chunk.text, chunk=generation_chunk, logprobs=logprobs @@ -744,51 +767,18 @@ class BaseChatOpenAI(BaseChatModel): async for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - if token_usage := chunk.get("usage"): - usage_metadata = UsageMetadata( - input_tokens=token_usage.get("prompt_tokens", 0), - output_tokens=token_usage.get("completion_tokens", 0), - total_tokens=token_usage.get("total_tokens", 0), - ) - generation_chunk = ChatGenerationChunk( - message=default_chunk_class( # type: ignore[call-arg] - content="", usage_metadata=usage_metadata - ) - ) - logprobs = None - else: - continue - else: - choice = chunk["choices"][0] - if choice["delta"] is None: - continue - message_chunk = await run_in_executor( - None, - _convert_delta_to_message_chunk, - choice["delta"], - default_chunk_class, - ) - generation_info = {**base_generation_info} if is_first_chunk else {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - if model_name := chunk.get("model"): - generation_info["model_name"] = model_name - if system_fingerprint := chunk.get("system_fingerprint"): - generation_info["system_fingerprint"] = system_fingerprint - - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = message_chunk.__class__ - generation_chunk = ChatGenerationChunk( - message=message_chunk, generation_info=generation_info or None - ) + generation_chunk = _convert_chunk_to_generation_chunk( + chunk, + default_chunk_class, + base_generation_info if is_first_chunk else {}, + ) + if generation_chunk is None: + continue + default_chunk_class = generation_chunk.message.__class__ + logprobs = (generation_chunk.generation_info or {}).get("logprobs") if run_manager: await run_manager.on_llm_new_token( - token=generation_chunk.text, - chunk=generation_chunk, - logprobs=logprobs, + generation_chunk.text, chunk=generation_chunk, logprobs=logprobs ) is_first_chunk = False yield generation_chunk diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 4f635ad7aaa..4041718368c 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1,12 +1,14 @@ """Test OpenAI Chat API wrapper.""" import json +from types import TracebackType from typing import Any, Dict, List, Literal, Optional, Type, Union from unittest.mock import AsyncMock, MagicMock, patch import pytest from langchain_core.messages import ( AIMessage, + AIMessageChunk, FunctionMessage, HumanMessage, InvalidToolCall, @@ -14,6 +16,7 @@ from langchain_core.messages import ( ToolCall, ToolMessage, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.pydantic_v1 import BaseModel from langchain_openai import ChatOpenAI @@ -172,6 +175,284 @@ def test__convert_dict_to_message_tool_call() -> None: assert reverted_message_dict == message +class MockAsyncContextManager: + def __init__(self, chunk_list: list): + self.current_chunk = 0 + self.chunk_list = chunk_list + self.chunk_num = len(chunk_list) + + async def __aenter__(self) -> "MockAsyncContextManager": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + pass + + def __aiter__(self) -> "MockAsyncContextManager": + return self + + async def __anext__(self) -> dict: + if self.current_chunk < self.chunk_num: + chunk = self.chunk_list[self.current_chunk] + self.current_chunk += 1 + return chunk + else: + raise StopAsyncIteration + + +class MockSyncContextManager: + def __init__(self, chunk_list: list): + self.current_chunk = 0 + self.chunk_list = chunk_list + self.chunk_num = len(chunk_list) + + def __enter__(self) -> "MockSyncContextManager": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + pass + + def __iter__(self) -> "MockSyncContextManager": + return self + + def __next__(self) -> dict: + if self.current_chunk < self.chunk_num: + chunk = self.chunk_list[self.current_chunk] + self.current_chunk += 1 + return chunk + else: + raise StopIteration + + +GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4eba\u5de5\u667a\u80fd"}}]} +{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]} +{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":","}}]} +{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4f60\u53ef\u4ee5"}}]} +{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u53eb\u6211"}}]} +{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"AI"}}]} +{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]} +{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"。"}}]} +{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"finish_reason":"stop","delta":{"role":"assistant","content":""}}],"usage":{"prompt_tokens":13,"completion_tokens":10,"total_tokens":23}} +[DONE]""" # noqa: E501 + + +@pytest.fixture +def mock_glm4_completion() -> list: + list_chunk_data = GLM4_STREAM_META.split("\n") + result_list = [] + for msg in list_chunk_data: + if msg != "[DONE]": + result_list.append(json.loads(msg)) + + return result_list + + +async def test_glm4_astream(mock_glm4_completion: list) -> None: + llm_name = "glm-4" + llm = ChatOpenAI(model=llm_name, stream_usage=True) + mock_client = AsyncMock() + + async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager: + return MockAsyncContextManager(mock_glm4_completion) + + mock_client.create = mock_create + usage_chunk = mock_glm4_completion[-1] + + usage_metadata: Optional[UsageMetadata] = None + with patch.object(llm, "async_client", mock_client): + async for chunk in llm.astream("你的名字叫什么?只回答名字"): + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + usage_metadata = chunk.usage_metadata + + assert usage_metadata is not None + + assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"] + assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"] + assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"] + + +def test_glm4_stream(mock_glm4_completion: list) -> None: + llm_name = "glm-4" + llm = ChatOpenAI(model=llm_name, stream_usage=True) + mock_client = MagicMock() + + def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: + return MockSyncContextManager(mock_glm4_completion) + + mock_client.create = mock_create + usage_chunk = mock_glm4_completion[-1] + + usage_metadata: Optional[UsageMetadata] = None + with patch.object(llm, "client", mock_client): + for chunk in llm.stream("你的名字叫什么?只回答名字"): + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + usage_metadata = chunk.usage_metadata + + assert usage_metadata is not None + + assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"] + assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"] + assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"] + + +DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"index":0,"delta":{"content":"","role":"assistant"},"finish_reason":null,"logprobs":null}],"created":1721630271,"model":"deepseek-chat","system_fingerprint":"fp_7e0991cad4","object":"chat.completion.chunk","usage":null} +{"choices":[{"delta":{"content":"我是","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"Deep","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"Seek","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":" Chat","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":",","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"一个","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"由","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"深度","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"求","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"索","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"公司","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"开发的","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"智能","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"助手","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"。","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} +{"choices":[{"delta":{"content":"","role":null},"finish_reason":"stop","index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":{"completion_tokens":15,"prompt_tokens":11,"total_tokens":26}} +[DONE]""" # noqa: E501 + + +@pytest.fixture +def mock_deepseek_completion() -> List[Dict]: + list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n") + result_list = [] + for msg in list_chunk_data: + if msg != "[DONE]": + result_list.append(json.loads(msg)) + + return result_list + + +async def test_deepseek_astream(mock_deepseek_completion: list) -> None: + llm_name = "deepseek-chat" + llm = ChatOpenAI(model=llm_name, stream_usage=True) + mock_client = AsyncMock() + + async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager: + return MockAsyncContextManager(mock_deepseek_completion) + + mock_client.create = mock_create + usage_chunk = mock_deepseek_completion[-1] + usage_metadata: Optional[UsageMetadata] = None + with patch.object(llm, "async_client", mock_client): + async for chunk in llm.astream("你的名字叫什么?只回答名字"): + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + usage_metadata = chunk.usage_metadata + + assert usage_metadata is not None + + assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"] + assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"] + assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"] + + +def test_deepseek_stream(mock_deepseek_completion: list) -> None: + llm_name = "deepseek-chat" + llm = ChatOpenAI(model=llm_name, stream_usage=True) + mock_client = MagicMock() + + def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: + return MockSyncContextManager(mock_deepseek_completion) + + mock_client.create = mock_create + usage_chunk = mock_deepseek_completion[-1] + usage_metadata: Optional[UsageMetadata] = None + with patch.object(llm, "client", mock_client): + for chunk in llm.stream("你的名字叫什么?只回答名字"): + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + usage_metadata = chunk.usage_metadata + + assert usage_metadata is not None + + assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"] + assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"] + assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"] + + +OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null} +{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":"我是"},"logprobs":null,"finish_reason":null}],"usage":null} +{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":"助手"},"logprobs":null,"finish_reason":null}],"usage":null} +{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":"。"},"logprobs":null,"finish_reason":null}],"usage":null} +{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} +{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[],"usage":{"prompt_tokens":14,"completion_tokens":3,"total_tokens":17}} +[DONE]""" # noqa: E501 + + +@pytest.fixture +def mock_openai_completion() -> List[Dict]: + list_chunk_data = OPENAI_STREAM_DATA.split("\n") + result_list = [] + for msg in list_chunk_data: + if msg != "[DONE]": + result_list.append(json.loads(msg)) + + return result_list + + +async def test_openai_astream(mock_openai_completion: list) -> None: + llm_name = "gpt-4o" + llm = ChatOpenAI(model=llm_name, stream_usage=True) + mock_client = AsyncMock() + + async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager: + return MockAsyncContextManager(mock_openai_completion) + + mock_client.create = mock_create + usage_chunk = mock_openai_completion[-1] + usage_metadata: Optional[UsageMetadata] = None + with patch.object(llm, "async_client", mock_client): + async for chunk in llm.astream("你的名字叫什么?只回答名字"): + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + usage_metadata = chunk.usage_metadata + + assert usage_metadata is not None + + assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"] + assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"] + assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"] + + +def test_openai_stream(mock_openai_completion: list) -> None: + llm_name = "gpt-4o" + llm = ChatOpenAI(model=llm_name, stream_usage=True) + mock_client = MagicMock() + + def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: + return MockSyncContextManager(mock_openai_completion) + + mock_client.create = mock_create + usage_chunk = mock_openai_completion[-1] + usage_metadata: Optional[UsageMetadata] = None + with patch.object(llm, "client", mock_client): + for chunk in llm.stream("你的名字叫什么?只回答名字"): + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + usage_metadata = chunk.usage_metadata + + assert usage_metadata is not None + + assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"] + assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"] + assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"] + + @pytest.fixture def mock_completion() -> dict: return {