Compare commits

...

15 Commits

Author SHA1 Message Date
Chester Curme
6f67aacab3 add test 2025-02-22 09:28:45 -05:00
Chester Curme
b07bb53f5b setters for http clients 2025-02-21 23:32:31 -05:00
Chester Curme
e19fed1fcc lint 2025-02-21 21:23:07 -05:00
Chester Curme
c23d7f1c2a cr 2025-02-21 21:22:45 -05:00
Chester Curme
0c4fdeae67 update deepseek 2025-02-21 21:14:19 -05:00
Chester Curme
be587d1640 update xai 2025-02-21 21:13:30 -05:00
Chester Curme
230f2e030a add setters 2025-02-21 21:12:12 -05:00
Chester Curme
7506c72346 remove redundant properties on azure 2025-02-21 21:11:49 -05:00
Chester Curme
24765e49fd set global ssl context 2025-02-21 19:45:43 -05:00
Chester Curme
69400b8704 refactor 2025-02-21 19:32:09 -05:00
Chester Curme
4588e06794 clients -> private attributes 2025-02-21 19:08:41 -05:00
Chester Curme
adcb5396d9 update 2025-02-21 16:54:54 -05:00
Chester Curme
619a885263 add test 2025-02-21 15:47:37 -05:00
Chester Curme
6c2474b220 update azure 2025-02-21 15:25:37 -05:00
Chester Curme
c289fc9ba9 update base class 2025-02-21 15:24:49 -05:00
7 changed files with 211 additions and 95 deletions

View File

