anthropic[patch]: cache clients (#31659)

This commit is contained in:
ccurme 2025-06-25 14:49:02 -04:00 committed by GitHub
parent e3f1ce0ac5
commit b02bd67788
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 123 additions and 2 deletions

View File

@ -0,0 +1,75 @@
"""Helpers for creating Anthropic API clients.
This module allows for the caching of httpx clients to avoid creating new instances
for each instance of ChatAnthropic.
Logic is largely replicated from anthropic._base_client.
"""
import asyncio
import os
from functools import lru_cache
from typing import Any, Optional
import anthropic
_NOT_GIVEN: Any = object()
class _SyncHttpxClientWrapper(anthropic.DefaultHttpxClient):
"""Borrowed from anthropic._base_client"""
def __del__(self) -> None:
if self.is_closed:
return
try:
self.close()
except Exception:
pass
class _AsyncHttpxClientWrapper(anthropic.DefaultAsyncHttpxClient):
"""Borrowed from anthropic._base_client"""
def __del__(self) -> None:
if self.is_closed:
return
try:
# TODO(someday): support non asyncio runtimes here
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass
@lru_cache
def _get_default_httpx_client(
*,
base_url: Optional[str],
timeout: Any = _NOT_GIVEN,
) -> _SyncHttpxClientWrapper:
kwargs: dict[str, Any] = {
"base_url": base_url
or os.environ.get("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com",
}
if timeout is not _NOT_GIVEN:
kwargs["timeout"] = timeout
return _SyncHttpxClientWrapper(**kwargs)
@lru_cache
def _get_default_async_httpx_client(
*,
base_url: Optional[str],
timeout: Any = _NOT_GIVEN,
) -> _AsyncHttpxClientWrapper:
kwargs: dict[str, Any] = {
"base_url": base_url
or os.environ.get("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com",
}
if timeout is not _NOT_GIVEN:
kwargs["timeout"] = timeout
return _AsyncHttpxClientWrapper(**kwargs)

View File

@ -69,6 +69,10 @@ from pydantic import (
)
from typing_extensions import NotRequired, TypedDict
from langchain_anthropic._client_utils import (
_get_default_async_httpx_client,
_get_default_httpx_client,
)
from langchain_anthropic.output_parsers import extract_tool_calls
_message_type_lookups = {
@ -1300,11 +1304,29 @@ class ChatAnthropic(BaseChatModel):
@cached_property
def _client(self) -> anthropic.Client:
return anthropic.Client(**self._client_params)
client_params = self._client_params
http_client_params = {"base_url": client_params["base_url"]}
if "timeout" in client_params:
http_client_params["timeout"] = client_params["timeout"]
http_client = _get_default_httpx_client(**http_client_params)
params = {
**client_params,
"http_client": http_client,
}
return anthropic.Client(**params)
@cached_property
def _async_client(self) -> anthropic.AsyncClient:
return anthropic.AsyncClient(**self._client_params)
client_params = self._client_params
http_client_params = {"base_url": client_params["base_url"]}
if "timeout" in client_params:
http_client_params["timeout"] = client_params["timeout"]
http_client = _get_default_async_httpx_client(**http_client_params)
params = {
**client_params,
"http_client": http_client,
}
return anthropic.AsyncClient(**params)
def _get_request_payload(
self,

View File

@ -1,5 +1,6 @@
"""Test ChatAnthropic chat model."""
import asyncio
import json
import os
from base64 import b64encode
@ -1082,3 +1083,10 @@ def test_files_api_pdf(block_format: str) -> None:
],
}
_ = llm.invoke([input_message])
def test_async_shared_client() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
llm._async_client # Instantiates lazily
_ = asyncio.run(llm.ainvoke("Hello"))
_ = asyncio.run(llm.ainvoke("Hello"))

View File

@ -44,6 +44,22 @@ def test_initialization() -> None:
assert model.anthropic_api_url == "https://api.anthropic.com"
def test_anthropic_client_caching() -> None:
"""Test that the OpenAI client is cached."""
llm1 = ChatAnthropic(model="claude-3-5-sonnet-latest")
llm2 = ChatAnthropic(model="claude-3-5-sonnet-latest")
assert llm1._client._client is llm2._client._client
llm3 = ChatAnthropic(model="claude-3-5-sonnet-latest", base_url="foo")
assert llm1._client._client is not llm3._client._client
llm4 = ChatAnthropic(model="claude-3-5-sonnet-latest", timeout=None)
assert llm1._client._client is llm4._client._client
llm5 = ChatAnthropic(model="claude-3-5-sonnet-latest", timeout=3)
assert llm1._client._client is not llm5._client._client
@pytest.mark.requires("anthropic")
def test_anthropic_model_name_param() -> None:
llm = ChatAnthropic(model_name="foo") # type: ignore[call-arg, call-arg]