diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 213e036c6dc..84459cc4d29 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -58,6 +58,7 @@ from langchain_core.messages import ( ToolMessage, ToolMessageChunk, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.output_parsers import ( JsonOutputParser, PydanticOutputParser, @@ -483,23 +484,36 @@ class BaseChatOpenAI(BaseChatModel): if not isinstance(chunk, dict): chunk = chunk.model_dump() if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - if choice["delta"] is None: - continue - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None - ) + if token_usage := chunk.get("usage"): + usage_metadata = UsageMetadata( + input_tokens=token_usage.get("prompt_tokens", 0), + output_tokens=token_usage.get("completion_tokens", 0), + total_tokens=token_usage.get("total_tokens", 0), + ) + chunk = ChatGenerationChunk( + message=default_chunk_class( + content="", usage_metadata=usage_metadata + ) + ) + else: + continue + else: + choice = chunk["choices"][0] + if choice["delta"] is None: + continue + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) if run_manager: run_manager.on_llm_new_token( chunk.text, chunk=chunk, logprobs=logprobs @@ -589,23 +603,36 @@ class BaseChatOpenAI(BaseChatModel): if not isinstance(chunk, dict): chunk = chunk.model_dump() if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - if choice["delta"] is None: - continue - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None - ) + if token_usage := chunk.get("usage"): + usage_metadata = UsageMetadata( + input_tokens=token_usage.get("prompt_tokens", 0), + output_tokens=token_usage.get("completion_tokens", 0), + total_tokens=token_usage.get("total_tokens", 0), + ) + chunk = ChatGenerationChunk( + message=default_chunk_class( + content="", usage_metadata=usage_metadata + ) + ) + else: + continue + else: + choice = chunk["choices"][0] + if choice["delta"] is None: + continue + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) if run_manager: await run_manager.on_llm_new_token( token=chunk.text, chunk=chunk, logprobs=logprobs @@ -1135,6 +1162,29 @@ class ChatOpenAI(BaseChatOpenAI): """Return whether this model can be serialized by Langchain.""" return True + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: + """Set default stream_options.""" + default_stream_options = {"include_usage": True} + stream_options = kwargs.get("stream_options", {}) + merged_stream_options = {**default_stream_options, **stream_options} + kwargs["stream_options"] = merged_stream_options + + return super()._stream(*args, **kwargs) + + async def _astream( + self, + *args: Any, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Set default stream_options.""" + default_stream_options = {"include_usage": True} + stream_options = kwargs.get("stream_options", {}) + merged_stream_options = {**default_stream_options, **stream_options} + kwargs["stream_options"] = merged_stream_options + + async for chunk in super()._astream(*args, **kwargs): + yield chunk + def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and issubclass(obj, BaseModel) diff --git a/libs/partners/openai/poetry.lock b/libs/partners/openai/poetry.lock index a45ada7a989..3c00a51e271 100644 --- a/libs/partners/openai/poetry.lock +++ b/libs/partners/openai/poetry.lock @@ -1268,4 +1268,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "0d142c2f1c4d08dbdcbd15b8005ffa606d4d36781c52e9411993cabd41de261b" +content-hash = "62f0a24221c64dc8035ccf7cca3f8ac2eaf47d653a441645c8021120833ecb52" diff --git a/libs/partners/openai/pyproject.toml b/libs/partners/openai/pyproject.toml index bbed53a4a84..aed9d0274f1 100644 --- a/libs/partners/openai/pyproject.toml +++ b/libs/partners/openai/pyproject.toml @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = {version =">=0.2.2rc1,<0.3", allow-prereleases=true} -openai = "^1.24.0" +openai = "^1.26.0" tiktoken = ">=0.7,<1" [tool.poetry.group.test] diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index e86c457932c..db87949eb64 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -346,9 +346,34 @@ def test_stream() -> None: llm = ChatOpenAI() full: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 for chunk in llm.stream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + chunks_with_token_counts += 1 + if chunks_with_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + + # check token usage is populated + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 + + # check not populated + aggregate: Optional[BaseMessageChunk] = None + for chunk in llm.stream("Hello", stream_options={"include_usage": False}): + assert isinstance(chunk.content, str) + aggregate = chunk if aggregate is None else aggregate + chunk + assert isinstance(aggregate, AIMessageChunk) + assert aggregate.usage_metadata is None async def test_astream() -> None: @@ -356,9 +381,34 @@ async def test_astream() -> None: llm = ChatOpenAI() full: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 async for chunk in llm.astream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + chunks_with_token_counts += 1 + if chunks_with_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + + # check token usage is populated + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 + + # check not populated + aggregate: Optional[BaseMessageChunk] = None + async for chunk in llm.astream("Hello", stream_options={"include_usage": False}): + assert isinstance(chunk.content, str) + aggregate = chunk if aggregate is None else aggregate + chunk + assert isinstance(aggregate, AIMessageChunk) + assert aggregate.usage_metadata is None async def test_abatch() -> None: