diff --git a/libs/partners/deepseek/langchain_deepseek/chat_models.py b/libs/partners/deepseek/langchain_deepseek/chat_models.py index 969e52aecff..e7f3f3e3a7b 100644 --- a/libs/partners/deepseek/langchain_deepseek/chat_models.py +++ b/libs/partners/deepseek/langchain_deepseek/chat_models.py @@ -1,9 +1,10 @@ """DeepSeek chat models.""" -from typing import Dict, Optional, Union +from typing import Dict, Optional, Type, Union import openai -from langchain_core.outputs import ChatResult +from langchain_core.messages import AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_core.utils import from_env, secret_from_env from langchain_openai.chat_models.base import BaseChatOpenAI from pydantic import ConfigDict, Field, SecretStr, model_validator @@ -218,3 +219,23 @@ class ChatDeepSeek(BaseChatOpenAI): ) return rtn + + def _convert_chunk_to_generation_chunk( + self, + chunk: dict, + default_chunk_class: Type, + base_generation_info: Optional[Dict], + ) -> Optional[ChatGenerationChunk]: + generation_chunk = super()._convert_chunk_to_generation_chunk( + chunk, + default_chunk_class, + base_generation_info, + ) + if (choices := chunk.get("choices")) and generation_chunk: + top = choices[0] + if reasoning_content := top.get("delta", {}).get("reasoning_content"): + if isinstance(generation_chunk.message, AIMessageChunk): + generation_chunk.message.additional_kwargs["reasoning_content"] = ( + reasoning_content + ) + return generation_chunk diff --git a/libs/partners/deepseek/tests/integration_tests/test_chat_models.py b/libs/partners/deepseek/tests/integration_tests/test_chat_models.py index cf5dce6d521..521296c3ef6 100644 --- a/libs/partners/deepseek/tests/integration_tests/test_chat_models.py +++ b/libs/partners/deepseek/tests/integration_tests/test_chat_models.py @@ -1,9 +1,10 @@ """Test ChatDeepSeek chat model.""" -from typing import Type +from typing import Optional, Type import pytest from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessageChunk, BaseMessageChunk from langchain_core.tools import BaseTool from langchain_tests.integration_tests import ChatModelIntegrationTests @@ -38,3 +39,13 @@ def test_reasoning_content() -> None: assert response.content assert response.additional_kwargs["reasoning_content"] raise ValueError() + + +@pytest.mark.xfail(reason="Takes > 30s to run.") +def test_reasoning_content_streaming() -> None: + chat_model = ChatDeepSeek(model="deepseek-reasoner") + full: Optional[BaseMessageChunk] = None + for chunk in chat_model.stream("What is the square root of 256256?"): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["reasoning_content"] diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index bbda248d0b5..cfd5131c5db 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -316,57 +316,6 @@ def _convert_delta_to_message_chunk( return default_class(content=content, id=id_) # type: ignore -def _convert_chunk_to_generation_chunk( - chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] -) -> Optional[ChatGenerationChunk]: - if chunk.get("type") == "content.delta": # from beta.chat.completions.stream - return None - token_usage = chunk.get("usage") - choices = ( - chunk.get("choices", []) - # from beta.chat.completions.stream - or chunk.get("chunk", {}).get("choices", []) - ) - - usage_metadata: Optional[UsageMetadata] = ( - _create_usage_metadata(token_usage) if token_usage else None - ) - if len(choices) == 0: - # logprobs is implicitly None - generation_chunk = ChatGenerationChunk( - message=default_chunk_class(content="", usage_metadata=usage_metadata) - ) - return generation_chunk - - choice = choices[0] - if choice["delta"] is None: - return None - - message_chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {**base_generation_info} if base_generation_info else {} - - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - if model_name := chunk.get("model"): - generation_info["model_name"] = model_name - if system_fingerprint := chunk.get("system_fingerprint"): - generation_info["system_fingerprint"] = system_fingerprint - - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - - if usage_metadata and isinstance(message_chunk, AIMessageChunk): - message_chunk.usage_metadata = usage_metadata - - generation_chunk = ChatGenerationChunk( - message=message_chunk, generation_info=generation_info or None - ) - return generation_chunk - - def _update_token_usage( overall_token_usage: Union[int, dict], new_usage: Union[int, dict] ) -> Union[int, dict]: @@ -692,6 +641,59 @@ class BaseChatOpenAI(BaseChatModel): combined["system_fingerprint"] = system_fingerprint return combined + def _convert_chunk_to_generation_chunk( + self, + chunk: dict, + default_chunk_class: Type, + base_generation_info: Optional[Dict], + ) -> Optional[ChatGenerationChunk]: + if chunk.get("type") == "content.delta": # from beta.chat.completions.stream + return None + token_usage = chunk.get("usage") + choices = ( + chunk.get("choices", []) + # from beta.chat.completions.stream + or chunk.get("chunk", {}).get("choices", []) + ) + + usage_metadata: Optional[UsageMetadata] = ( + _create_usage_metadata(token_usage) if token_usage else None + ) + if len(choices) == 0: + # logprobs is implicitly None + generation_chunk = ChatGenerationChunk( + message=default_chunk_class(content="", usage_metadata=usage_metadata) + ) + return generation_chunk + + choice = choices[0] + if choice["delta"] is None: + return None + + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {**base_generation_info} if base_generation_info else {} + + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + + if usage_metadata and isinstance(message_chunk, AIMessageChunk): + message_chunk.usage_metadata = usage_metadata + + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + return generation_chunk + def _stream( self, messages: List[BaseMessage], @@ -727,7 +729,7 @@ class BaseChatOpenAI(BaseChatModel): for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() - generation_chunk = _convert_chunk_to_generation_chunk( + generation_chunk = self._convert_chunk_to_generation_chunk( chunk, default_chunk_class, base_generation_info if is_first_chunk else {}, @@ -895,7 +897,7 @@ class BaseChatOpenAI(BaseChatModel): async for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() - generation_chunk = _convert_chunk_to_generation_chunk( + generation_chunk = self._convert_chunk_to_generation_chunk( chunk, default_chunk_class, base_generation_info if is_first_chunk else {},