diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 2d580d08ccd..1a82d3d9893 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -8,10 +8,12 @@ import logging import os import sys import warnings +from functools import cached_property from io import BytesIO from math import ceil from operator import itemgetter from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Callable, @@ -91,10 +93,20 @@ from langchain_core.utils.pydantic import ( is_basemodel_subclass, ) from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env -from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + SecretStr, + model_validator, +) from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self +if TYPE_CHECKING: + import httpx + logger = logging.getLogger(__name__) @@ -491,6 +503,7 @@ class BaseChatOpenAI(BaseChatModel): However this does not prevent a user from directly passed in the parameter during invocation. """ + _client_params: Dict[str, Any] = PrivateAttr(default_factory=dict) model_config = ConfigDict(populate_by_name=True) @@ -526,7 +539,7 @@ class BaseChatOpenAI(BaseChatModel): or os.getenv("OPENAI_ORGANIZATION") ) self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE") - client_params: dict = { + self._client_params: dict = { "api_key": ( self.openai_api_key.get_secret_value() if self.openai_api_key else None ), @@ -537,7 +550,7 @@ class BaseChatOpenAI(BaseChatModel): "default_query": self.default_query, } if self.max_retries is not None: - client_params["max_retries"] = self.max_retries + self._client_params["max_retries"] = self.max_retries if self.openai_proxy and (self.http_client or self.http_async_client): openai_proxy = self.openai_proxy @@ -548,37 +561,81 @@ class BaseChatOpenAI(BaseChatModel): "'http_client'/'http_async_client' is already specified. Received:\n" f"{openai_proxy=}\n{http_client=}\n{http_async_client=}" ) - if not self.client: - if self.openai_proxy and not self.http_client: - try: - import httpx - except ImportError as e: - raise ImportError( - "Could not import httpx python package. " - "Please install it with `pip install httpx`." - ) from e - self.http_client = httpx.Client(proxy=self.openai_proxy) - sync_specific = {"http_client": self.http_client} - self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type] - self.client = self.root_client.chat.completions - if not self.async_client: - if self.openai_proxy and not self.http_async_client: - try: - import httpx - except ImportError as e: - raise ImportError( - "Could not import httpx python package. " - "Please install it with `pip install httpx`." - ) from e - self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy) - async_specific = {"http_client": self.http_async_client} - self.root_async_client = openai.AsyncOpenAI( - **client_params, - **async_specific, # type: ignore[arg-type] - ) - self.async_client = self.root_async_client.chat.completions + return self + @cached_property + def _http_client(self) -> Optional[httpx.Client]: + """Optional httpx.Client. Only used for sync invocations. + + Must specify http_async_client as well if you'd like a custom client for + async invocations. + """ + # Configure a custom httpx client. See the + # [httpx documentation](https://www.python-httpx.org/api/#client) for more + # details. + if self.http_client is not None: + return self.http_client + if not self.openai_proxy: + return None + try: + import httpx + except ImportError as e: + raise ImportError( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) from e + return httpx.Client(proxy=self.openai_proxy) + + @cached_property + def _http_async_client(self) -> Optional[httpx.AsyncClient]: + """Optional httpx.AsyncClient. Only used for async invocations. + + Must specify http_client as well if you'd like a custom client for sync + invocations. + """ + if self.http_async_client is not None: + return self.http_async_client + if not self.openai_proxy: + return None + try: + import httpx + except ImportError as e: + raise ImportError( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) from e + return httpx.AsyncClient(proxy=self.openai_proxy) + + @cached_property + def _root_client(self) -> openai.OpenAI: + if self.root_client is not None: + return self.root_client + sync_specific = {"http_client": self._http_client} + return openai.OpenAI(**self._client_params, **sync_specific) # type: ignore[arg-type] + + @cached_property + def _root_async_client(self) -> openai.AsyncOpenAI: + if self.root_async_client is not None: + return self.root_async_client + async_specific = {"http_client": self._http_async_client} + return openai.AsyncOpenAI( + **self._client_params, + **async_specific, # type: ignore[arg-type] + ) + + @cached_property + def _client(self) -> Any: + if self.client is not None: + return self.client + return self._root_client.chat.completions + + @cached_property + def _async_client(self) -> Any: + if self.async_client is not None: + return self.async_client + return self._root_async_client.chat.completions + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" @@ -704,15 +761,15 @@ class BaseChatOpenAI(BaseChatModel): "specified." ) payload.pop("stream") - response_stream = self.root_client.beta.chat.completions.stream(**payload) + response_stream = self._root_client.beta.chat.completions.stream(**payload) context_manager = response_stream else: if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) + raw_response = self._client.with_raw_response.create(**payload) response = raw_response.parse() base_generation_info = {"headers": dict(raw_response.headers)} else: - response = self.client.create(**payload) + response = self._client.create(**payload) context_manager = response try: with context_manager as response: @@ -772,15 +829,15 @@ class BaseChatOpenAI(BaseChatModel): ) payload.pop("stream") try: - response = self.root_client.beta.chat.completions.parse(**payload) + response = self._root_client.beta.chat.completions.parse(**payload) except openai.BadRequestError as e: _handle_openai_bad_request(e) elif self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) + raw_response = self._client.with_raw_response.create(**payload) response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} else: - response = self.client.create(**payload) + response = self._client.create(**payload) return self._create_chat_result(response, generation_info) def _get_request_payload( @@ -868,19 +925,19 @@ class BaseChatOpenAI(BaseChatModel): "specified." ) payload.pop("stream") - response_stream = self.root_async_client.beta.chat.completions.stream( + response_stream = self._root_async_client.beta.chat.completions.stream( **payload ) context_manager = response_stream else: if self.include_response_headers: - raw_response = await self.async_client.with_raw_response.create( + raw_response = await self._async_client.with_raw_response.create( **payload ) response = raw_response.parse() base_generation_info = {"headers": dict(raw_response.headers)} else: - response = await self.async_client.create(**payload) + response = await self._async_client.create(**payload) context_manager = response try: async with context_manager as response: @@ -940,17 +997,17 @@ class BaseChatOpenAI(BaseChatModel): ) payload.pop("stream") try: - response = await self.root_async_client.beta.chat.completions.parse( + response = await self._root_async_client.beta.chat.completions.parse( **payload ) except openai.BadRequestError as e: _handle_openai_bad_request(e) elif self.include_response_headers: - raw_response = await self.async_client.with_raw_response.create(**payload) + raw_response = await self._async_client.with_raw_response.create(**payload) response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} else: - response = await self.async_client.create(**payload) + response = await self._async_client.create(**payload) return await run_in_executor( None, self._create_chat_result, response, generation_info )