mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
openai, deepseek: make _convert_chunk_to_generation_chunk an instance method (#29731)
1. Make `_convert_chunk_to_generation_chunk` an instance method on BaseChatOpenAI 2. Override on ChatDeepSeek to add `"reasoning_content"` to message additional_kwargs. Resolves https://github.com/langchain-ai/langchain/issues/29513
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
"""DeepSeek chat models."""
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import from_env, secret_from_env
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
@@ -218,3 +219,23 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
)
|
||||
|
||||
return rtn
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: Type,
|
||||
base_generation_info: Optional[Dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
base_generation_info,
|
||||
)
|
||||
if (choices := chunk.get("choices")) and generation_chunk:
|
||||
top = choices[0]
|
||||
if reasoning_content := top.get("delta", {}).get("reasoning_content"):
|
||||
if isinstance(generation_chunk.message, AIMessageChunk):
|
||||
generation_chunk.message.additional_kwargs["reasoning_content"] = (
|
||||
reasoning_content
|
||||
)
|
||||
return generation_chunk
|
||||
|
@@ -1,9 +1,10 @@
|
||||
"""Test ChatDeepSeek chat model."""
|
||||
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
@@ -38,3 +39,13 @@ def test_reasoning_content() -> None:
|
||||
assert response.content
|
||||
assert response.additional_kwargs["reasoning_content"]
|
||||
raise ValueError()
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Takes > 30s to run.")
|
||||
def test_reasoning_content_streaming() -> None:
|
||||
chat_model = ChatDeepSeek(model="deepseek-reasoner")
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in chat_model.stream("What is the square root of 256256?"):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.additional_kwargs["reasoning_content"]
|
||||
|
@@ -316,57 +316,6 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content, id=id_) # type: ignore
|
||||
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
|
||||
return None
|
||||
token_usage = chunk.get("usage")
|
||||
choices = (
|
||||
chunk.get("choices", [])
|
||||
# from beta.chat.completions.stream
|
||||
or chunk.get("chunk", {}).get("choices", [])
|
||||
)
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = (
|
||||
_create_usage_metadata(token_usage) if token_usage else None
|
||||
)
|
||||
if len(choices) == 0:
|
||||
# logprobs is implicitly None
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(content="", usage_metadata=usage_metadata)
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
if choice["delta"] is None:
|
||||
return None
|
||||
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
|
||||
def _update_token_usage(
|
||||
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
|
||||
) -> Union[int, dict]:
|
||||
@@ -692,6 +641,59 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
combined["system_fingerprint"] = system_fingerprint
|
||||
return combined
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: Type,
|
||||
base_generation_info: Optional[Dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
|
||||
return None
|
||||
token_usage = chunk.get("usage")
|
||||
choices = (
|
||||
chunk.get("choices", [])
|
||||
# from beta.chat.completions.stream
|
||||
or chunk.get("chunk", {}).get("choices", [])
|
||||
)
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = (
|
||||
_create_usage_metadata(token_usage) if token_usage else None
|
||||
)
|
||||
if len(choices) == 0:
|
||||
# logprobs is implicitly None
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(content="", usage_metadata=usage_metadata)
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
if choice["delta"] is None:
|
||||
return None
|
||||
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@@ -727,7 +729,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
generation_chunk = _convert_chunk_to_generation_chunk(
|
||||
generation_chunk = self._convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
base_generation_info if is_first_chunk else {},
|
||||
@@ -895,7 +897,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
async for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
generation_chunk = _convert_chunk_to_generation_chunk(
|
||||
generation_chunk = self._convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
base_generation_info if is_first_chunk else {},
|
||||
|
Reference in New Issue
Block a user