fix(xai): inject model_provider in response_metadata (#33543)

plus tests minor rfc
This commit is contained in:
Mason Daugherty
2025-10-17 16:11:03 -04:00
committed by GitHub
parent 8fd54f13b5
commit 8efa75d04c
6 changed files with 132 additions and 47 deletions

View File

@@ -529,6 +529,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
) -> ChatResult: ) -> ChatResult:
rtn = super()._create_chat_result(response, generation_info) rtn = super()._create_chat_result(response, generation_info)
for generation in rtn.generations:
generation.message.response_metadata["model_provider"] = "xai"
if not isinstance(response, openai.BaseModel): if not isinstance(response, openai.BaseModel):
return rtn return rtn
@@ -555,6 +558,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
default_chunk_class, default_chunk_class,
base_generation_info, base_generation_info,
) )
if generation_chunk:
generation_chunk.message.response_metadata["model_provider"] = "xai"
if (choices := chunk.get("choices")) and generation_chunk: if (choices := chunk.get("choices")) and generation_chunk:
top = choices[0] top = choices[0]
if isinstance(generation_chunk.message, AIMessageChunk) and ( if isinstance(generation_chunk.message, AIMessageChunk) and (

View File

@@ -0,0 +1,96 @@
"""Integration tests for ChatXAI specific features."""
from __future__ import annotations
from typing import Literal
import pytest
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
from langchain_xai import ChatXAI
MODEL_NAME = "grok-4-fast-reasoning"
@pytest.mark.parametrize("output_version", ["", "v1"])
def test_reasoning(output_version: Literal["", "v1"]) -> None:
"""Test reasoning features.
Note: `grok-4` does not return `reasoning_content`, but may optionally return
encrypted reasoning content if `use_encrypted_content` is set to True.
"""
# Test reasoning effort
if output_version:
chat_model = ChatXAI(
model="grok-3-mini",
reasoning_effort="low",
output_version=output_version,
)
else:
chat_model = ChatXAI(
model="grok-3-mini",
reasoning_effort="low",
)
input_message = "What is 3^3?"
response = chat_model.invoke(input_message)
assert response.content
assert response.additional_kwargs["reasoning_content"]
# Test streaming
full: BaseMessageChunk | None = None
for chunk in chat_model.stream(input_message):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]
# Check that we can access reasoning content blocks
assert response.content_blocks
reasoning_content = (
block for block in response.content_blocks if block["type"] == "reasoning"
)
assert len(list(reasoning_content)) >= 1
# Test that passing message with reasoning back in works
follow_up_message = "Based on your reasoning, what is 4^4?"
followup = chat_model.invoke([input_message, response, follow_up_message])
assert followup.content
assert followup.additional_kwargs["reasoning_content"]
followup_reasoning = (
block for block in followup.content_blocks if block["type"] == "reasoning"
)
assert len(list(followup_reasoning)) >= 1
# Test passing in a ReasoningContentBlock
response_metadata = {"model_provider": "xai"}
if output_version:
response_metadata["output_version"] = output_version
msg_w_reasoning = AIMessage(
content_blocks=response.content_blocks,
response_metadata=response_metadata,
)
followup_2 = chat_model.invoke(
[msg_w_reasoning, "Based on your reasoning, what is 5^5?"]
)
assert followup_2.content
assert followup_2.additional_kwargs["reasoning_content"]
def test_web_search() -> None:
llm = ChatXAI(
model=MODEL_NAME,
search_parameters={"mode": "on", "max_search_results": 3},
)
# Test invoke
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
assert response.content
assert response.additional_kwargs["citations"]
assert len(response.additional_kwargs["citations"]) <= 3
# Test streaming
full = None
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["citations"]
assert len(full.additional_kwargs["citations"]) <= 3

View File

@@ -4,26 +4,25 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from langchain_core.messages import AIMessageChunk, BaseMessageChunk import pytest
from langchain_core.messages import AIMessage
from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_tests.integration_tests import ( # type: ignore[import-not-found] from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found] ChatModelIntegrationTests, # type: ignore[import-not-found]
) )
from typing_extensions import override
from langchain_xai import ChatXAI from langchain_xai import ChatXAI
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
# Initialize the rate limiter in global scope, so it can be re-used # Initialize the rate limiter in global scope, so it can be re-used across tests
# across tests.
rate_limiter = InMemoryRateLimiter( rate_limiter = InMemoryRateLimiter(
requests_per_second=0.5, requests_per_second=0.5,
) )
MODEL_NAME = "grok-4-fast-reasoning"
# Not using Grok 4 since it doesn't support reasoning params (effort) or returns
# reasoning content.
class TestXAIStandard(ChatModelIntegrationTests): class TestXAIStandard(ChatModelIntegrationTests):
@@ -33,48 +32,28 @@ class TestXAIStandard(ChatModelIntegrationTests):
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
# TODO: bump to test new Grok once they implement other features
return { return {
"model": "grok-3", "model": MODEL_NAME,
"rate_limiter": rate_limiter, "rate_limiter": rate_limiter,
"stream_usage": True, "stream_usage": True,
} }
@pytest.mark.xfail(
def test_reasoning_content() -> None: reason="Default model does not support stop sequences, using grok-3 instead"
"""Test reasoning content."""
chat_model = ChatXAI(
model="grok-3-mini",
reasoning_effort="low",
) )
response = chat_model.invoke("What is 3^3?") @override
assert response.content def test_stop_sequence(self, model: BaseChatModel) -> None:
assert response.additional_kwargs["reasoning_content"] """Override to use `grok-3` which supports stop sequences."""
params = {**self.chat_model_params, "model": "grok-3"}
# Test streaming grok3_model = ChatXAI(**params)
full: BaseMessageChunk | None = None
for chunk in chat_model.stream("What is 3^3?"):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]
result = grok3_model.invoke("hi", stop=["you"])
assert isinstance(result, AIMessage)
def test_web_search() -> None: custom_model = ChatXAI(
llm = ChatXAI( **params,
model="grok-3", stop_sequences=["you"],
search_parameters={"mode": "auto", "max_search_results": 3}, )
) result = custom_model.invoke("hi")
assert isinstance(result, AIMessage)
# Test invoke
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
assert response.content
assert response.additional_kwargs["citations"]
assert len(response.additional_kwargs["citations"]) <= 3
# Test streaming
full = None
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["citations"]
assert len(full.additional_kwargs["citations"]) <= 3

