"""Integration tests for ChatXAI specific features.""" from __future__ import annotations from typing import Literal import pytest from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk from langchain_xai import ChatXAI MODEL_NAME = "grok-4-fast-reasoning" @pytest.mark.parametrize("output_version", ["", "v1"]) def test_reasoning(output_version: Literal["", "v1"]) -> None: """Test reasoning features. !!! note `grok-4` does not return `reasoning_content`, but may optionally return encrypted reasoning content if `use_encrypted_content` is set to `True`. """ # Test reasoning effort if output_version: chat_model = ChatXAI( model="grok-3-mini", reasoning_effort="low", temperature=0, output_version=output_version, ) else: chat_model = ChatXAI( model="grok-3-mini", reasoning_effort="low", temperature=0, ) input_message = "What is 3^3?" response = chat_model.invoke(input_message) assert response.content assert response.additional_kwargs["reasoning_content"] ## Check output tokens usage_metadata = response.usage_metadata assert usage_metadata reasoning_tokens = usage_metadata.get("output_token_details", {}).get("reasoning") total_tokens = usage_metadata.get("output_tokens") assert total_tokens assert reasoning_tokens assert total_tokens > reasoning_tokens # Test streaming full: BaseMessageChunk | None = None for chunk in chat_model.stream(input_message): full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) assert full.additional_kwargs["reasoning_content"] ## Check output tokens usage_metadata = full.usage_metadata assert usage_metadata reasoning_tokens = usage_metadata.get("output_token_details", {}).get("reasoning") total_tokens = usage_metadata.get("output_tokens") assert total_tokens assert reasoning_tokens assert total_tokens > reasoning_tokens # Check that we can access reasoning content blocks assert response.content_blocks reasoning_content = ( block for block in response.content_blocks if block["type"] == "reasoning" ) assert len(list(reasoning_content)) >= 1 # Test that passing message with reasoning back in works follow_up_message = "Based on your reasoning, what is 4^4?" followup = chat_model.invoke([input_message, response, follow_up_message]) assert followup.content assert followup.additional_kwargs["reasoning_content"] followup_reasoning = ( block for block in followup.content_blocks if block["type"] == "reasoning" ) assert len(list(followup_reasoning)) >= 1 # Test passing in a ReasoningContentBlock response_metadata = {"model_provider": "xai"} if output_version: response_metadata["output_version"] = output_version msg_w_reasoning = AIMessage( content_blocks=response.content_blocks, response_metadata=response_metadata, ) followup_2 = chat_model.invoke( [msg_w_reasoning, "Based on your reasoning, what is 5^5?"] ) assert followup_2.content assert followup_2.additional_kwargs["reasoning_content"] def test_web_search() -> None: llm = ChatXAI(model=MODEL_NAME, temperature=0).bind_tools([{"type": "web_search"}]) # xAI may emit additional block types (e.g. `citation`, `reasoning`) alongside # the core set, so assert each required type is present individually rather # than checking set equality. expected_types = ("server_tool_call", "server_tool_result", "text") def _assert_web_search_block(blocks: list) -> None: server_tool_calls = [b for b in blocks if b["type"] == "server_tool_call"] assert server_tool_calls, "expected at least one server_tool_call block" assert server_tool_calls[0]["name"] == "web_search" # Test invoke response = llm.invoke("Look up the current time in Boston, MA.") assert response.content content_types = {block["type"] for block in response.content_blocks} for expected in expected_types: assert expected in content_types, f"missing {expected!r} in {content_types}" _assert_web_search_block(response.content_blocks) # Test streaming full: AIMessageChunk | None = None for chunk in llm.stream("Look up the current time in Boston, MA."): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) content_types = {block["type"] for block in full.content_blocks} for expected in expected_types: assert expected in content_types, f"missing {expected!r} in {content_types}" _assert_web_search_block(full.content_blocks)