mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
xai: support live search (#31379)
https://docs.x.ai/docs/guides/live-search
This commit is contained in:
parent
443341a20d
commit
8b1f54c419
@ -275,6 +275,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"""
|
||||
xai_api_base: str = Field(default="https://api.x.ai/v1/")
|
||||
"""Base URL path for API requests."""
|
||||
search_parameters: Optional[dict[str, Any]] = None
|
||||
"""Parameters for search requests. Example: ``{"mode": "auto"}``."""
|
||||
|
||||
openai_api_key: Optional[SecretStr] = None
|
||||
openai_api_base: Optional[str] = None
|
||||
@ -371,6 +373,18 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get default parameters."""
|
||||
params = super()._default_params
|
||||
if self.search_parameters:
|
||||
if "extra_body" in params:
|
||||
params["extra_body"]["search_parameters"] = self.search_parameters
|
||||
else:
|
||||
params["extra_body"] = {"search_parameters": self.search_parameters}
|
||||
|
||||
return params
|
||||
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: Union[dict, openai.BaseModel],
|
||||
@ -386,6 +400,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
response.choices[0].message.reasoning_content # type: ignore
|
||||
)
|
||||
|
||||
if hasattr(response, "citations"):
|
||||
rtn.generations[0].message.additional_kwargs["citations"] = (
|
||||
response.citations
|
||||
)
|
||||
|
||||
return rtn
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
@ -407,6 +426,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
reasoning_content
|
||||
)
|
||||
|
||||
if (citations := chunk.get("citations")) and generation_chunk:
|
||||
if isinstance(generation_chunk.message, AIMessageChunk):
|
||||
generation_chunk.message.additional_kwargs["citations"] = citations
|
||||
|
||||
return generation_chunk
|
||||
|
||||
def with_structured_output(
|
||||
|
@ -48,3 +48,24 @@ def test_reasoning_content() -> None:
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
def test_web_search() -> None:
|
||||
llm = ChatXAI(
|
||||
model="grok-3-latest",
|
||||
search_parameters={"mode": "auto", "max_search_results": 3},
|
||||
)
|
||||
|
||||
# Test invoke
|
||||
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
|
||||
assert response.content
|
||||
assert response.additional_kwargs["citations"]
|
||||
assert len(response.additional_kwargs["citations"]) <= 3
|
||||
|
||||
# Test streaming
|
||||
full = None
|
||||
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.additional_kwargs["citations"]
|
||||
assert len(full.additional_kwargs["citations"]) <= 3
|
||||
|
Loading…
Reference in New Issue
Block a user