diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 4b8ec3e016e..f11c13ce96d 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -380,12 +380,31 @@ class ChatOpenAI(BaseChatModel): "default_query": values["default_query"], } + openai_proxy = values["openai_proxy"] if not values.get("client"): + if openai_proxy and not values["http_client"]: + try: + import httpx + except ImportError as e: + raise ImportError( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) from e + values["http_client"] = httpx.Client(proxy=openai_proxy) sync_specific = {"http_client": values["http_client"]} values["client"] = openai.OpenAI( **client_params, **sync_specific ).chat.completions if not values.get("async_client"): + if openai_proxy and not values["http_async_client"]: + try: + import httpx + except ImportError as e: + raise ImportError( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) from e + values["http_async_client"] = httpx.AsyncClient(proxy=openai_proxy) async_specific = {"http_client": values["http_async_client"]} values["async_client"] = openai.AsyncOpenAI( **client_params, **async_specific diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 50baef64d14..b65e0f7ce55 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -501,3 +501,25 @@ def test_openai_structured_output() -> None: assert isinstance(result, MyModel) assert result.name == "Erick" assert result.age == 27 + + +def test_openai_proxy() -> None: + """Test ChatOpenAI with proxy.""" + chat_openai = ChatOpenAI( + openai_proxy="http://localhost:8080", + ) + mounts = chat_openai.client._client._client._mounts + assert len(mounts) == 1 + for key, value in mounts.items(): + proxy = value._pool._proxy_url.origin + assert proxy.scheme == b"http" + assert proxy.host == b"localhost" + assert proxy.port == 8080 + + async_client_mounts = chat_openai.async_client._client._client._mounts + assert len(async_client_mounts) == 1 + for key, value in async_client_mounts.items(): + proxy = value._pool._proxy_url.origin + assert proxy.scheme == b"http" + assert proxy.host == b"localhost" + assert proxy.port == 8080