fix(xai): inject model_provider in response_metadata (#33543)

plus tests minor rfc
This commit is contained in:
Mason Daugherty
2025-10-17 16:11:03 -04:00
committed by GitHub
parent 8fd54f13b5
commit 8efa75d04c
6 changed files with 132 additions and 47 deletions

View File

@@ -0,0 +1,96 @@
"""Integration tests for ChatXAI specific features."""
from __future__ import annotations
from typing import Literal
import pytest
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
from langchain_xai import ChatXAI
MODEL_NAME = "grok-4-fast-reasoning"
@pytest.mark.parametrize("output_version", ["", "v1"])
def test_reasoning(output_version: Literal["", "v1"]) -> None:
"""Test reasoning features.
Note: `grok-4` does not return `reasoning_content`, but may optionally return
encrypted reasoning content if `use_encrypted_content` is set to True.
"""
# Test reasoning effort
if output_version:
chat_model = ChatXAI(
model="grok-3-mini",
reasoning_effort="low",
output_version=output_version,
)
else:
chat_model = ChatXAI(
model="grok-3-mini",
reasoning_effort="low",
)
input_message = "What is 3^3?"
response = chat_model.invoke(input_message)
assert response.content
assert response.additional_kwargs["reasoning_content"]
# Test streaming
full: BaseMessageChunk | None = None
for chunk in chat_model.stream(input_message):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]
# Check that we can access reasoning content blocks
assert response.content_blocks
reasoning_content = (
block for block in response.content_blocks if block["type"] == "reasoning"
)
assert len(list(reasoning_content)) >= 1
# Test that passing message with reasoning back in works
follow_up_message = "Based on your reasoning, what is 4^4?"
followup = chat_model.invoke([input_message, response, follow_up_message])
assert followup.content
assert followup.additional_kwargs["reasoning_content"]
followup_reasoning = (
block for block in followup.content_blocks if block["type"] == "reasoning"
)
assert len(list(followup_reasoning)) >= 1
# Test passing in a ReasoningContentBlock
response_metadata = {"model_provider": "xai"}
if output_version:
response_metadata["output_version"] = output_version
msg_w_reasoning = AIMessage(
content_blocks=response.content_blocks,
response_metadata=response_metadata,
)
followup_2 = chat_model.invoke(
[msg_w_reasoning, "Based on your reasoning, what is 5^5?"]
)
assert followup_2.content
assert followup_2.additional_kwargs["reasoning_content"]
def test_web_search() -> None:
llm = ChatXAI(
model=MODEL_NAME,
search_parameters={"mode": "on", "max_search_results": 3},
)
# Test invoke
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
assert response.content
assert response.additional_kwargs["citations"]
assert len(response.additional_kwargs["citations"]) <= 3
# Test streaming
full = None
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["citations"]
assert len(full.additional_kwargs["citations"]) <= 3