diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 650ea5dd8bf..1742d6df4f6 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -275,6 +275,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] """ xai_api_base: str = Field(default="https://api.x.ai/v1/") """Base URL path for API requests.""" + search_parameters: Optional[dict[str, Any]] = None + """Parameters for search requests. Example: ``{"mode": "auto"}``.""" openai_api_key: Optional[SecretStr] = None openai_api_base: Optional[str] = None @@ -371,6 +373,18 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] ) return self + @property + def _default_params(self) -> dict[str, Any]: + """Get default parameters.""" + params = super()._default_params + if self.search_parameters: + if "extra_body" in params: + params["extra_body"]["search_parameters"] = self.search_parameters + else: + params["extra_body"] = {"search_parameters": self.search_parameters} + + return params + def _create_chat_result( self, response: Union[dict, openai.BaseModel], @@ -386,6 +400,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] response.choices[0].message.reasoning_content # type: ignore ) + if hasattr(response, "citations"): + rtn.generations[0].message.additional_kwargs["citations"] = ( + response.citations + ) + return rtn def _convert_chunk_to_generation_chunk( @@ -407,6 +426,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] reasoning_content ) + if (citations := chunk.get("citations")) and generation_chunk: + if isinstance(generation_chunk.message, AIMessageChunk): + generation_chunk.message.additional_kwargs["citations"] = citations + return generation_chunk def with_structured_output( diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index 1e14116c05d..67e5bd1707a 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -48,3 +48,24 @@ def test_reasoning_content() -> None: full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) assert full.additional_kwargs["reasoning_content"] + + +def test_web_search() -> None: + llm = ChatXAI( + model="grok-3-latest", + search_parameters={"mode": "auto", "max_search_results": 3}, + ) + + # Test invoke + response = llm.invoke("Provide me a digest of world news in the last 24 hours.") + assert response.content + assert response.additional_kwargs["citations"] + assert len(response.additional_kwargs["citations"]) <= 3 + + # Test streaming + full = None + for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["citations"] + assert len(full.additional_kwargs["citations"]) <= 3