diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 5690843f261..650ea5dd8bf 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -13,6 +13,8 @@ from langchain_core.language_models.chat_models import ( LangSmithParams, LanguageModelInput, ) +from langchain_core.messages import AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable from langchain_core.utils import secret_from_env from langchain_openai.chat_models.base import BaseChatOpenAI @@ -369,6 +371,44 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] ) return self + def _create_chat_result( + self, + response: Union[dict, openai.BaseModel], + generation_info: Optional[dict] = None, + ) -> ChatResult: + rtn = super()._create_chat_result(response, generation_info) + + if not isinstance(response, openai.BaseModel): + return rtn + + if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore + rtn.generations[0].message.additional_kwargs["reasoning_content"] = ( + response.choices[0].message.reasoning_content # type: ignore + ) + + return rtn + + def _convert_chunk_to_generation_chunk( + self, + chunk: dict, + default_chunk_class: type, + base_generation_info: Optional[dict], + ) -> Optional[ChatGenerationChunk]: + generation_chunk = super()._convert_chunk_to_generation_chunk( + chunk, + default_chunk_class, + base_generation_info, + ) + if (choices := chunk.get("choices")) and generation_chunk: + top = choices[0] + if isinstance(generation_chunk.message, AIMessageChunk): + if reasoning_content := top.get("delta", {}).get("reasoning_content"): + generation_chunk.message.additional_kwargs["reasoning_content"] = ( + reasoning_content + ) + + return generation_chunk + def with_structured_output( self, schema: Optional[_DictOrPydanticClass] = None, diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index 632b518f71d..1e14116c05d 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -1,7 +1,9 @@ """Standard LangChain interface tests""" -import pytest # type: ignore[import-not-found] +from typing import Optional + from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessageChunk, BaseMessageChunk from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_tests.integration_tests import ( # type: ignore[import-not-found] ChatModelIntegrationTests, # type: ignore[import-not-found] @@ -24,10 +26,25 @@ class TestXAIStandard(ChatModelIntegrationTests): @property def chat_model_params(self) -> dict: return { - "model": "grok-2", + "model": "grok-3", "rate_limiter": rate_limiter, + "stream_usage": True, } - @pytest.mark.xfail(reason="Not yet supported.") - def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: - super().test_usage_metadata_streaming(model) + +def test_reasoning_content() -> None: + """Test reasoning content.""" + chat_model = ChatXAI( + model="grok-3-mini-beta", + reasoning_effort="low", + ) + response = chat_model.invoke("What is 3^3?") + assert response.content + assert response.additional_kwargs["reasoning_content"] + + # Test streaming + full: Optional[BaseMessageChunk] = None + for chunk in chat_model.stream("What is 3^3?"): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["reasoning_content"]