diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 0081a734cf3..1e4d308f9af 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -7,7 +7,6 @@ from typing import ( Optional, ) -import openai from langchain_core.language_models.chat_models import LangSmithParams from langchain_core.utils import secret_from_env from langchain_openai.chat_models.base import BaseChatOpenAI @@ -325,7 +324,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] if self.n is not None and self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") - client_params: dict = { + self._client_params: dict = { "api_key": ( self.xai_api_key.get_secret_value() if self.xai_api_key else None ), @@ -335,27 +334,12 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] "default_query": self.default_query, } if self.max_retries is not None: - client_params["max_retries"] = self.max_retries + self._client_params["max_retries"] = self.max_retries - if client_params["api_key"] is None: + if self._client_params["api_key"] is None: raise ValueError( "xAI API key is not set. Please set it in the `xai_api_key` field or " "in the `XAI_API_KEY` environment variable." ) - if not (self.client or None): - sync_specific: dict = {"http_client": self.http_client} - self.client = openai.OpenAI( - **client_params, **sync_specific - ).chat.completions - self.root_client = openai.OpenAI(**client_params, **sync_specific) - if not (self.async_client or None): - async_specific: dict = {"http_client": self.http_async_client} - self.async_client = openai.AsyncOpenAI( - **client_params, **async_specific - ).chat.completions - self.root_async_client = openai.AsyncOpenAI( - **client_params, - **async_specific, - ) return self