openai[patch]: fix ChatOpenAI model's openai proxy (#19559)

Due to changes in the OpenAI SDK, the previous method of setting the
OpenAI proxy in ChatOpenAI no longer works. This PR fixes this issue,
making the previous way of setting the OpenAI proxy in ChatOpenAI
effective again.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Shuqian
2024-03-28 14:16:55 +08:00
committed by William Fu-Hinthorn
parent f4e43a5b8a
commit cde9dc23cd
2 changed files with 41 additions and 0 deletions

View File

@@ -380,12 +380,31 @@ class ChatOpenAI(BaseChatModel):
"default_query": values["default_query"],
}
openai_proxy = values["openai_proxy"]
if not values.get("client"):
if openai_proxy and not values["http_client"]:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
values["http_client"] = httpx.Client(proxy=openai_proxy)
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
if not values.get("async_client"):
if openai_proxy and not values["http_async_client"]:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
values["http_async_client"] = httpx.AsyncClient(proxy=openai_proxy)
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific

View File

@@ -501,3 +501,25 @@ def test_openai_structured_output() -> None:
assert isinstance(result, MyModel)
assert result.name == "Erick"
assert result.age == 27
def test_openai_proxy() -> None:
"""Test ChatOpenAI with proxy."""
chat_openai = ChatOpenAI(
openai_proxy="http://localhost:8080",
)
mounts = chat_openai.client._client._client._mounts
assert len(mounts) == 1
for key, value in mounts.items():
proxy = value._pool._proxy_url.origin
assert proxy.scheme == b"http"
assert proxy.host == b"localhost"
assert proxy.port == 8080
async_client_mounts = chat_openai.async_client._client._client._mounts
assert len(async_client_mounts) == 1
for key, value in async_client_mounts.items():
proxy = value._pool._proxy_url.origin
assert proxy.scheme == b"http"
assert proxy.host == b"localhost"
assert proxy.port == 8080