mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
openai, deepseek: make _convert_chunk_to_generation_chunk an instance method (#29731)
1. Make `_convert_chunk_to_generation_chunk` an instance method on BaseChatOpenAI 2. Override on ChatDeepSeek to add `"reasoning_content"` to message additional_kwargs. Resolves https://github.com/langchain-ai/langchain/issues/29513
This commit is contained in:
@@ -1,9 +1,10 @@
|
|||||||
"""DeepSeek chat models."""
|
"""DeepSeek chat models."""
|
||||||
|
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Type, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from langchain_core.outputs import ChatResult
|
from langchain_core.messages import AIMessageChunk
|
||||||
|
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.utils import from_env, secret_from_env
|
from langchain_core.utils import from_env, secret_from_env
|
||||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||||
@@ -218,3 +219,23 @@ class ChatDeepSeek(BaseChatOpenAI):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
|
def _convert_chunk_to_generation_chunk(
|
||||||
|
self,
|
||||||
|
chunk: dict,
|
||||||
|
default_chunk_class: Type,
|
||||||
|
base_generation_info: Optional[Dict],
|
||||||
|
) -> Optional[ChatGenerationChunk]:
|
||||||
|
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
||||||
|
chunk,
|
||||||
|
default_chunk_class,
|
||||||
|
base_generation_info,
|
||||||
|
)
|
||||||
|
if (choices := chunk.get("choices")) and generation_chunk:
|
||||||
|
top = choices[0]
|
||||||
|
if reasoning_content := top.get("delta", {}).get("reasoning_content"):
|
||||||
|
if isinstance(generation_chunk.message, AIMessageChunk):
|
||||||
|
generation_chunk.message.additional_kwargs["reasoning_content"] = (
|
||||||
|
reasoning_content
|
||||||
|
)
|
||||||
|
return generation_chunk
|
||||||
|
@@ -1,9 +1,10 @@
|
|||||||
"""Test ChatDeepSeek chat model."""
|
"""Test ChatDeepSeek chat model."""
|
||||||
|
|
||||||
from typing import Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||||
|
|
||||||
@@ -38,3 +39,13 @@ def test_reasoning_content() -> None:
|
|||||||
assert response.content
|
assert response.content
|
||||||
assert response.additional_kwargs["reasoning_content"]
|
assert response.additional_kwargs["reasoning_content"]
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Takes > 30s to run.")
|
||||||
|
def test_reasoning_content_streaming() -> None:
|
||||||
|
chat_model = ChatDeepSeek(model="deepseek-reasoner")
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for chunk in chat_model.stream("What is the square root of 256256?"):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.additional_kwargs["reasoning_content"]
|
||||||
|
@@ -316,57 +316,6 @@ def _convert_delta_to_message_chunk(
|
|||||||
return default_class(content=content, id=id_) # type: ignore
|
return default_class(content=content, id=id_) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def _convert_chunk_to_generation_chunk(
|
|
||||||
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
|
|
||||||
) -> Optional[ChatGenerationChunk]:
|
|
||||||
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
|
|
||||||
return None
|
|
||||||
token_usage = chunk.get("usage")
|
|
||||||
choices = (
|
|
||||||
chunk.get("choices", [])
|
|
||||||
# from beta.chat.completions.stream
|
|
||||||
or chunk.get("chunk", {}).get("choices", [])
|
|
||||||
)
|
|
||||||
|
|
||||||
usage_metadata: Optional[UsageMetadata] = (
|
|
||||||
_create_usage_metadata(token_usage) if token_usage else None
|
|
||||||
)
|
|
||||||
if len(choices) == 0:
|
|
||||||
# logprobs is implicitly None
|
|
||||||
generation_chunk = ChatGenerationChunk(
|
|
||||||
message=default_chunk_class(content="", usage_metadata=usage_metadata)
|
|
||||||
)
|
|
||||||
return generation_chunk
|
|
||||||
|
|
||||||
choice = choices[0]
|
|
||||||
if choice["delta"] is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
message_chunk = _convert_delta_to_message_chunk(
|
|
||||||
choice["delta"], default_chunk_class
|
|
||||||
)
|
|
||||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
|
||||||
|
|
||||||
if finish_reason := choice.get("finish_reason"):
|
|
||||||
generation_info["finish_reason"] = finish_reason
|
|
||||||
if model_name := chunk.get("model"):
|
|
||||||
generation_info["model_name"] = model_name
|
|
||||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
|
||||||
generation_info["system_fingerprint"] = system_fingerprint
|
|
||||||
|
|
||||||
logprobs = choice.get("logprobs")
|
|
||||||
if logprobs:
|
|
||||||
generation_info["logprobs"] = logprobs
|
|
||||||
|
|
||||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
|
||||||
message_chunk.usage_metadata = usage_metadata
|
|
||||||
|
|
||||||
generation_chunk = ChatGenerationChunk(
|
|
||||||
message=message_chunk, generation_info=generation_info or None
|
|
||||||
)
|
|
||||||
return generation_chunk
|
|
||||||
|
|
||||||
|
|
||||||
def _update_token_usage(
|
def _update_token_usage(
|
||||||
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
|
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
|
||||||
) -> Union[int, dict]:
|
) -> Union[int, dict]:
|
||||||
@@ -692,6 +641,59 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
combined["system_fingerprint"] = system_fingerprint
|
combined["system_fingerprint"] = system_fingerprint
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
|
def _convert_chunk_to_generation_chunk(
|
||||||
|
self,
|
||||||
|
chunk: dict,
|
||||||
|
default_chunk_class: Type,
|
||||||
|
base_generation_info: Optional[Dict],
|
||||||
|
) -> Optional[ChatGenerationChunk]:
|
||||||
|
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
|
||||||
|
return None
|
||||||
|
token_usage = chunk.get("usage")
|
||||||
|
choices = (
|
||||||
|
chunk.get("choices", [])
|
||||||
|
# from beta.chat.completions.stream
|
||||||
|
or chunk.get("chunk", {}).get("choices", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
usage_metadata: Optional[UsageMetadata] = (
|
||||||
|
_create_usage_metadata(token_usage) if token_usage else None
|
||||||
|
)
|
||||||
|
if len(choices) == 0:
|
||||||
|
# logprobs is implicitly None
|
||||||
|
generation_chunk = ChatGenerationChunk(
|
||||||
|
message=default_chunk_class(content="", usage_metadata=usage_metadata)
|
||||||
|
)
|
||||||
|
return generation_chunk
|
||||||
|
|
||||||
|
choice = choices[0]
|
||||||
|
if choice["delta"] is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||||
|
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
if model_name := chunk.get("model"):
|
||||||
|
generation_info["model_name"] = model_name
|
||||||
|
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||||
|
generation_info["system_fingerprint"] = system_fingerprint
|
||||||
|
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
|
|
||||||
|
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||||
|
message_chunk.usage_metadata = usage_metadata
|
||||||
|
|
||||||
|
generation_chunk = ChatGenerationChunk(
|
||||||
|
message=message_chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
|
return generation_chunk
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@@ -727,7 +729,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
for chunk in response:
|
for chunk in response:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
generation_chunk = _convert_chunk_to_generation_chunk(
|
generation_chunk = self._convert_chunk_to_generation_chunk(
|
||||||
chunk,
|
chunk,
|
||||||
default_chunk_class,
|
default_chunk_class,
|
||||||
base_generation_info if is_first_chunk else {},
|
base_generation_info if is_first_chunk else {},
|
||||||
@@ -895,7 +897,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
generation_chunk = _convert_chunk_to_generation_chunk(
|
generation_chunk = self._convert_chunk_to_generation_chunk(
|
||||||
chunk,
|
chunk,
|
||||||
default_chunk_class,
|
default_chunk_class,
|
||||||
base_generation_info if is_first_chunk else {},
|
base_generation_info if is_first_chunk else {},
|
||||||
|
Reference in New Issue
Block a user