diff --git a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py index 3543c6add1a..d10b815a9ea 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py +++ b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py @@ -41,8 +41,7 @@ class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient): pass -@lru_cache -def _get_default_httpx_client( +def _build_sync_httpx_client( base_url: Optional[str], timeout: Any ) -> _SyncHttpxClientWrapper: return _SyncHttpxClientWrapper( @@ -53,8 +52,7 @@ def _get_default_httpx_client( ) -@lru_cache -def _get_default_async_httpx_client( +def _build_async_httpx_client( base_url: Optional[str], timeout: Any ) -> _AsyncHttpxClientWrapper: return _AsyncHttpxClientWrapper( @@ -63,3 +61,47 @@ def _get_default_async_httpx_client( or "https://api.openai.com/v1", timeout=timeout, ) + + +@lru_cache +def _cached_sync_httpx_client( + base_url: Optional[str], timeout: Any +) -> _SyncHttpxClientWrapper: + return _build_sync_httpx_client(base_url, timeout) + + +@lru_cache +def _cached_async_httpx_client( + base_url: Optional[str], timeout: Any +) -> _AsyncHttpxClientWrapper: + return _build_async_httpx_client(base_url, timeout) + + +def _get_default_httpx_client( + base_url: Optional[str], timeout: Any +) -> _SyncHttpxClientWrapper: + """Get default httpx client. + + Uses cached client unless timeout is ``httpx.Timeout``, which is not hashable. + """ + try: + hash(timeout) + except TypeError: + return _build_sync_httpx_client(base_url, timeout) + else: + return _cached_sync_httpx_client(base_url, timeout) + + +def _get_default_async_httpx_client( + base_url: Optional[str], timeout: Any +) -> _AsyncHttpxClientWrapper: + """Get default httpx client. + + Uses cached client unless timeout is ``httpx.Timeout``, which is not hashable. + """ + try: + hash(timeout) + except TypeError: + return _build_async_httpx_client(base_url, timeout) + else: + return _cached_async_httpx_client(base_url, timeout) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index aa7bead4a9b..c9fce46bc9e 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -6,6 +6,7 @@ from types import TracebackType from typing import Any, Literal, Optional, Union, cast from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest from langchain_core.load import dumps, loads from langchain_core.messages import ( @@ -74,6 +75,30 @@ def test_openai_model_param() -> None: assert llm.max_tokens == 10 +def test_openai_client_caching() -> None: + """Test that the OpenAI client is cached.""" + llm1 = ChatOpenAI(model="gpt-4.1-mini") + llm2 = ChatOpenAI(model="gpt-4.1-mini") + assert llm1.root_client._client is llm2.root_client._client + + llm3 = ChatOpenAI(model="gpt-4.1-mini", base_url="foo") + assert llm1.root_client._client is not llm3.root_client._client + + llm4 = ChatOpenAI(model="gpt-4.1-mini", timeout=None) + assert llm1.root_client._client is llm4.root_client._client + + llm5 = ChatOpenAI(model="gpt-4.1-mini", timeout=3) + assert llm1.root_client._client is not llm5.root_client._client + + llm6 = ChatOpenAI( + model="gpt-4.1-mini", timeout=httpx.Timeout(timeout=60.0, connect=5.0) + ) + assert llm1.root_client._client is not llm6.root_client._client + + llm7 = ChatOpenAI(model="gpt-4.1-mini", timeout=(5, 1)) + assert llm1.root_client._client is not llm7.root_client._client + + def test_openai_o1_temperature() -> None: llm = ChatOpenAI(model="o1-preview") assert llm.temperature == 1