update base class

This commit is contained in:
Chester Curme 2025-02-21 15:24:49 -05:00
parent 437fe6d216
commit c289fc9ba9

View File

@ -8,10 +8,12 @@ import logging
import os import os
import sys import sys
import warnings import warnings
from functools import cached_property
from io import BytesIO from io import BytesIO
from math import ceil from math import ceil
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
AsyncIterator, AsyncIterator,
Callable, Callable,
@ -91,10 +93,20 @@ from langchain_core.utils.pydantic import (
is_basemodel_subclass, is_basemodel_subclass,
) )
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
model_validator,
)
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self from typing_extensions import Self
if TYPE_CHECKING:
import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -491,6 +503,7 @@ class BaseChatOpenAI(BaseChatModel):
However this does not prevent a user from directly passed in the parameter during However this does not prevent a user from directly passed in the parameter during
invocation. invocation.
""" """
_client_params: Dict[str, Any] = PrivateAttr(default_factory=dict)
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
@ -526,7 +539,7 @@ class BaseChatOpenAI(BaseChatModel):
or os.getenv("OPENAI_ORGANIZATION") or os.getenv("OPENAI_ORGANIZATION")
) )
self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE") self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE")
client_params: dict = { self._client_params: dict = {
"api_key": ( "api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None self.openai_api_key.get_secret_value() if self.openai_api_key else None
), ),
@ -537,7 +550,7 @@ class BaseChatOpenAI(BaseChatModel):
"default_query": self.default_query, "default_query": self.default_query,
} }
if self.max_retries is not None: if self.max_retries is not None:
client_params["max_retries"] = self.max_retries self._client_params["max_retries"] = self.max_retries
if self.openai_proxy and (self.http_client or self.http_async_client): if self.openai_proxy and (self.http_client or self.http_async_client):
openai_proxy = self.openai_proxy openai_proxy = self.openai_proxy
@ -548,37 +561,81 @@ class BaseChatOpenAI(BaseChatModel):
"'http_client'/'http_async_client' is already specified. Received:\n" "'http_client'/'http_async_client' is already specified. Received:\n"
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}" f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
) )
if not self.client:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self.http_client = httpx.Client(proxy=self.openai_proxy)
sync_specific = {"http_client": self.http_client}
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions
if not self.async_client:
if self.openai_proxy and not self.http_async_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
async_specific = {"http_client": self.http_async_client}
self.root_async_client = openai.AsyncOpenAI(
**client_params,
**async_specific, # type: ignore[arg-type]
)
self.async_client = self.root_async_client.chat.completions
return self return self
@cached_property
def _http_client(self) -> Optional[httpx.Client]:
"""Optional httpx.Client. Only used for sync invocations.
Must specify http_async_client as well if you'd like a custom client for
async invocations.
"""
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more
# details.
if self.http_client is not None:
return self.http_client
if not self.openai_proxy:
return None
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
return httpx.Client(proxy=self.openai_proxy)
@cached_property
def _http_async_client(self) -> Optional[httpx.AsyncClient]:
"""Optional httpx.AsyncClient. Only used for async invocations.
Must specify http_client as well if you'd like a custom client for sync
invocations.
"""
if self.http_async_client is not None:
return self.http_async_client
if not self.openai_proxy:
return None
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
return httpx.AsyncClient(proxy=self.openai_proxy)
@cached_property
def _root_client(self) -> openai.OpenAI:
if self.root_client is not None:
return self.root_client
sync_specific = {"http_client": self._http_client}
return openai.OpenAI(**self._client_params, **sync_specific) # type: ignore[arg-type]
@cached_property
def _root_async_client(self) -> openai.AsyncOpenAI:
if self.root_async_client is not None:
return self.root_async_client
async_specific = {"http_client": self._http_async_client}
return openai.AsyncOpenAI(
**self._client_params,
**async_specific, # type: ignore[arg-type]
)
@cached_property
def _client(self) -> Any:
if self.client is not None:
return self.client
return self._root_client.chat.completions
@cached_property
def _async_client(self) -> Any:
if self.async_client is not None:
return self.async_client
return self._root_async_client.chat.completions
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""
@ -704,15 +761,15 @@ class BaseChatOpenAI(BaseChatModel):
"specified." "specified."
) )
payload.pop("stream") payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload) response_stream = self._root_client.beta.chat.completions.stream(**payload)
context_manager = response_stream context_manager = response_stream
else: else:
if self.include_response_headers: if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload) raw_response = self._client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)} base_generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = self.client.create(**payload) response = self._client.create(**payload)
context_manager = response context_manager = response
try: try:
with context_manager as response: with context_manager as response:
@ -772,15 +829,15 @@ class BaseChatOpenAI(BaseChatModel):
) )
payload.pop("stream") payload.pop("stream")
try: try:
response = self.root_client.beta.chat.completions.parse(**payload) response = self._root_client.beta.chat.completions.parse(**payload)
except openai.BadRequestError as e: except openai.BadRequestError as e:
_handle_openai_bad_request(e) _handle_openai_bad_request(e)
elif self.include_response_headers: elif self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload) raw_response = self._client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = self.client.create(**payload) response = self._client.create(**payload)
return self._create_chat_result(response, generation_info) return self._create_chat_result(response, generation_info)
def _get_request_payload( def _get_request_payload(
@ -868,19 +925,19 @@ class BaseChatOpenAI(BaseChatModel):
"specified." "specified."
) )
payload.pop("stream") payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream( response_stream = self._root_async_client.beta.chat.completions.stream(
**payload **payload
) )
context_manager = response_stream context_manager = response_stream
else: else:
if self.include_response_headers: if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create( raw_response = await self._async_client.with_raw_response.create(
**payload **payload
) )
response = raw_response.parse() response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)} base_generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = await self.async_client.create(**payload) response = await self._async_client.create(**payload)
context_manager = response context_manager = response
try: try:
async with context_manager as response: async with context_manager as response:
@ -940,17 +997,17 @@ class BaseChatOpenAI(BaseChatModel):
) )
payload.pop("stream") payload.pop("stream")
try: try:
response = await self.root_async_client.beta.chat.completions.parse( response = await self._root_async_client.beta.chat.completions.parse(
**payload **payload
) )
except openai.BadRequestError as e: except openai.BadRequestError as e:
_handle_openai_bad_request(e) _handle_openai_bad_request(e)
elif self.include_response_headers: elif self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload) raw_response = await self._async_client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = await self.async_client.create(**payload) response = await self._async_client.create(**payload)
return await run_in_executor( return await run_in_executor(
None, self._create_chat_result, response, generation_info None, self._create_chat_result, response, generation_info
) )