diff --git a/libs/partners/anthropic/langchain_anthropic/_client_utils.py b/libs/partners/anthropic/langchain_anthropic/_client_utils.py new file mode 100644 index 00000000000..b98b9e16433 --- /dev/null +++ b/libs/partners/anthropic/langchain_anthropic/_client_utils.py @@ -0,0 +1,75 @@ +"""Helpers for creating Anthropic API clients. + +This module allows for the caching of httpx clients to avoid creating new instances +for each instance of ChatAnthropic. + +Logic is largely replicated from anthropic._base_client. +""" + +import asyncio +import os +from functools import lru_cache +from typing import Any, Optional + +import anthropic + +_NOT_GIVEN: Any = object() + + +class _SyncHttpxClientWrapper(anthropic.DefaultHttpxClient): + """Borrowed from anthropic._base_client""" + + def __del__(self) -> None: + if self.is_closed: + return + + try: + self.close() + except Exception: + pass + + +class _AsyncHttpxClientWrapper(anthropic.DefaultAsyncHttpxClient): + """Borrowed from anthropic._base_client""" + + def __del__(self) -> None: + if self.is_closed: + return + + try: + # TODO(someday): support non asyncio runtimes here + asyncio.get_running_loop().create_task(self.aclose()) + except Exception: + pass + + +@lru_cache +def _get_default_httpx_client( + *, + base_url: Optional[str], + timeout: Any = _NOT_GIVEN, +) -> _SyncHttpxClientWrapper: + kwargs: dict[str, Any] = { + "base_url": base_url + or os.environ.get("ANTHROPIC_BASE_URL") + or "https://api.anthropic.com", + } + if timeout is not _NOT_GIVEN: + kwargs["timeout"] = timeout + return _SyncHttpxClientWrapper(**kwargs) + + +@lru_cache +def _get_default_async_httpx_client( + *, + base_url: Optional[str], + timeout: Any = _NOT_GIVEN, +) -> _AsyncHttpxClientWrapper: + kwargs: dict[str, Any] = { + "base_url": base_url + or os.environ.get("ANTHROPIC_BASE_URL") + or "https://api.anthropic.com", + } + if timeout is not _NOT_GIVEN: + kwargs["timeout"] = timeout + return _AsyncHttpxClientWrapper(**kwargs) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 417aeb7f44f..e30bf35a435 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -69,6 +69,10 @@ from pydantic import ( ) from typing_extensions import NotRequired, TypedDict +from langchain_anthropic._client_utils import ( + _get_default_async_httpx_client, + _get_default_httpx_client, +) from langchain_anthropic.output_parsers import extract_tool_calls _message_type_lookups = { @@ -1300,11 +1304,29 @@ class ChatAnthropic(BaseChatModel): @cached_property def _client(self) -> anthropic.Client: - return anthropic.Client(**self._client_params) + client_params = self._client_params + http_client_params = {"base_url": client_params["base_url"]} + if "timeout" in client_params: + http_client_params["timeout"] = client_params["timeout"] + http_client = _get_default_httpx_client(**http_client_params) + params = { + **client_params, + "http_client": http_client, + } + return anthropic.Client(**params) @cached_property def _async_client(self) -> anthropic.AsyncClient: - return anthropic.AsyncClient(**self._client_params) + client_params = self._client_params + http_client_params = {"base_url": client_params["base_url"]} + if "timeout" in client_params: + http_client_params["timeout"] = client_params["timeout"] + http_client = _get_default_async_httpx_client(**http_client_params) + params = { + **client_params, + "http_client": http_client, + } + return anthropic.AsyncClient(**params) def _get_request_payload( self, diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index c40c5a4ab0a..93550c1fb55 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -1,5 +1,6 @@ """Test ChatAnthropic chat model.""" +import asyncio import json import os from base64 import b64encode @@ -1082,3 +1083,10 @@ def test_files_api_pdf(block_format: str) -> None: ], } _ = llm.invoke([input_message]) + + +def test_async_shared_client() -> None: + llm = ChatAnthropic(model="claude-3-5-haiku-latest") + llm._async_client # Instantiates lazily + _ = asyncio.run(llm.ainvoke("Hello")) + _ = asyncio.run(llm.ainvoke("Hello")) diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 2d418b6bddd..b0b0cc82b09 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -44,6 +44,22 @@ def test_initialization() -> None: assert model.anthropic_api_url == "https://api.anthropic.com" +def test_anthropic_client_caching() -> None: + """Test that the OpenAI client is cached.""" + llm1 = ChatAnthropic(model="claude-3-5-sonnet-latest") + llm2 = ChatAnthropic(model="claude-3-5-sonnet-latest") + assert llm1._client._client is llm2._client._client + + llm3 = ChatAnthropic(model="claude-3-5-sonnet-latest", base_url="foo") + assert llm1._client._client is not llm3._client._client + + llm4 = ChatAnthropic(model="claude-3-5-sonnet-latest", timeout=None) + assert llm1._client._client is llm4._client._client + + llm5 = ChatAnthropic(model="claude-3-5-sonnet-latest", timeout=3) + assert llm1._client._client is not llm5._client._client + + @pytest.mark.requires("anthropic") def test_anthropic_model_name_param() -> None: llm = ChatAnthropic(model_name="foo") # type: ignore[call-arg, call-arg]