@@ -178,7 +178,7 @@ class ChatDeepSeek(BaseChatOpenAI):
self.api_key and self.api_key.get_secret_value()
):
raise ValueError("If using default api base, DEEPSEEK_API_KEY must be set.")
client_params: dict = {
self._client_params: dict = {
k: v
for k, v in {
"api_key": self.api_key.get_secret_value() if self.api_key else None,
@@ -191,16 +191,6 @@ class ChatDeepSeek(BaseChatOpenAI):
if v is not None
}
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
return self
def _create_chat_result(

View File

@@ -629,7 +629,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
"Or you can equivalently specify:\n\n"
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
)
client_params: dict = {
self._client_params: dict = {
"api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment_name,
@@ -650,27 +650,43 @@ class AzureChatOpenAI(BaseChatOpenAI):
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
self._client_params["max_retries"] = self.max_retries
if not self.client:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions
if not self.async_client:
async_specific = {"http_client": self.http_async_client}
if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)
self.root_async_client = openai.AsyncAzureOpenAI(
**client_params,
**async_specific, # type: ignore[arg-type]
if self.azure_ad_async_token_provider:
self._client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)
self.async_client = self.root_async_client.chat.completions
return self
@property
def root_client(self) -> Any:
if self._root_client is None:
sync_specific = {"http_client": self.http_client}
self._root_client = openai.AzureOpenAI(
**self._client_params,
**sync_specific, # type: ignore[call-overload]
)
return self._root_client
@root_client.setter
def root_client(self, value: openai.AzureOpenAI) -> None:
self._root_client = value
@property
def root_async_client(self) -> Any:
if self._root_async_client is None:
async_specific = {"http_client": self.http_async_client}
self._root_async_client = openai.AsyncAzureOpenAI(
**self._client_params,
**async_specific, # type: ignore[call-overload]
)
return self._root_async_client
@root_async_client.setter
def root_async_client(self, value: openai.AsyncAzureOpenAI) -> None:
self._root_async_client = value
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""

View File

@@ -6,12 +6,14 @@ import base64
import json
import logging
import os
import ssl
import sys
import warnings
from io import BytesIO
from math import ceil
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
@@ -31,6 +33,7 @@ from typing import (
)
from urllib.parse import urlparse
import certifi
import openai
import tiktoken
from langchain_core._api.deprecation import deprecated
@@ -91,12 +94,25 @@ from langchain_core.utils.pydantic import (
is_basemodel_subclass,
)
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
model_validator,
)
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self
if TYPE_CHECKING:
import httpx
logger = logging.getLogger(__name__)
# This SSL context is equivelent to the default `verify=True`.
global_ssl_context = ssl.create_default_context(cafile=certifi.where())
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
@@ -385,10 +401,10 @@ class _AllReturnType(TypedDict):
class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
root_client: Any = Field(default=None, exclude=True) #: :meta private:
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
_client: Any = PrivateAttr(default=None) #: :meta private:
_async_client: Any = PrivateAttr(default=None) #: :meta private:
_root_client: Any = PrivateAttr(default=None) #: :meta private:
_root_async_client: Any = PrivateAttr(default=None) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: Optional[float] = None
@@ -460,11 +476,11 @@ class BaseChatOpenAI(BaseChatModel):
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = Field(default=None, exclude=True)
_http_client: Union[Any, None] = PrivateAttr(default=None)
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
_http_async_client: Union[Any, None] = PrivateAttr(default=None)
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
@@ -491,6 +507,7 @@ class BaseChatOpenAI(BaseChatModel):
However this does not prevent a user from directly passed in the parameter during
invocation.
"""
_client_params: Dict[str, Any] = PrivateAttr(default_factory=dict)
model_config = ConfigDict(populate_by_name=True)
@@ -511,6 +528,24 @@ class BaseChatOpenAI(BaseChatModel):
values["temperature"] = 1
return values
def __init__(
self,
client: Optional[Any] = None,
async_client: Optional[Any] = None,
root_client: Optional[Any] = None,
async_root_client: Optional[Any] = None,
http_client: Optional[Any] = None,
http_async_client: Optional[Any] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._client = client
self._async_client = async_client
self._root_client = root_client
self._async_root_client = async_root_client
self._http_client = http_client
self._http_async_client = http_async_client
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
@@ -526,7 +561,7 @@ class BaseChatOpenAI(BaseChatModel):
or os.getenv("OPENAI_ORGANIZATION")
)
self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE")
client_params: dict = {
self._client_params: dict = {
"api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None
),
@@ -537,47 +572,122 @@ class BaseChatOpenAI(BaseChatModel):
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
self._client_params["max_retries"] = self.max_retries
if self.openai_proxy and (self.http_client or self.http_async_client):
if self.openai_proxy and (self._http_client or self._http_async_client):
openai_proxy = self.openai_proxy
http_client = self.http_client
http_async_client = self.http_async_client
http_client = self._http_client
http_async_client = self._http_async_client
raise ValueError(
"Cannot specify 'openai_proxy' if one of "
"'http_client'/'http_async_client' is already specified. Received:\n"
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
)
if not self.client:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self.http_client = httpx.Client(proxy=self.openai_proxy)
return self
@property
def http_client(self) -> Optional[httpx.Client]:
"""Optional httpx.Client. Only used for sync invocations.
Must specify http_async_client as well if you'd like a custom client for
async invocations.
"""
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more
# details.
if self._http_client is None:
if not self.openai_proxy:
return None
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self._http_client = httpx.Client(
proxy=self.openai_proxy, verify=global_ssl_context
)
return self._http_client
@http_client.setter
def http_client(self, value: Optional[httpx.Client]) -> None:
self._http_client = value
@property
def http_async_client(self) -> Optional[httpx.AsyncClient]:
"""Optional httpx.AsyncClient. Only used for async invocations.
Must specify http_client as well if you'd like a custom client for sync
invocations.
"""
if self._http_async_client is None:
if not self.openai_proxy:
return None
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self._http_async_client = httpx.AsyncClient(
proxy=self.openai_proxy, verify=global_ssl_context
)
return self._http_async_client
@http_async_client.setter
def http_async_client(self, value: Optional[httpx.AsyncClient]) -> None:
self._http_async_client = value
@property
def root_client(self) -> Any:
if self._root_client is None:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions
if not self.async_client:
if self.openai_proxy and not self.http_async_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
self._root_client = openai.OpenAI(
**self._client_params,
**sync_specific, # type: ignore[arg-type]
)
return self._root_client
@root_client.setter
def root_client(self, value: openai.OpenAI) -> None:
self._root_client = value
@property
def root_async_client(self) -> Any:
if self._root_async_client is None:
async_specific = {"http_client": self.http_async_client}
self.root_async_client = openai.AsyncOpenAI(
**client_params,
self._root_async_client = openai.AsyncOpenAI(
**self._client_params,
**async_specific, # type: ignore[arg-type]
)
self.async_client = self.root_async_client.chat.completions
return self
return self._root_async_client
@root_async_client.setter
def root_async_client(self, value: openai.AsyncOpenAI) -> None:
self._root_async_client = value
@property
def client(self) -> Any:
if self._client is None:
self._client = self.root_client.chat.completions
return self._client
@client.setter
def client(self, value: Any) -> None:
self._client = value
@property
def async_client(self) -> Any:
if self._async_client is None:
self._async_client = self.root_async_client.chat.completions
return self._async_client
@async_client.setter
def async_client(self, value: Any) -> None:
self._async_client = value
@property
def _default_params(self) -> Dict[str, Any]:
@@ -1961,6 +2071,9 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
"""Maximum number of tokens to generate."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}

View File

@@ -676,6 +676,16 @@ def test_openai_proxy() -> None:
assert proxy.host == b"localhost"
assert proxy.port == 8080
http_async_client = httpx.AsyncClient(proxy="http://localhost:8081")
chat_openai = ChatOpenAI(http_async_client=http_async_client)
mounts = chat_openai.async_client._client._client._mounts
assert len(mounts) == 1
for key, value in mounts.items():
proxy = value._pool._proxy_url.origin
assert proxy.scheme == b"http"
assert proxy.host == b"localhost"
assert proxy.port == 8081
def test_openai_response_headers() -> None:
"""Test ChatOpenAI response headers."""

View File

@@ -14,6 +14,7 @@ def test_initialize_azure_openai() -> None:
azure_deployment="35-turbo-dev",
openai_api_version="2023-05-15",
azure_endpoint="my-base-url",
http_client=None,
)
assert llm.deployment_name == "35-turbo-dev"
assert llm.openai_api_version == "2023-05-15"

View File

@@ -298,7 +298,7 @@ async def test_glm4_astream(mock_glm4_completion: list) -> None:
usage_chunk = mock_glm4_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
with patch.object(llm, "_async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
@@ -323,7 +323,7 @@ def test_glm4_stream(mock_glm4_completion: list) -> None:
usage_chunk = mock_glm4_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
with patch.object(llm, "_client", mock_client):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
@@ -378,7 +378,7 @@ async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
mock_client.create = mock_create
usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
with patch.object(llm, "_async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
@@ -402,7 +402,7 @@ def test_deepseek_stream(mock_deepseek_completion: list) -> None:
mock_client.create = mock_create
usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
with patch.object(llm, "_client", mock_client):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
@@ -446,7 +446,7 @@ async def test_openai_astream(mock_openai_completion: list) -> None:
mock_client.create = mock_create
usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
with patch.object(llm, "_async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
@@ -470,7 +470,7 @@ def test_openai_stream(mock_openai_completion: list) -> None:
mock_client.create = mock_create
usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
with patch.object(llm, "_client", mock_client):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
@@ -533,7 +533,7 @@ def mock_async_client(mock_completion: dict) -> AsyncMock:
def test_openai_invoke(mock_client: MagicMock) -> None:
llm = ChatOpenAI()
with patch.object(llm, "client", mock_client):
with patch.object(llm, "_client", mock_client):
res = llm.invoke("bar")
assert res.content == "Bar Baz"
@@ -541,11 +541,13 @@ def test_openai_invoke(mock_client: MagicMock) -> None:
assert "headers" not in res.response_metadata
assert mock_client.create.called
assert llm._async_client is None
async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None:
llm = ChatOpenAI()
with patch.object(llm, "async_client", mock_async_client):
with patch.object(llm, "_async_client", mock_async_client):
res = await llm.ainvoke("bar")
assert res.content == "Bar Baz"
@@ -573,7 +575,7 @@ def test__get_encoding_model(model: str) -> None:
def test_openai_invoke_name(mock_client: MagicMock) -> None:
llm = ChatOpenAI()
with patch.object(llm, "client", mock_client):
with patch.object(llm, "_client", mock_client):
messages = [HumanMessage(content="Foo", name="Katie")]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args

View File

@@ -7,7 +7,6 @@ from typing import (
Optional,
)
import openai
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.utils import secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
@@ -325,7 +324,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
if self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
client_params: dict = {
self._client_params: dict = {
"api_key": (
self.xai_api_key.get_secret_value() if self.xai_api_key else None
),
@@ -335,27 +334,12 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
self._client_params["max_retries"] = self.max_retries
if client_params["api_key"] is None:
if self._client_params["api_key"] is None:
raise ValueError(
"xAI API key is not set. Please set it in the `xai_api_key` field or "
"in the `XAI_API_KEY` environment variable."
)
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
self.root_client = openai.OpenAI(**client_params, **sync_specific)
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
self.root_async_client = openai.AsyncOpenAI(
**client_params,
**async_specific,
)
return self