From 8efa75d04ce8822ce7f6ef5f3aab1a1437f32631 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Fri, 17 Oct 2025 16:11:03 -0400 Subject: [PATCH] fix(xai): inject `model_provider` in `response_metadata` (#33543) plus tests minor rfc --- .../partners/xai/langchain_xai/chat_models.py | 7 ++ .../integration_tests/test_chat_models.py | 96 +++++++++++++++++++ .../test_chat_models_standard.py | 63 ++++-------- .../test_chat_models_standard.ambr | 2 +- .../unit_tests/test_chat_models_standard.py | 6 +- libs/partners/xai/uv.lock | 5 +- 6 files changed, 132 insertions(+), 47 deletions(-) create mode 100644 libs/partners/xai/tests/integration_tests/test_chat_models.py diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 22701217153..d2de07c20ff 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -529,6 +529,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] ) -> ChatResult: rtn = super()._create_chat_result(response, generation_info) + for generation in rtn.generations: + generation.message.response_metadata["model_provider"] = "xai" + if not isinstance(response, openai.BaseModel): return rtn @@ -555,6 +558,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] default_chunk_class, base_generation_info, ) + + if generation_chunk: + generation_chunk.message.response_metadata["model_provider"] = "xai" + if (choices := chunk.get("choices")) and generation_chunk: top = choices[0] if isinstance(generation_chunk.message, AIMessageChunk) and ( diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models.py b/libs/partners/xai/tests/integration_tests/test_chat_models.py new file mode 100644 index 00000000000..049f51e59bc --- /dev/null +++ b/libs/partners/xai/tests/integration_tests/test_chat_models.py @@ -0,0 +1,96 @@ +"""Integration tests for ChatXAI specific features.""" + +from __future__ import annotations + +from typing import Literal + +import pytest +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk + +from langchain_xai import ChatXAI + +MODEL_NAME = "grok-4-fast-reasoning" + + +@pytest.mark.parametrize("output_version", ["", "v1"]) +def test_reasoning(output_version: Literal["", "v1"]) -> None: + """Test reasoning features. + + Note: `grok-4` does not return `reasoning_content`, but may optionally return + encrypted reasoning content if `use_encrypted_content` is set to True. + """ + # Test reasoning effort + if output_version: + chat_model = ChatXAI( + model="grok-3-mini", + reasoning_effort="low", + output_version=output_version, + ) + else: + chat_model = ChatXAI( + model="grok-3-mini", + reasoning_effort="low", + ) + input_message = "What is 3^3?" + response = chat_model.invoke(input_message) + assert response.content + assert response.additional_kwargs["reasoning_content"] + + # Test streaming + full: BaseMessageChunk | None = None + for chunk in chat_model.stream(input_message): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["reasoning_content"] + + # Check that we can access reasoning content blocks + assert response.content_blocks + reasoning_content = ( + block for block in response.content_blocks if block["type"] == "reasoning" + ) + assert len(list(reasoning_content)) >= 1 + + # Test that passing message with reasoning back in works + follow_up_message = "Based on your reasoning, what is 4^4?" + followup = chat_model.invoke([input_message, response, follow_up_message]) + assert followup.content + assert followup.additional_kwargs["reasoning_content"] + followup_reasoning = ( + block for block in followup.content_blocks if block["type"] == "reasoning" + ) + assert len(list(followup_reasoning)) >= 1 + + # Test passing in a ReasoningContentBlock + response_metadata = {"model_provider": "xai"} + if output_version: + response_metadata["output_version"] = output_version + msg_w_reasoning = AIMessage( + content_blocks=response.content_blocks, + response_metadata=response_metadata, + ) + followup_2 = chat_model.invoke( + [msg_w_reasoning, "Based on your reasoning, what is 5^5?"] + ) + assert followup_2.content + assert followup_2.additional_kwargs["reasoning_content"] + + +def test_web_search() -> None: + llm = ChatXAI( + model=MODEL_NAME, + search_parameters={"mode": "on", "max_search_results": 3}, + ) + + # Test invoke + response = llm.invoke("Provide me a digest of world news in the last 24 hours.") + assert response.content + assert response.additional_kwargs["citations"] + assert len(response.additional_kwargs["citations"]) <= 3 + + # Test streaming + full = None + for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["citations"] + assert len(full.additional_kwargs["citations"]) <= 3 diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index dada7d5c4bc..9b7871d6ac2 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -4,26 +4,25 @@ from __future__ import annotations from typing import TYPE_CHECKING -from langchain_core.messages import AIMessageChunk, BaseMessageChunk +import pytest +from langchain_core.messages import AIMessage from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_tests.integration_tests import ( # type: ignore[import-not-found] ChatModelIntegrationTests, # type: ignore[import-not-found] ) +from typing_extensions import override from langchain_xai import ChatXAI if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel -# Initialize the rate limiter in global scope, so it can be re-used -# across tests. +# Initialize the rate limiter in global scope, so it can be re-used across tests rate_limiter = InMemoryRateLimiter( requests_per_second=0.5, ) - -# Not using Grok 4 since it doesn't support reasoning params (effort) or returns -# reasoning content. +MODEL_NAME = "grok-4-fast-reasoning" class TestXAIStandard(ChatModelIntegrationTests): @@ -33,48 +32,28 @@ class TestXAIStandard(ChatModelIntegrationTests): @property def chat_model_params(self) -> dict: - # TODO: bump to test new Grok once they implement other features return { - "model": "grok-3", + "model": MODEL_NAME, "rate_limiter": rate_limiter, "stream_usage": True, } - -def test_reasoning_content() -> None: - """Test reasoning content.""" - chat_model = ChatXAI( - model="grok-3-mini", - reasoning_effort="low", + @pytest.mark.xfail( + reason="Default model does not support stop sequences, using grok-3 instead" ) - response = chat_model.invoke("What is 3^3?") - assert response.content - assert response.additional_kwargs["reasoning_content"] + @override + def test_stop_sequence(self, model: BaseChatModel) -> None: + """Override to use `grok-3` which supports stop sequences.""" + params = {**self.chat_model_params, "model": "grok-3"} - # Test streaming - full: BaseMessageChunk | None = None - for chunk in chat_model.stream("What is 3^3?"): - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - assert full.additional_kwargs["reasoning_content"] + grok3_model = ChatXAI(**params) + result = grok3_model.invoke("hi", stop=["you"]) + assert isinstance(result, AIMessage) -def test_web_search() -> None: - llm = ChatXAI( - model="grok-3", - search_parameters={"mode": "auto", "max_search_results": 3}, - ) - - # Test invoke - response = llm.invoke("Provide me a digest of world news in the last 24 hours.") - assert response.content - assert response.additional_kwargs["citations"] - assert len(response.additional_kwargs["citations"]) <= 3 - - # Test streaming - full = None - for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."): - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - assert full.additional_kwargs["citations"] - assert len(full.additional_kwargs["citations"]) <= 3 + custom_model = ChatXAI( + **params, + stop_sequences=["you"], + ) + result = custom_model.invoke("hi") + assert isinstance(result, AIMessage) diff --git a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr index 4cd1261555c..6f1cd18c2e2 100644 --- a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr +++ b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr @@ -9,7 +9,7 @@ 'kwargs': dict({ 'max_retries': 2, 'max_tokens': 100, - 'model_name': 'grok-beta', + 'model_name': 'grok-4', 'request_timeout': 60.0, 'stop': list([ ]), diff --git a/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py b/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py index 7c9a947a913..7224655db79 100644 --- a/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py @@ -7,6 +7,8 @@ from langchain_tests.unit_tests import ( # type: ignore[import-not-found] from langchain_xai import ChatXAI +MODEL_NAME = "grok-4" + class TestXAIStandard(ChatModelUnitTests): @property @@ -15,7 +17,7 @@ class TestXAIStandard(ChatModelUnitTests): @property def chat_model_params(self) -> dict: - return {"model": "grok-beta"} + return {"model": MODEL_NAME} @property def init_from_env_params(self) -> tuple[dict, dict, dict]: @@ -24,7 +26,7 @@ class TestXAIStandard(ChatModelUnitTests): "XAI_API_KEY": "api_key", }, { - "model": "grok-beta", + "model": MODEL_NAME, }, { "xai_api_key": "api_key", diff --git a/libs/partners/xai/uv.lock b/libs/partners/xai/uv.lock index 0e0024f50f8..ac0ffb5c3bc 100644 --- a/libs/partners/xai/uv.lock +++ b/libs/partners/xai/uv.lock @@ -621,7 +621,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.0.0a8" +version = "1.0.0rc2" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -699,6 +699,7 @@ dev = [{ name = "langchain-core", editable = "../../core" }] lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }] test = [ { name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, + { name = "langchain", editable = "../../langchain_v1" }, { name = "langchain-core", editable = "../../core" }, { name = "langchain-tests", editable = "../../standard-tests" }, { name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" }, @@ -728,7 +729,7 @@ typing = [ [[package]] name = "langchain-tests" -version = "1.0.0a2" +version = "1.0.0rc1" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" },