anthropic: less pydantic for client (#28823)

This commit is contained in:
Erick Friis 2024-12-19 08:00:02 -08:00 committed by GitHub
parent f1d783748a
commit ff7b01af88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
import copy
import re
import warnings
from functools import cached_property
from operator import itemgetter
from typing import (
Any,
@ -68,11 +69,10 @@ from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import NotRequired, Self
from typing_extensions import NotRequired
from langchain_anthropic.output_parsers import extract_tool_calls
@ -541,9 +541,6 @@ class ChatAnthropic(BaseChatModel):
populate_by_name=True,
)
_client: anthropic.Client = PrivateAttr(default=None) # type: ignore[assignment]
_async_client: anthropic.AsyncClient = PrivateAttr(default=None) # type: ignore[assignment]
model: str = Field(alias="model_name")
"""Model name to use."""
@ -661,13 +658,11 @@ class ChatAnthropic(BaseChatModel):
values = _build_model_kwargs(values, all_required_field_names)
return values
@model_validator(mode="after")
def post_init(self) -> Self:
api_key = self.anthropic_api_key.get_secret_value()
api_url = self.anthropic_api_url
@cached_property
def _client_params(self) -> Dict[str, Any]:
client_params: Dict[str, Any] = {
"api_key": api_key,
"base_url": api_url,
"api_key": self.anthropic_api_key.get_secret_value(),
"base_url": self.anthropic_api_url,
"max_retries": self.max_retries,
"default_headers": (self.default_headers or None),
}
@ -677,9 +672,15 @@ class ChatAnthropic(BaseChatModel):
if self.default_request_timeout is None or self.default_request_timeout > 0:
client_params["timeout"] = self.default_request_timeout
self._client = anthropic.Client(**client_params)
self._async_client = anthropic.AsyncClient(**client_params)
return self
return client_params
@cached_property
def _client(self) -> anthropic.Client:
return anthropic.Client(**self._client_params)
@cached_property
def _async_client(self) -> anthropic.AsyncClient:
return anthropic.AsyncClient(**self._client_params)
def _get_request_payload(
self,