mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 00:00:34 +00:00
Compare commits
15 Commits
langchain-
...
cc/openai_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f67aacab3 | ||
|
|
b07bb53f5b | ||
|
|
e19fed1fcc | ||
|
|
c23d7f1c2a | ||
|
|
0c4fdeae67 | ||
|
|
be587d1640 | ||
|
|
230f2e030a | ||
|
|
7506c72346 | ||
|
|
24765e49fd | ||
|
|
69400b8704 | ||
|
|
4588e06794 | ||
|
|
adcb5396d9 | ||
|
|
619a885263 | ||
|
|
6c2474b220 | ||
|
|
c289fc9ba9 |
@@ -178,7 +178,7 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
self.api_key and self.api_key.get_secret_value()
|
||||
):
|
||||
raise ValueError("If using default api base, DEEPSEEK_API_KEY must be set.")
|
||||
client_params: dict = {
|
||||
self._client_params: dict = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"api_key": self.api_key.get_secret_value() if self.api_key else None,
|
||||
@@ -191,16 +191,6 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
if v is not None
|
||||
}
|
||||
|
||||
if not (self.client or None):
|
||||
sync_specific: dict = {"http_client": self.http_client}
|
||||
self.client = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
if not (self.async_client or None):
|
||||
async_specific: dict = {"http_client": self.http_async_client}
|
||||
self.async_client = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
return self
|
||||
|
||||
def _create_chat_result(
|
||||
|
||||
@@ -629,7 +629,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"Or you can equivalently specify:\n\n"
|
||||
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
|
||||
)
|
||||
client_params: dict = {
|
||||
self._client_params: dict = {
|
||||
"api_version": self.openai_api_version,
|
||||
"azure_endpoint": self.azure_endpoint,
|
||||
"azure_deployment": self.deployment_name,
|
||||
@@ -650,27 +650,43 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
self._client_params["max_retries"] = self.max_retries
|
||||
|
||||
if not self.client:
|
||||
sync_specific = {"http_client": self.http_client}
|
||||
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
|
||||
self.client = self.root_client.chat.completions
|
||||
if not self.async_client:
|
||||
async_specific = {"http_client": self.http_async_client}
|
||||
|
||||
if self.azure_ad_async_token_provider:
|
||||
client_params["azure_ad_token_provider"] = (
|
||||
self.azure_ad_async_token_provider
|
||||
)
|
||||
|
||||
self.root_async_client = openai.AsyncAzureOpenAI(
|
||||
**client_params,
|
||||
**async_specific, # type: ignore[arg-type]
|
||||
if self.azure_ad_async_token_provider:
|
||||
self._client_params["azure_ad_token_provider"] = (
|
||||
self.azure_ad_async_token_provider
|
||||
)
|
||||
self.async_client = self.root_async_client.chat.completions
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def root_client(self) -> Any:
|
||||
if self._root_client is None:
|
||||
sync_specific = {"http_client": self.http_client}
|
||||
self._root_client = openai.AzureOpenAI(
|
||||
**self._client_params,
|
||||
**sync_specific, # type: ignore[call-overload]
|
||||
)
|
||||
return self._root_client
|
||||
|
||||
@root_client.setter
|
||||
def root_client(self, value: openai.AzureOpenAI) -> None:
|
||||
self._root_client = value
|
||||
|
||||
@property
|
||||
def root_async_client(self) -> Any:
|
||||
if self._root_async_client is None:
|
||||
async_specific = {"http_client": self.http_async_client}
|
||||
self._root_async_client = openai.AsyncAzureOpenAI(
|
||||
**self._client_params,
|
||||
**async_specific, # type: ignore[call-overload]
|
||||
)
|
||||
return self._root_async_client
|
||||
|
||||
@root_async_client.setter
|
||||
def root_async_client(self, value: openai.AsyncAzureOpenAI) -> None:
|
||||
self._root_async_client = value
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
||||
@@ -6,12 +6,14 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from math import ceil
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
@@ -31,6 +33,7 @@ from typing import (
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import certifi
|
||||
import openai
|
||||
import tiktoken
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
@@ -91,12 +94,25 @@ from langchain_core.utils.pydantic import (
|
||||
is_basemodel_subclass,
|
||||
)
|
||||
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This SSL context is equivelent to the default `verify=True`.
|
||||
global_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
@@ -385,10 +401,10 @@ class _AllReturnType(TypedDict):
|
||||
|
||||
|
||||
class BaseChatOpenAI(BaseChatModel):
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
root_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
_client: Any = PrivateAttr(default=None) #: :meta private:
|
||||
_async_client: Any = PrivateAttr(default=None) #: :meta private:
|
||||
_root_client: Any = PrivateAttr(default=None) #: :meta private:
|
||||
_root_async_client: Any = PrivateAttr(default=None) #: :meta private:
|
||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: Optional[float] = None
|
||||
@@ -460,11 +476,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||
_http_client: Union[Any, None] = PrivateAttr(default=None)
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||
_http_async_client: Union[Any, None] = PrivateAttr(default=None)
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
@@ -491,6 +507,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
However this does not prevent a user from directly passed in the parameter during
|
||||
invocation.
|
||||
"""
|
||||
_client_params: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@@ -511,6 +528,24 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
values["temperature"] = 1
|
||||
return values
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Optional[Any] = None,
|
||||
async_client: Optional[Any] = None,
|
||||
root_client: Optional[Any] = None,
|
||||
async_root_client: Optional[Any] = None,
|
||||
http_client: Optional[Any] = None,
|
||||
http_async_client: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._client = client
|
||||
self._async_client = async_client
|
||||
self._root_client = root_client
|
||||
self._async_root_client = async_root_client
|
||||
self._http_client = http_client
|
||||
self._http_async_client = http_async_client
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@@ -526,7 +561,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
or os.getenv("OPENAI_ORGANIZATION")
|
||||
)
|
||||
self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE")
|
||||
client_params: dict = {
|
||||
self._client_params: dict = {
|
||||
"api_key": (
|
||||
self.openai_api_key.get_secret_value() if self.openai_api_key else None
|
||||
),
|
||||
@@ -537,47 +572,122 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
self._client_params["max_retries"] = self.max_retries
|
||||
|
||||
if self.openai_proxy and (self.http_client or self.http_async_client):
|
||||
if self.openai_proxy and (self._http_client or self._http_async_client):
|
||||
openai_proxy = self.openai_proxy
|
||||
http_client = self.http_client
|
||||
http_async_client = self.http_async_client
|
||||
http_client = self._http_client
|
||||
http_async_client = self._http_async_client
|
||||
raise ValueError(
|
||||
"Cannot specify 'openai_proxy' if one of "
|
||||
"'http_client'/'http_async_client' is already specified. Received:\n"
|
||||
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
|
||||
)
|
||||
if not self.client:
|
||||
if self.openai_proxy and not self.http_client:
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import httpx python package. "
|
||||
"Please install it with `pip install httpx`."
|
||||
) from e
|
||||
self.http_client = httpx.Client(proxy=self.openai_proxy)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def http_client(self) -> Optional[httpx.Client]:
|
||||
"""Optional httpx.Client. Only used for sync invocations.
|
||||
|
||||
Must specify http_async_client as well if you'd like a custom client for
|
||||
async invocations.
|
||||
"""
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more
|
||||
# details.
|
||||
if self._http_client is None:
|
||||
if not self.openai_proxy:
|
||||
return None
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import httpx python package. "
|
||||
"Please install it with `pip install httpx`."
|
||||
) from e
|
||||
self._http_client = httpx.Client(
|
||||
proxy=self.openai_proxy, verify=global_ssl_context
|
||||
)
|
||||
return self._http_client
|
||||
|
||||
@http_client.setter
|
||||
def http_client(self, value: Optional[httpx.Client]) -> None:
|
||||
self._http_client = value
|
||||
|
||||
@property
|
||||
def http_async_client(self) -> Optional[httpx.AsyncClient]:
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations.
|
||||
|
||||
Must specify http_client as well if you'd like a custom client for sync
|
||||
invocations.
|
||||
"""
|
||||
if self._http_async_client is None:
|
||||
if not self.openai_proxy:
|
||||
return None
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import httpx python package. "
|
||||
"Please install it with `pip install httpx`."
|
||||
) from e
|
||||
self._http_async_client = httpx.AsyncClient(
|
||||
proxy=self.openai_proxy, verify=global_ssl_context
|
||||
)
|
||||
return self._http_async_client
|
||||
|
||||
@http_async_client.setter
|
||||
def http_async_client(self, value: Optional[httpx.AsyncClient]) -> None:
|
||||
self._http_async_client = value
|
||||
|
||||
@property
|
||||
def root_client(self) -> Any:
|
||||
if self._root_client is None:
|
||||
sync_specific = {"http_client": self.http_client}
|
||||
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
|
||||
self.client = self.root_client.chat.completions
|
||||
if not self.async_client:
|
||||
if self.openai_proxy and not self.http_async_client:
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import httpx python package. "
|
||||
"Please install it with `pip install httpx`."
|
||||
) from e
|
||||
self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
|
||||
self._root_client = openai.OpenAI(
|
||||
**self._client_params,
|
||||
**sync_specific, # type: ignore[arg-type]
|
||||
)
|
||||
return self._root_client
|
||||
|
||||
@root_client.setter
|
||||
def root_client(self, value: openai.OpenAI) -> None:
|
||||
self._root_client = value
|
||||
|
||||
@property
|
||||
def root_async_client(self) -> Any:
|
||||
if self._root_async_client is None:
|
||||
async_specific = {"http_client": self.http_async_client}
|
||||
self.root_async_client = openai.AsyncOpenAI(
|
||||
**client_params,
|
||||
self._root_async_client = openai.AsyncOpenAI(
|
||||
**self._client_params,
|
||||
**async_specific, # type: ignore[arg-type]
|
||||
)
|
||||
self.async_client = self.root_async_client.chat.completions
|
||||
return self
|
||||
return self._root_async_client
|
||||
|
||||
@root_async_client.setter
|
||||
def root_async_client(self, value: openai.AsyncOpenAI) -> None:
|
||||
self._root_async_client = value
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self.root_client.chat.completions
|
||||
return self._client
|
||||
|
||||
@client.setter
|
||||
def client(self, value: Any) -> None:
|
||||
self._client = value
|
||||
|
||||
@property
|
||||
def async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
self._async_client = self.root_async_client.chat.completions
|
||||
return self._async_client
|
||||
|
||||
@async_client.setter
|
||||
def async_client(self, value: Any) -> None:
|
||||
self._async_client = value
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
@@ -1961,6 +2071,9 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@@ -676,6 +676,16 @@ def test_openai_proxy() -> None:
|
||||
assert proxy.host == b"localhost"
|
||||
assert proxy.port == 8080
|
||||
|
||||
http_async_client = httpx.AsyncClient(proxy="http://localhost:8081")
|
||||
chat_openai = ChatOpenAI(http_async_client=http_async_client)
|
||||
mounts = chat_openai.async_client._client._client._mounts
|
||||
assert len(mounts) == 1
|
||||
for key, value in mounts.items():
|
||||
proxy = value._pool._proxy_url.origin
|
||||
assert proxy.scheme == b"http"
|
||||
assert proxy.host == b"localhost"
|
||||
assert proxy.port == 8081
|
||||
|
||||
|
||||
def test_openai_response_headers() -> None:
|
||||
"""Test ChatOpenAI response headers."""
|
||||
|
||||
@@ -14,6 +14,7 @@ def test_initialize_azure_openai() -> None:
|
||||
azure_deployment="35-turbo-dev",
|
||||
openai_api_version="2023-05-15",
|
||||
azure_endpoint="my-base-url",
|
||||
http_client=None,
|
||||
)
|
||||
assert llm.deployment_name == "35-turbo-dev"
|
||||
assert llm.openai_api_version == "2023-05-15"
|
||||
|
||||
@@ -298,7 +298,7 @@ async def test_glm4_astream(mock_glm4_completion: list) -> None:
|
||||
usage_chunk = mock_glm4_completion[-1]
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
with patch.object(llm, "_async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
@@ -323,7 +323,7 @@ def test_glm4_stream(mock_glm4_completion: list) -> None:
|
||||
usage_chunk = mock_glm4_completion[-1]
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
with patch.object(llm, "_client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
@@ -378,7 +378,7 @@ async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_deepseek_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
with patch.object(llm, "_async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
@@ -402,7 +402,7 @@ def test_deepseek_stream(mock_deepseek_completion: list) -> None:
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_deepseek_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
with patch.object(llm, "_client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
@@ -446,7 +446,7 @@ async def test_openai_astream(mock_openai_completion: list) -> None:
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_openai_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
with patch.object(llm, "_async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
@@ -470,7 +470,7 @@ def test_openai_stream(mock_openai_completion: list) -> None:
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_openai_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
with patch.object(llm, "_client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
@@ -533,7 +533,7 @@ def mock_async_client(mock_completion: dict) -> AsyncMock:
|
||||
def test_openai_invoke(mock_client: MagicMock) -> None:
|
||||
llm = ChatOpenAI()
|
||||
|
||||
with patch.object(llm, "client", mock_client):
|
||||
with patch.object(llm, "_client", mock_client):
|
||||
res = llm.invoke("bar")
|
||||
assert res.content == "Bar Baz"
|
||||
|
||||
@@ -541,11 +541,13 @@ def test_openai_invoke(mock_client: MagicMock) -> None:
|
||||
assert "headers" not in res.response_metadata
|
||||
assert mock_client.create.called
|
||||
|
||||
assert llm._async_client is None
|
||||
|
||||
|
||||
async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None:
|
||||
llm = ChatOpenAI()
|
||||
|
||||
with patch.object(llm, "async_client", mock_async_client):
|
||||
with patch.object(llm, "_async_client", mock_async_client):
|
||||
res = await llm.ainvoke("bar")
|
||||
assert res.content == "Bar Baz"
|
||||
|
||||
@@ -573,7 +575,7 @@ def test__get_encoding_model(model: str) -> None:
|
||||
def test_openai_invoke_name(mock_client: MagicMock) -> None:
|
||||
llm = ChatOpenAI()
|
||||
|
||||
with patch.object(llm, "client", mock_client):
|
||||
with patch.object(llm, "_client", mock_client):
|
||||
messages = [HumanMessage(content="Foo", name="Katie")]
|
||||
res = llm.invoke(messages)
|
||||
call_args, call_kwargs = mock_client.create.call_args
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.utils import secret_from_env
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
@@ -325,7 +324,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
if self.n is not None and self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
client_params: dict = {
|
||||
self._client_params: dict = {
|
||||
"api_key": (
|
||||
self.xai_api_key.get_secret_value() if self.xai_api_key else None
|
||||
),
|
||||
@@ -335,27 +334,12 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
self._client_params["max_retries"] = self.max_retries
|
||||
|
||||
if client_params["api_key"] is None:
|
||||
if self._client_params["api_key"] is None:
|
||||
raise ValueError(
|
||||
"xAI API key is not set. Please set it in the `xai_api_key` field or "
|
||||
"in the `XAI_API_KEY` environment variable."
|
||||
)
|
||||
|
||||
if not (self.client or None):
|
||||
sync_specific: dict = {"http_client": self.http_client}
|
||||
self.client = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
self.root_client = openai.OpenAI(**client_params, **sync_specific)
|
||||
if not (self.async_client or None):
|
||||
async_specific: dict = {"http_client": self.http_async_client}
|
||||
self.async_client = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
self.root_async_client = openai.AsyncOpenAI(
|
||||
**client_params,
|
||||
**async_specific,
|
||||
)
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user