mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
xai: support live search (#31379)
https://docs.x.ai/docs/guides/live-search
This commit is contained in:
parent
443341a20d
commit
8b1f54c419
@ -275,6 +275,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
"""
|
"""
|
||||||
xai_api_base: str = Field(default="https://api.x.ai/v1/")
|
xai_api_base: str = Field(default="https://api.x.ai/v1/")
|
||||||
"""Base URL path for API requests."""
|
"""Base URL path for API requests."""
|
||||||
|
search_parameters: Optional[dict[str, Any]] = None
|
||||||
|
"""Parameters for search requests. Example: ``{"mode": "auto"}``."""
|
||||||
|
|
||||||
openai_api_key: Optional[SecretStr] = None
|
openai_api_key: Optional[SecretStr] = None
|
||||||
openai_api_base: Optional[str] = None
|
openai_api_base: Optional[str] = None
|
||||||
@ -371,6 +373,18 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> dict[str, Any]:
|
||||||
|
"""Get default parameters."""
|
||||||
|
params = super()._default_params
|
||||||
|
if self.search_parameters:
|
||||||
|
if "extra_body" in params:
|
||||||
|
params["extra_body"]["search_parameters"] = self.search_parameters
|
||||||
|
else:
|
||||||
|
params["extra_body"] = {"search_parameters": self.search_parameters}
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def _create_chat_result(
|
def _create_chat_result(
|
||||||
self,
|
self,
|
||||||
response: Union[dict, openai.BaseModel],
|
response: Union[dict, openai.BaseModel],
|
||||||
@ -386,6 +400,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
response.choices[0].message.reasoning_content # type: ignore
|
response.choices[0].message.reasoning_content # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(response, "citations"):
|
||||||
|
rtn.generations[0].message.additional_kwargs["citations"] = (
|
||||||
|
response.citations
|
||||||
|
)
|
||||||
|
|
||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
def _convert_chunk_to_generation_chunk(
|
def _convert_chunk_to_generation_chunk(
|
||||||
@ -407,6 +426,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
reasoning_content
|
reasoning_content
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (citations := chunk.get("citations")) and generation_chunk:
|
||||||
|
if isinstance(generation_chunk.message, AIMessageChunk):
|
||||||
|
generation_chunk.message.additional_kwargs["citations"] = citations
|
||||||
|
|
||||||
return generation_chunk
|
return generation_chunk
|
||||||
|
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
|
@ -48,3 +48,24 @@ def test_reasoning_content() -> None:
|
|||||||
full = chunk if full is None else full + chunk
|
full = chunk if full is None else full + chunk
|
||||||
assert isinstance(full, AIMessageChunk)
|
assert isinstance(full, AIMessageChunk)
|
||||||
assert full.additional_kwargs["reasoning_content"]
|
assert full.additional_kwargs["reasoning_content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_web_search() -> None:
|
||||||
|
llm = ChatXAI(
|
||||||
|
model="grok-3-latest",
|
||||||
|
search_parameters={"mode": "auto", "max_search_results": 3},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
|
||||||
|
assert response.content
|
||||||
|
assert response.additional_kwargs["citations"]
|
||||||
|
assert len(response.additional_kwargs["citations"]) <= 3
|
||||||
|
|
||||||
|
# Test streaming
|
||||||
|
full = None
|
||||||
|
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.additional_kwargs["citations"]
|
||||||
|
assert len(full.additional_kwargs["citations"]) <= 3
|
||||||
|
Loading…
Reference in New Issue
Block a user