View File

@@ -9,7 +9,7 @@
'kwargs': dict({ 'kwargs': dict({
'max_retries': 2, 'max_retries': 2,
'max_tokens': 100, 'max_tokens': 100,
'model_name': 'grok-beta', 'model_name': 'grok-4',
'request_timeout': 60.0, 'request_timeout': 60.0,
'stop': list([ 'stop': list([
]), ]),

View File

@@ -7,6 +7,8 @@ from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
from langchain_xai import ChatXAI from langchain_xai import ChatXAI
MODEL_NAME = "grok-4"
class TestXAIStandard(ChatModelUnitTests): class TestXAIStandard(ChatModelUnitTests):
@property @property
@@ -15,7 +17,7 @@ class TestXAIStandard(ChatModelUnitTests):
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return {"model": "grok-beta"} return {"model": MODEL_NAME}
@property @property
def init_from_env_params(self) -> tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
@@ -24,7 +26,7 @@ class TestXAIStandard(ChatModelUnitTests):
"XAI_API_KEY": "api_key", "XAI_API_KEY": "api_key",
}, },
{ {
"model": "grok-beta", "model": MODEL_NAME,
}, },
{ {
"xai_api_key": "api_key", "xai_api_key": "api_key",

View File

@@ -621,7 +621,7 @@ wheels = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "1.0.0a8" version = "1.0.0rc2"
source = { editable = "../../core" } source = { editable = "../../core" }
dependencies = [ dependencies = [
{ name = "jsonpatch" }, { name = "jsonpatch" },
@@ -699,6 +699,7 @@ dev = [{ name = "langchain-core", editable = "../../core" }]
lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }] lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }]
test = [ test = [
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, { name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
{ name = "langchain", editable = "../../langchain_v1" },
{ name = "langchain-core", editable = "../../core" }, { name = "langchain-core", editable = "../../core" },
{ name = "langchain-tests", editable = "../../standard-tests" }, { name = "langchain-tests", editable = "../../standard-tests" },
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" }, { name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" },
@@ -728,7 +729,7 @@ typing = [
[[package]] [[package]]
name = "langchain-tests" name = "langchain-tests"
version = "1.0.0a2" version = "1.0.0rc1"
source = { editable = "../../standard-tests" } source = { editable = "../../standard-tests" }
dependencies = [ dependencies = [
{ name = "httpx" }, { name = "httpx" },