xai: support live search (#31379)

https://docs.x.ai/docs/guides/live-search
This commit is contained in:
ccurme 2025-05-27 14:08:59 -04:00 committed by GitHub
parent 443341a20d
commit 8b1f54c419
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 0 deletions

View File

@ -275,6 +275,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
""" """
xai_api_base: str = Field(default="https://api.x.ai/v1/") xai_api_base: str = Field(default="https://api.x.ai/v1/")
"""Base URL path for API requests.""" """Base URL path for API requests."""
search_parameters: Optional[dict[str, Any]] = None
"""Parameters for search requests. Example: ``{"mode": "auto"}``."""
openai_api_key: Optional[SecretStr] = None openai_api_key: Optional[SecretStr] = None
openai_api_base: Optional[str] = None openai_api_base: Optional[str] = None
@ -371,6 +373,18 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
) )
return self return self
@property
def _default_params(self) -> dict[str, Any]:
"""Get default parameters."""
params = super()._default_params
if self.search_parameters:
if "extra_body" in params:
params["extra_body"]["search_parameters"] = self.search_parameters
else:
params["extra_body"] = {"search_parameters": self.search_parameters}
return params
def _create_chat_result( def _create_chat_result(
self, self,
response: Union[dict, openai.BaseModel], response: Union[dict, openai.BaseModel],
@ -386,6 +400,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
response.choices[0].message.reasoning_content # type: ignore response.choices[0].message.reasoning_content # type: ignore
) )
if hasattr(response, "citations"):
rtn.generations[0].message.additional_kwargs["citations"] = (
response.citations
)
return rtn return rtn
def _convert_chunk_to_generation_chunk( def _convert_chunk_to_generation_chunk(
@ -407,6 +426,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
reasoning_content reasoning_content
) )
if (citations := chunk.get("citations")) and generation_chunk:
if isinstance(generation_chunk.message, AIMessageChunk):
generation_chunk.message.additional_kwargs["citations"] = citations
return generation_chunk return generation_chunk
def with_structured_output( def with_structured_output(

View File

@ -48,3 +48,24 @@ def test_reasoning_content() -> None:
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk) assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"] assert full.additional_kwargs["reasoning_content"]
def test_web_search() -> None:
llm = ChatXAI(
model="grok-3-latest",
search_parameters={"mode": "auto", "max_search_results": 3},
)
# Test invoke
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
assert response.content
assert response.additional_kwargs["citations"]
assert len(response.additional_kwargs["citations"]) <= 3
# Test streaming
full = None
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["citations"]
assert len(full.additional_kwargs["citations"]) <= 3