This commit is contained in:
Chester Curme 2025-02-21 16:54:54 -05:00
parent 619a885263
commit adcb5396d9
4 changed files with 37 additions and 25 deletions

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import logging import logging
import os import os
from functools import cached_property
from typing import ( from typing import (
Any, Any,
Awaitable, Awaitable,
@ -660,35 +659,38 @@ class AzureChatOpenAI(BaseChatOpenAI):
return self return self
@cached_property @property
def _root_client(self) -> openai.AzureOpenAI: def _root_client(self) -> openai.AzureOpenAI:
if self.root_client is not None: if self.root_client is not None:
return self.root_client return self.root_client
sync_specific = {"http_client": self._http_client} sync_specific = {"http_client": self._http_client}
return openai.AzureOpenAI(**self._client_params, **sync_specific) # type: ignore[call-overload] self.root_client = openai.AzureOpenAI(**self._client_params, **sync_specific) # type: ignore[call-overload]
return self.root_client
@cached_property @property
def _root_async_client(self) -> openai.AsyncAzureOpenAI: def _root_async_client(self) -> openai.AsyncAzureOpenAI:
if self.root_async_client is not None: if self.root_async_client is not None:
return self.root_async_client return self.root_async_client
async_specific = {"http_client": self._http_async_client} async_specific = {"http_client": self._http_async_client}
self.root_async_client = openai.AsyncAzureOpenAI(
return openai.AsyncAzureOpenAI(
**self._client_params, **self._client_params,
**async_specific, # type: ignore[call-overload] **async_specific, # type: ignore[call-overload]
) )
return self._root_async_client
@cached_property @property
def _client(self) -> Any: def _client(self) -> Any:
if self.client is not None: if self.client is not None:
return self.client return self.client
return self._root_client.chat.completions self.client = self._root_client.chat.completions
return self.client
@cached_property @property
def _async_client(self) -> Any: def _async_client(self) -> Any:
if self.async_client is not None: if self.async_client is not None:
return self.async_client return self.async_client
return self._root_async_client.chat.completions self.async_client = self._root_async_client.chat.completions
return self.async_client
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:

View File

@ -8,7 +8,6 @@ import logging
import os import os
import sys import sys
import warnings import warnings
from functools import cached_property
from io import BytesIO from io import BytesIO
from math import ceil from math import ceil
from operator import itemgetter from operator import itemgetter
@ -564,7 +563,7 @@ class BaseChatOpenAI(BaseChatModel):
return self return self
@cached_property @property
def _http_client(self) -> Optional[httpx.Client]: def _http_client(self) -> Optional[httpx.Client]:
"""Optional httpx.Client. Only used for sync invocations. """Optional httpx.Client. Only used for sync invocations.
@ -585,9 +584,10 @@ class BaseChatOpenAI(BaseChatModel):
"Could not import httpx python package. " "Could not import httpx python package. "
"Please install it with `pip install httpx`." "Please install it with `pip install httpx`."
) from e ) from e
return httpx.Client(proxy=self.openai_proxy) self.http_client = httpx.Client(proxy=self.openai_proxy)
return self.http_client
@cached_property @property
def _http_async_client(self) -> Optional[httpx.AsyncClient]: def _http_async_client(self) -> Optional[httpx.AsyncClient]:
"""Optional httpx.AsyncClient. Only used for async invocations. """Optional httpx.AsyncClient. Only used for async invocations.
@ -605,36 +605,41 @@ class BaseChatOpenAI(BaseChatModel):
"Could not import httpx python package. " "Could not import httpx python package. "
"Please install it with `pip install httpx`." "Please install it with `pip install httpx`."
) from e ) from e
return httpx.AsyncClient(proxy=self.openai_proxy) self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
return self.http_async_client
@cached_property @property
def _root_client(self) -> openai.OpenAI: def _root_client(self) -> openai.OpenAI:
if self.root_client is not None: if self.root_client is not None:
return self.root_client return self.root_client
sync_specific = {"http_client": self._http_client} sync_specific = {"http_client": self._http_client}
return openai.OpenAI(**self._client_params, **sync_specific) # type: ignore[arg-type] self.root_client = openai.OpenAI(**self._client_params, **sync_specific) # type: ignore[arg-type]
return self.root_client
@cached_property @property
def _root_async_client(self) -> openai.AsyncOpenAI: def _root_async_client(self) -> openai.AsyncOpenAI:
if self.root_async_client is not None: if self.root_async_client is not None:
return self.root_async_client return self.root_async_client
async_specific = {"http_client": self._http_async_client} async_specific = {"http_client": self._http_async_client}
return openai.AsyncOpenAI( self.root_async_client = openai.AsyncOpenAI(
**self._client_params, **self._client_params,
**async_specific, # type: ignore[arg-type] **async_specific, # type: ignore[arg-type]
) )
return self.root_async_client
@cached_property @property
def _client(self) -> Any: def _client(self) -> Any:
if self.client is not None: if self.client is not None:
return self.client return self.client
return self._root_client.chat.completions self.client = self._root_client.chat.completions
return self.client
@cached_property @property
def _async_client(self) -> Any: def _async_client(self) -> Any:
if self.async_client is not None: if self.async_client is not None:
return self.async_client return self.async_client
return self._root_async_client.chat.completions self.async_client = self._root_async_client.chat.completions
return self.async_client
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:

View File

@ -660,6 +660,9 @@ def test_openai_structured_output(model: str) -> None:
def test_openai_proxy() -> None: def test_openai_proxy() -> None:
"""Test ChatOpenAI with proxy.""" """Test ChatOpenAI with proxy."""
chat_openai = ChatOpenAI(openai_proxy="http://localhost:8080") chat_openai = ChatOpenAI(openai_proxy="http://localhost:8080")
assert chat_openai.client is None
_ = chat_openai._client # force client to instantiate
assert chat_openai.client is not None
mounts = chat_openai.client._client._client._mounts mounts = chat_openai.client._client._client._mounts
assert len(mounts) == 1 assert len(mounts) == 1
for key, value in mounts.items(): for key, value in mounts.items():
@ -668,6 +671,9 @@ def test_openai_proxy() -> None:
assert proxy.host == b"localhost" assert proxy.host == b"localhost"
assert proxy.port == 8080 assert proxy.port == 8080
assert chat_openai.async_client is None
_ = chat_openai._async_client # force client to instantiate
assert chat_openai.async_client is not None
async_client_mounts = chat_openai.async_client._client._client._mounts async_client_mounts = chat_openai.async_client._client._client._mounts
assert len(async_client_mounts) == 1 assert len(async_client_mounts) == 1
for key, value in async_client_mounts.items(): for key, value in async_client_mounts.items():

View File

@ -541,8 +541,7 @@ def test_openai_invoke(mock_client: MagicMock) -> None:
assert "headers" not in res.response_metadata assert "headers" not in res.response_metadata
assert mock_client.create.called assert mock_client.create.called
assert llm.root_client is None assert llm.async_client is None
assert llm.root_async_client is None
async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None: async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None: