xai[patch]: support reasoning content (#30758)

https://docs.x.ai/docs/guides/reasoning

```python
from langchain.chat_models import init_chat_model

llm = init_chat_model(
    "xai:grok-3-mini-beta",
    reasoning_effort="low"
)
response = llm.invoke("Hello, world!")
```
This commit is contained in:
ccurme 2025-04-11 10:00:27 -04:00 committed by GitHub
parent 89f28a24d3
commit 9cfb95e621
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 5 deletions

View File

@ -13,6 +13,8 @@ from langchain_core.language_models.chat_models import (
LangSmithParams,
LanguageModelInput,
)
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.utils import secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
@ -369,6 +371,44 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
)
return self
def _create_chat_result(
self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[dict] = None,
) -> ChatResult:
rtn = super()._create_chat_result(response, generation_info)
if not isinstance(response, openai.BaseModel):
return rtn
if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore
rtn.generations[0].message.additional_kwargs["reasoning_content"] = (
response.choices[0].message.reasoning_content # type: ignore
)
return rtn
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: Optional[dict],
) -> Optional[ChatGenerationChunk]:
generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info,
)
if (choices := chunk.get("choices")) and generation_chunk:
top = choices[0]
if isinstance(generation_chunk.message, AIMessageChunk):
if reasoning_content := top.get("delta", {}).get("reasoning_content"):
generation_chunk.message.additional_kwargs["reasoning_content"] = (
reasoning_content
)
return generation_chunk
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,

View File

@ -1,7 +1,9 @@
"""Standard LangChain interface tests"""
import pytest # type: ignore[import-not-found]
from typing import Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found]
@ -24,10 +26,25 @@ class TestXAIStandard(ChatModelIntegrationTests):
@property
def chat_model_params(self) -> dict:
return {
"model": "grok-2",
"model": "grok-3",
"rate_limiter": rate_limiter,
"stream_usage": True,
}
@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
super().test_usage_metadata_streaming(model)
def test_reasoning_content() -> None:
"""Test reasoning content."""
chat_model = ChatXAI(
model="grok-3-mini-beta",
reasoning_effort="low",
)
response = chat_model.invoke("What is 3^3?")
assert response.content
assert response.additional_kwargs["reasoning_content"]
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in chat_model.stream("What is 3^3?"):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]