mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 09:23:57 +00:00
anthropic[patch]: cache clients (#31659)
This commit is contained in:
parent
e3f1ce0ac5
commit
b02bd67788
75
libs/partners/anthropic/langchain_anthropic/_client_utils.py
Normal file
75
libs/partners/anthropic/langchain_anthropic/_client_utils.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Helpers for creating Anthropic API clients.
|
||||
|
||||
This module allows for the caching of httpx clients to avoid creating new instances
|
||||
for each instance of ChatAnthropic.
|
||||
|
||||
Logic is largely replicated from anthropic._base_client.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Optional
|
||||
|
||||
import anthropic
|
||||
|
||||
_NOT_GIVEN: Any = object()
|
||||
|
||||
|
||||
class _SyncHttpxClientWrapper(anthropic.DefaultHttpxClient):
|
||||
"""Borrowed from anthropic._base_client"""
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class _AsyncHttpxClientWrapper(anthropic.DefaultAsyncHttpxClient):
|
||||
"""Borrowed from anthropic._base_client"""
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
# TODO(someday): support non asyncio runtimes here
|
||||
asyncio.get_running_loop().create_task(self.aclose())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_default_httpx_client(
|
||||
*,
|
||||
base_url: Optional[str],
|
||||
timeout: Any = _NOT_GIVEN,
|
||||
) -> _SyncHttpxClientWrapper:
|
||||
kwargs: dict[str, Any] = {
|
||||
"base_url": base_url
|
||||
or os.environ.get("ANTHROPIC_BASE_URL")
|
||||
or "https://api.anthropic.com",
|
||||
}
|
||||
if timeout is not _NOT_GIVEN:
|
||||
kwargs["timeout"] = timeout
|
||||
return _SyncHttpxClientWrapper(**kwargs)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_default_async_httpx_client(
|
||||
*,
|
||||
base_url: Optional[str],
|
||||
timeout: Any = _NOT_GIVEN,
|
||||
) -> _AsyncHttpxClientWrapper:
|
||||
kwargs: dict[str, Any] = {
|
||||
"base_url": base_url
|
||||
or os.environ.get("ANTHROPIC_BASE_URL")
|
||||
or "https://api.anthropic.com",
|
||||
}
|
||||
if timeout is not _NOT_GIVEN:
|
||||
kwargs["timeout"] = timeout
|
||||
return _AsyncHttpxClientWrapper(**kwargs)
|
@ -69,6 +69,10 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_anthropic._client_utils import (
|
||||
_get_default_async_httpx_client,
|
||||
_get_default_httpx_client,
|
||||
)
|
||||
from langchain_anthropic.output_parsers import extract_tool_calls
|
||||
|
||||
_message_type_lookups = {
|
||||
@ -1300,11 +1304,29 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
@cached_property
|
||||
def _client(self) -> anthropic.Client:
|
||||
return anthropic.Client(**self._client_params)
|
||||
client_params = self._client_params
|
||||
http_client_params = {"base_url": client_params["base_url"]}
|
||||
if "timeout" in client_params:
|
||||
http_client_params["timeout"] = client_params["timeout"]
|
||||
http_client = _get_default_httpx_client(**http_client_params)
|
||||
params = {
|
||||
**client_params,
|
||||
"http_client": http_client,
|
||||
}
|
||||
return anthropic.Client(**params)
|
||||
|
||||
@cached_property
|
||||
def _async_client(self) -> anthropic.AsyncClient:
|
||||
return anthropic.AsyncClient(**self._client_params)
|
||||
client_params = self._client_params
|
||||
http_client_params = {"base_url": client_params["base_url"]}
|
||||
if "timeout" in client_params:
|
||||
http_client_params["timeout"] = client_params["timeout"]
|
||||
http_client = _get_default_async_httpx_client(**http_client_params)
|
||||
params = {
|
||||
**client_params,
|
||||
"http_client": http_client,
|
||||
}
|
||||
return anthropic.AsyncClient(**params)
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test ChatAnthropic chat model."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from base64 import b64encode
|
||||
@ -1082,3 +1083,10 @@ def test_files_api_pdf(block_format: str) -> None:
|
||||
],
|
||||
}
|
||||
_ = llm.invoke([input_message])
|
||||
|
||||
|
||||
def test_async_shared_client() -> None:
|
||||
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
|
||||
llm._async_client # Instantiates lazily
|
||||
_ = asyncio.run(llm.ainvoke("Hello"))
|
||||
_ = asyncio.run(llm.ainvoke("Hello"))
|
||||
|
@ -44,6 +44,22 @@ def test_initialization() -> None:
|
||||
assert model.anthropic_api_url == "https://api.anthropic.com"
|
||||
|
||||
|
||||
def test_anthropic_client_caching() -> None:
|
||||
"""Test that the OpenAI client is cached."""
|
||||
llm1 = ChatAnthropic(model="claude-3-5-sonnet-latest")
|
||||
llm2 = ChatAnthropic(model="claude-3-5-sonnet-latest")
|
||||
assert llm1._client._client is llm2._client._client
|
||||
|
||||
llm3 = ChatAnthropic(model="claude-3-5-sonnet-latest", base_url="foo")
|
||||
assert llm1._client._client is not llm3._client._client
|
||||
|
||||
llm4 = ChatAnthropic(model="claude-3-5-sonnet-latest", timeout=None)
|
||||
assert llm1._client._client is llm4._client._client
|
||||
|
||||
llm5 = ChatAnthropic(model="claude-3-5-sonnet-latest", timeout=3)
|
||||
assert llm1._client._client is not llm5._client._client
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_name_param() -> None:
|
||||
llm = ChatAnthropic(model_name="foo") # type: ignore[call-arg, call-arg]
|
||||
|
Loading…
Reference in New Issue
Block a user