mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
update base class
This commit is contained in:
parent
437fe6d216
commit
c289fc9ba9
@ -8,10 +8,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
from functools import cached_property
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
@ -91,10 +93,20 @@ from langchain_core.utils.pydantic import (
|
|||||||
is_basemodel_subclass,
|
is_basemodel_subclass,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
|
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
PrivateAttr,
|
||||||
|
SecretStr,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import httpx
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -491,6 +503,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
However this does not prevent a user from directly passed in the parameter during
|
However this does not prevent a user from directly passed in the parameter during
|
||||||
invocation.
|
invocation.
|
||||||
"""
|
"""
|
||||||
|
_client_params: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
@ -526,7 +539,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
or os.getenv("OPENAI_ORGANIZATION")
|
or os.getenv("OPENAI_ORGANIZATION")
|
||||||
)
|
)
|
||||||
self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE")
|
self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE")
|
||||||
client_params: dict = {
|
self._client_params: dict = {
|
||||||
"api_key": (
|
"api_key": (
|
||||||
self.openai_api_key.get_secret_value() if self.openai_api_key else None
|
self.openai_api_key.get_secret_value() if self.openai_api_key else None
|
||||||
),
|
),
|
||||||
@ -537,7 +550,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"default_query": self.default_query,
|
"default_query": self.default_query,
|
||||||
}
|
}
|
||||||
if self.max_retries is not None:
|
if self.max_retries is not None:
|
||||||
client_params["max_retries"] = self.max_retries
|
self._client_params["max_retries"] = self.max_retries
|
||||||
|
|
||||||
if self.openai_proxy and (self.http_client or self.http_async_client):
|
if self.openai_proxy and (self.http_client or self.http_async_client):
|
||||||
openai_proxy = self.openai_proxy
|
openai_proxy = self.openai_proxy
|
||||||
@ -548,37 +561,81 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"'http_client'/'http_async_client' is already specified. Received:\n"
|
"'http_client'/'http_async_client' is already specified. Received:\n"
|
||||||
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
|
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
|
||||||
)
|
)
|
||||||
if not self.client:
|
|
||||||
if self.openai_proxy and not self.http_client:
|
|
||||||
try:
|
|
||||||
import httpx
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import httpx python package. "
|
|
||||||
"Please install it with `pip install httpx`."
|
|
||||||
) from e
|
|
||||||
self.http_client = httpx.Client(proxy=self.openai_proxy)
|
|
||||||
sync_specific = {"http_client": self.http_client}
|
|
||||||
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
|
|
||||||
self.client = self.root_client.chat.completions
|
|
||||||
if not self.async_client:
|
|
||||||
if self.openai_proxy and not self.http_async_client:
|
|
||||||
try:
|
|
||||||
import httpx
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import httpx python package. "
|
|
||||||
"Please install it with `pip install httpx`."
|
|
||||||
) from e
|
|
||||||
self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
|
|
||||||
async_specific = {"http_client": self.http_async_client}
|
|
||||||
self.root_async_client = openai.AsyncOpenAI(
|
|
||||||
**client_params,
|
|
||||||
**async_specific, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
self.async_client = self.root_async_client.chat.completions
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _http_client(self) -> Optional[httpx.Client]:
|
||||||
|
"""Optional httpx.Client. Only used for sync invocations.
|
||||||
|
|
||||||
|
Must specify http_async_client as well if you'd like a custom client for
|
||||||
|
async invocations.
|
||||||
|
"""
|
||||||
|
# Configure a custom httpx client. See the
|
||||||
|
# [httpx documentation](https://www.python-httpx.org/api/#client) for more
|
||||||
|
# details.
|
||||||
|
if self.http_client is not None:
|
||||||
|
return self.http_client
|
||||||
|
if not self.openai_proxy:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import httpx python package. "
|
||||||
|
"Please install it with `pip install httpx`."
|
||||||
|
) from e
|
||||||
|
return httpx.Client(proxy=self.openai_proxy)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _http_async_client(self) -> Optional[httpx.AsyncClient]:
|
||||||
|
"""Optional httpx.AsyncClient. Only used for async invocations.
|
||||||
|
|
||||||
|
Must specify http_client as well if you'd like a custom client for sync
|
||||||
|
invocations.
|
||||||
|
"""
|
||||||
|
if self.http_async_client is not None:
|
||||||
|
return self.http_async_client
|
||||||
|
if not self.openai_proxy:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import httpx python package. "
|
||||||
|
"Please install it with `pip install httpx`."
|
||||||
|
) from e
|
||||||
|
return httpx.AsyncClient(proxy=self.openai_proxy)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _root_client(self) -> openai.OpenAI:
|
||||||
|
if self.root_client is not None:
|
||||||
|
return self.root_client
|
||||||
|
sync_specific = {"http_client": self._http_client}
|
||||||
|
return openai.OpenAI(**self._client_params, **sync_specific) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _root_async_client(self) -> openai.AsyncOpenAI:
|
||||||
|
if self.root_async_client is not None:
|
||||||
|
return self.root_async_client
|
||||||
|
async_specific = {"http_client": self._http_async_client}
|
||||||
|
return openai.AsyncOpenAI(
|
||||||
|
**self._client_params,
|
||||||
|
**async_specific, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _client(self) -> Any:
|
||||||
|
if self.client is not None:
|
||||||
|
return self.client
|
||||||
|
return self._root_client.chat.completions
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _async_client(self) -> Any:
|
||||||
|
if self.async_client is not None:
|
||||||
|
return self.async_client
|
||||||
|
return self._root_async_client.chat.completions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling OpenAI API."""
|
"""Get the default parameters for calling OpenAI API."""
|
||||||
@ -704,15 +761,15 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"specified."
|
"specified."
|
||||||
)
|
)
|
||||||
payload.pop("stream")
|
payload.pop("stream")
|
||||||
response_stream = self.root_client.beta.chat.completions.stream(**payload)
|
response_stream = self._root_client.beta.chat.completions.stream(**payload)
|
||||||
context_manager = response_stream
|
context_manager = response_stream
|
||||||
else:
|
else:
|
||||||
if self.include_response_headers:
|
if self.include_response_headers:
|
||||||
raw_response = self.client.with_raw_response.create(**payload)
|
raw_response = self._client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||||
else:
|
else:
|
||||||
response = self.client.create(**payload)
|
response = self._client.create(**payload)
|
||||||
context_manager = response
|
context_manager = response
|
||||||
try:
|
try:
|
||||||
with context_manager as response:
|
with context_manager as response:
|
||||||
@ -772,15 +829,15 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
payload.pop("stream")
|
payload.pop("stream")
|
||||||
try:
|
try:
|
||||||
response = self.root_client.beta.chat.completions.parse(**payload)
|
response = self._root_client.beta.chat.completions.parse(**payload)
|
||||||
except openai.BadRequestError as e:
|
except openai.BadRequestError as e:
|
||||||
_handle_openai_bad_request(e)
|
_handle_openai_bad_request(e)
|
||||||
elif self.include_response_headers:
|
elif self.include_response_headers:
|
||||||
raw_response = self.client.with_raw_response.create(**payload)
|
raw_response = self._client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
generation_info = {"headers": dict(raw_response.headers)}
|
generation_info = {"headers": dict(raw_response.headers)}
|
||||||
else:
|
else:
|
||||||
response = self.client.create(**payload)
|
response = self._client.create(**payload)
|
||||||
return self._create_chat_result(response, generation_info)
|
return self._create_chat_result(response, generation_info)
|
||||||
|
|
||||||
def _get_request_payload(
|
def _get_request_payload(
|
||||||
@ -868,19 +925,19 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"specified."
|
"specified."
|
||||||
)
|
)
|
||||||
payload.pop("stream")
|
payload.pop("stream")
|
||||||
response_stream = self.root_async_client.beta.chat.completions.stream(
|
response_stream = self._root_async_client.beta.chat.completions.stream(
|
||||||
**payload
|
**payload
|
||||||
)
|
)
|
||||||
context_manager = response_stream
|
context_manager = response_stream
|
||||||
else:
|
else:
|
||||||
if self.include_response_headers:
|
if self.include_response_headers:
|
||||||
raw_response = await self.async_client.with_raw_response.create(
|
raw_response = await self._async_client.with_raw_response.create(
|
||||||
**payload
|
**payload
|
||||||
)
|
)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||||
else:
|
else:
|
||||||
response = await self.async_client.create(**payload)
|
response = await self._async_client.create(**payload)
|
||||||
context_manager = response
|
context_manager = response
|
||||||
try:
|
try:
|
||||||
async with context_manager as response:
|
async with context_manager as response:
|
||||||
@ -940,17 +997,17 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
payload.pop("stream")
|
payload.pop("stream")
|
||||||
try:
|
try:
|
||||||
response = await self.root_async_client.beta.chat.completions.parse(
|
response = await self._root_async_client.beta.chat.completions.parse(
|
||||||
**payload
|
**payload
|
||||||
)
|
)
|
||||||
except openai.BadRequestError as e:
|
except openai.BadRequestError as e:
|
||||||
_handle_openai_bad_request(e)
|
_handle_openai_bad_request(e)
|
||||||
elif self.include_response_headers:
|
elif self.include_response_headers:
|
||||||
raw_response = await self.async_client.with_raw_response.create(**payload)
|
raw_response = await self._async_client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
generation_info = {"headers": dict(raw_response.headers)}
|
generation_info = {"headers": dict(raw_response.headers)}
|
||||||
else:
|
else:
|
||||||
response = await self.async_client.create(**payload)
|
response = await self._async_client.create(**payload)
|
||||||
return await run_in_executor(
|
return await run_in_executor(
|
||||||
None, self._create_chat_result, response, generation_info
|
None, self._create_chat_result, response, generation_info
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user