openai, deepseek: make _convert_chunk_to_generation_chunk an instance method (#29731)

1. Make `_convert_chunk_to_generation_chunk` an instance method on
BaseChatOpenAI
2. Override on ChatDeepSeek to add `"reasoning_content"` to message
additional_kwargs.

Resolves https://github.com/langchain-ai/langchain/issues/29513
This commit is contained in:
ccurme
2025-02-11 11:13:23 -08:00
committed by GitHub
parent 1edd27d860
commit 9477f49409
3 changed files with 90 additions and 56 deletions

View File

@@ -1,9 +1,10 @@
"""DeepSeek chat models."""
from typing import Dict, Optional, Union
from typing import Dict, Optional, Type, Union
import openai
from langchain_core.outputs import ChatResult
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.utils import from_env, secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import ConfigDict, Field, SecretStr, model_validator
@@ -218,3 +219,23 @@ class ChatDeepSeek(BaseChatOpenAI):
)
return rtn
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: Type,
base_generation_info: Optional[Dict],
) -> Optional[ChatGenerationChunk]:
generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info,
)
if (choices := chunk.get("choices")) and generation_chunk:
top = choices[0]
if reasoning_content := top.get("delta", {}).get("reasoning_content"):
if isinstance(generation_chunk.message, AIMessageChunk):
generation_chunk.message.additional_kwargs["reasoning_content"] = (
reasoning_content
)
return generation_chunk

View File

@@ -1,9 +1,10 @@
"""Test ChatDeepSeek chat model."""
from typing import Type
from typing import Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from langchain_core.tools import BaseTool
from langchain_tests.integration_tests import ChatModelIntegrationTests
@@ -38,3 +39,13 @@ def test_reasoning_content() -> None:
assert response.content
assert response.additional_kwargs["reasoning_content"]
raise ValueError()
@pytest.mark.xfail(reason="Takes > 30s to run.")
def test_reasoning_content_streaming() -> None:
chat_model = ChatDeepSeek(model="deepseek-reasoner")
full: Optional[BaseMessageChunk] = None
for chunk in chat_model.stream("What is the square root of 256256?"):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]

View File

@@ -316,57 +316,6 @@ def _convert_delta_to_message_chunk(
return default_class(content=content, id=id_) # type: ignore
def _convert_chunk_to_generation_chunk(
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
) -> Optional[ChatGenerationChunk]:
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
return None
token_usage = chunk.get("usage")
choices = (
chunk.get("choices", [])
# from beta.chat.completions.stream
or chunk.get("chunk", {}).get("choices", [])
)
usage_metadata: Optional[UsageMetadata] = (
_create_usage_metadata(token_usage) if token_usage else None
)
if len(choices) == 0:
# logprobs is implicitly None
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata)
)
return generation_chunk
choice = choices[0]
if choice["delta"] is None:
return None
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if base_generation_info else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
return generation_chunk
def _update_token_usage(
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
) -> Union[int, dict]:
@@ -692,6 +641,59 @@ class BaseChatOpenAI(BaseChatModel):
combined["system_fingerprint"] = system_fingerprint
return combined
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: Type,
base_generation_info: Optional[Dict],
) -> Optional[ChatGenerationChunk]:
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
return None
token_usage = chunk.get("usage")
choices = (
chunk.get("choices", [])
# from beta.chat.completions.stream
or chunk.get("chunk", {}).get("choices", [])
)
usage_metadata: Optional[UsageMetadata] = (
_create_usage_metadata(token_usage) if token_usage else None
)
if len(choices) == 0:
# logprobs is implicitly None
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata)
)
return generation_chunk
choice = choices[0]
if choice["delta"] is None:
return None
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if base_generation_info else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
return generation_chunk
def _stream(
self,
messages: List[BaseMessage],
@@ -727,7 +729,7 @@ class BaseChatOpenAI(BaseChatModel):
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
generation_chunk = _convert_chunk_to_generation_chunk(
generation_chunk = self._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info if is_first_chunk else {},
@@ -895,7 +897,7 @@ class BaseChatOpenAI(BaseChatModel):
async for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
generation_chunk = _convert_chunk_to_generation_chunk(
generation_chunk = self._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info if is_first_chunk else {},