mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
xai[patch]: support reasoning content (#30758)
https://docs.x.ai/docs/guides/reasoning ```python from langchain.chat_models import init_chat_model llm = init_chat_model( "xai:grok-3-mini-beta", reasoning_effort="low" ) response = llm.invoke("Hello, world!") ```
This commit is contained in:
parent
89f28a24d3
commit
9cfb95e621
@ -13,6 +13,8 @@ from langchain_core.language_models.chat_models import (
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils import secret_from_env
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
@ -369,6 +371,44 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
)
|
||||
return self
|
||||
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: Union[dict, openai.BaseModel],
|
||||
generation_info: Optional[dict] = None,
|
||||
) -> ChatResult:
|
||||
rtn = super()._create_chat_result(response, generation_info)
|
||||
|
||||
if not isinstance(response, openai.BaseModel):
|
||||
return rtn
|
||||
|
||||
if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore
|
||||
rtn.generations[0].message.additional_kwargs["reasoning_content"] = (
|
||||
response.choices[0].message.reasoning_content # type: ignore
|
||||
)
|
||||
|
||||
return rtn
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: Optional[dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
base_generation_info,
|
||||
)
|
||||
if (choices := chunk.get("choices")) and generation_chunk:
|
||||
top = choices[0]
|
||||
if isinstance(generation_chunk.message, AIMessageChunk):
|
||||
if reasoning_content := top.get("delta", {}).get("reasoning_content"):
|
||||
generation_chunk.message.additional_kwargs["reasoning_content"] = (
|
||||
reasoning_content
|
||||
)
|
||||
|
||||
return generation_chunk
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
|
@ -1,7 +1,9 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelIntegrationTests, # type: ignore[import-not-found]
|
||||
@ -24,10 +26,25 @@ class TestXAIStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "grok-2",
|
||||
"model": "grok-3",
|
||||
"rate_limiter": rate_limiter,
|
||||
"stream_usage": True,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
||||
def test_reasoning_content() -> None:
|
||||
"""Test reasoning content."""
|
||||
chat_model = ChatXAI(
|
||||
model="grok-3-mini-beta",
|
||||
reasoning_effort="low",
|
||||
)
|
||||
response = chat_model.invoke("What is 3^3?")
|
||||
assert response.content
|
||||
assert response.additional_kwargs["reasoning_content"]
|
||||
|
||||
# Test streaming
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in chat_model.stream("What is 3^3?"):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.additional_kwargs["reasoning_content"]
|
||||
|
Loading…
Reference in New Issue
Block a user