openai[major]: switch to pydantic v2

This commit is contained in:
Bagatur
2024-09-03 16:33:35 -07:00
parent 9a9ab65030
commit 615f8b0d47
13 changed files with 228 additions and 225 deletions

View File

@@ -31,7 +31,7 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser, PydanticToolsParser,
) )
from langchain_core.outputs import ChatResult from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from pydantic import BaseModel, Field, SecretStr, model_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
@@ -39,6 +39,8 @@ from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai.chat_models.base import BaseChatOpenAI from langchain_openai.chat_models.base import BaseChatOpenAI
from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -494,7 +496,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
default_factory=from_env("OPENAI_API_VERSION", default=None), default_factory=from_env("OPENAI_API_VERSION", default=None),
) )
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
# Check OPENAI_KEY for backwards compatibility. # Check OPENAI_API_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials. # other forms of azure credentials.
openai_api_key: Optional[SecretStr] = Field( openai_api_key: Optional[SecretStr] = Field(
@@ -565,31 +567,31 @@ class AzureChatOpenAI(BaseChatOpenAI):
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@root_validator(pre=False, skip_on_failure=True) @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if self.n < 1:
raise ValueError("n must be at least 1.") raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]: if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.") raise ValueError("n must be 1 when streaming.")
# Check OPENAI_ORGANIZATION for backwards compatibility. # Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = ( self.openai_organization = (
values["openai_organization"] self.openai_organization
or os.getenv("OPENAI_ORG_ID") or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION") or os.getenv("OPENAI_ORGANIZATION")
) )
# For backwards compatibility. Before openai v1, no distinction was made # For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base). # between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"] openai_api_base = self.openai_api_base
if openai_api_base and values["validate_base_url"]: if openai_api_base and self.validate_base_url:
if "/openai" not in openai_api_base: if "/openai" not in openai_api_base:
raise ValueError( raise ValueError(
"As of openai>=1.0.0, Azure endpoints should be specified via " "As of openai>=1.0.0, Azure endpoints should be specified via "
"the `azure_endpoint` param not `openai_api_base` " "the `azure_endpoint` param not `openai_api_base` "
"(or alias `base_url`)." "(or alias `base_url`)."
) )
if values["deployment_name"]: if self.deployment_name:
raise ValueError( raise ValueError(
"As of openai>=1.0.0, if `azure_deployment` (or alias " "As of openai>=1.0.0, if `azure_deployment` (or alias "
"`deployment_name`) is specified then " "`deployment_name`) is specified then "
@@ -603,38 +605,38 @@ class AzureChatOpenAI(BaseChatOpenAI):
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"' 'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
) )
client_params = { client_params = {
"api_version": values["openai_api_version"], "api_version": self.openai_api_version,
"azure_endpoint": values["azure_endpoint"], "azure_endpoint": self.azure_endpoint,
"azure_deployment": values["deployment_name"], "azure_deployment": self.deployment_name,
"api_key": ( "api_key": (
values["openai_api_key"].get_secret_value() self.openai_api_key.get_secret_value()
if values["openai_api_key"] if self.openai_api_key
else None else None
), ),
"azure_ad_token": ( "azure_ad_token": (
values["azure_ad_token"].get_secret_value() self.azure_ad_token.get_secret_value()
if values["azure_ad_token"] if self.azure_ad_token
else None else None
), ),
"azure_ad_token_provider": values["azure_ad_token_provider"], "azure_ad_token_provider": self.azure_ad_token_provider,
"organization": values["openai_organization"], "organization": self.openai_organization,
"base_url": values["openai_api_base"], "base_url": self.openai_api_base,
"timeout": values["request_timeout"], "timeout": self.request_timeout,
"max_retries": values["max_retries"], "max_retries": self.max_retries,
"default_headers": values["default_headers"], "default_headers": self.default_headers,
"default_query": values["default_query"], "default_query": self.default_query,
} }
if not values.get("client"): if not (self.client or None):
sync_specific = {"http_client": values["http_client"]} sync_specific = {"http_client": self.http_client}
values["root_client"] = openai.AzureOpenAI(**client_params, **sync_specific) self.root_client = openai.AzureOpenAI(**client_params, **sync_specific)
values["client"] = values["root_client"].chat.completions self.client = self.root_client.chat.completions
if not values.get("async_client"): if not (self.async_client or None):
async_specific = {"http_client": values["http_async_client"]} async_specific = {"http_client": self.http_async_client}
values["root_async_client"] = openai.AsyncAzureOpenAI( self.root_async_client = openai.AsyncAzureOpenAI(
**client_params, **async_specific **client_params, **async_specific
) )
values["async_client"] = values["root_async_client"].chat.completions self.async_client = self.root_async_client.chat.completions
return values return self
def bind_tools( def bind_tools(
self, self,

View File

@@ -73,15 +73,11 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call, parse_tool_call,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from pydantic import BaseModel, Field, model_validator, SecretStr
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
@@ -91,7 +87,10 @@ from langchain_core.utils.pydantic import (
TypeBaseModel, TypeBaseModel,
is_basemodel_subclass, is_basemodel_subclass,
) )
from langchain_core.utils.utils import build_extra_kwargs from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from pydantic import ConfigDict
from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -361,15 +360,19 @@ class BaseChatOpenAI(BaseChatModel):
"""What sampling temperature to use.""" """What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") openai_api_key: Optional[SecretStr] = Field(
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" alias="api_key",
default_factory=secret_from_env("OPENAI_API_KEY", default=None),
)
openai_api_base: Optional[str] = Field(default=None, alias="base_url") openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service """Base URL path for API requests, leave blank if not using a proxy or service
emulator.""" emulator."""
openai_organization: Optional[str] = Field(default=None, alias="organization") openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI # to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None openai_proxy: Optional[str] = Field(
default_factory=from_env("OPENAI_PROXY", default=None)
)
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout" default=None, alias="timeout"
) )
@@ -428,13 +431,11 @@ class BaseChatOpenAI(BaseChatModel):
include_response_headers: bool = False include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata.""" """Whether to include response headers in the output message response_metadata."""
class Config: model_config = ConfigDict(populate_by_name=True,)
"""Configuration for this pydantic object."""
allow_population_by_field_name = True @model_validator(mode="before")
@classmethod
@root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@@ -443,56 +444,49 @@ class BaseChatOpenAI(BaseChatModel):
) )
return values return values
@root_validator(pre=False, skip_on_failure=True, allow_reuse=True) @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if self.n < 1:
raise ValueError("n must be at least 1.") raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]: if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.") raise ValueError("n must be 1 when streaming.")
values["openai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "openai_api_key", "OPENAI_API_KEY")
)
# Check OPENAI_ORGANIZATION for backwards compatibility. # Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = ( self.openai_organization = (
values["openai_organization"] self.openai_organization
or os.getenv("OPENAI_ORG_ID") or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION") or os.getenv("OPENAI_ORGANIZATION")
) )
values["openai_api_base"] = values["openai_api_base"] or os.getenv( self.openai_api_base = self.openai_api_base or os.getenv(
"OPENAI_API_BASE" "OPENAI_API_BASE"
) )
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
)
client_params = { client_params = {
"api_key": ( "api_key": (
values["openai_api_key"].get_secret_value() self.openai_api_key.get_secret_value()
if values["openai_api_key"] if self.openai_api_key
else None else None
), ),
"organization": values["openai_organization"], "organization": self.openai_organization,
"base_url": values["openai_api_base"], "base_url": self.openai_api_base,
"timeout": values["request_timeout"], "timeout": self.request_timeout,
"max_retries": values["max_retries"], "max_retries": self.max_retries,
"default_headers": values["default_headers"], "default_headers": self.default_headers,
"default_query": values["default_query"], "default_query": self.default_query,
} }
if values["openai_proxy"] and ( if self.openai_proxy and (
values["http_client"] or values["http_async_client"] self.http_client or self.http_async_client
): ):
openai_proxy = values["openai_proxy"] openai_proxy = self.openai_proxy
http_client = values["http_client"] http_client = self.http_client
http_async_client = values["http_async_client"] http_async_client = self.http_async_client
raise ValueError( raise ValueError(
"Cannot specify 'openai_proxy' if one of " "Cannot specify 'openai_proxy' if one of "
"'http_client'/'http_async_client' is already specified. Received:\n" "'http_client'/'http_async_client' is already specified. Received:\n"
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}" f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
) )
if not values.get("client"): if not (self.client or None):
if values["openai_proxy"] and not values["http_client"]: if self.openai_proxy and not self.http_client:
try: try:
import httpx import httpx
except ImportError as e: except ImportError as e:
@@ -500,12 +494,12 @@ class BaseChatOpenAI(BaseChatModel):
"Could not import httpx python package. " "Could not import httpx python package. "
"Please install it with `pip install httpx`." "Please install it with `pip install httpx`."
) from e ) from e
values["http_client"] = httpx.Client(proxy=values["openai_proxy"]) self.http_client = httpx.Client(proxy=self.openai_proxy)
sync_specific = {"http_client": values["http_client"]} sync_specific = {"http_client": self.http_client}
values["root_client"] = openai.OpenAI(**client_params, **sync_specific) self.root_client = openai.OpenAI(**client_params, **sync_specific)
values["client"] = values["root_client"].chat.completions self.client = self.root_client.chat.completions
if not values.get("async_client"): if not (self.async_client or None):
if values["openai_proxy"] and not values["http_async_client"]: if self.openai_proxy and not self.http_async_client:
try: try:
import httpx import httpx
except ImportError as e: except ImportError as e:
@@ -513,15 +507,15 @@ class BaseChatOpenAI(BaseChatModel):
"Could not import httpx python package. " "Could not import httpx python package. "
"Please install it with `pip install httpx`." "Please install it with `pip install httpx`."
) from e ) from e
values["http_async_client"] = httpx.AsyncClient( self.http_async_client = httpx.AsyncClient(
proxy=values["openai_proxy"] proxy=self.openai_proxy
) )
async_specific = {"http_client": values["http_async_client"]} async_specific = {"http_client": self.http_async_client}
values["root_async_client"] = openai.AsyncOpenAI( self.root_async_client = openai.AsyncOpenAI(
**client_params, **async_specific **client_params, **async_specific
) )
values["async_client"] = values["root_async_client"].chat.completions self.async_client = self.root_async_client.chat.completions
return values return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:

View File

@@ -5,10 +5,12 @@ from __future__ import annotations
from typing import Callable, Dict, Optional, Union from typing import Callable, Dict, Optional, Union
import openai import openai
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from pydantic import Field, SecretStr, root_validator, model_validator
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
from langchain_openai.embeddings.base import OpenAIEmbeddings from langchain_openai.embeddings.base import OpenAIEmbeddings
from typing_extensions import Self
class AzureOpenAIEmbeddings(OpenAIEmbeddings): class AzureOpenAIEmbeddings(OpenAIEmbeddings):
@@ -153,21 +155,21 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
chunk_size: int = 2048 chunk_size: int = 2048
"""Maximum number of texts to embed in each batch""" """Maximum number of texts to embed in each batch"""
@root_validator(pre=False, skip_on_failure=True) @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
# For backwards compatibility. Before openai v1, no distinction was made # For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base). # between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"] openai_api_base = self.openai_api_base
if openai_api_base and values["validate_base_url"]: if openai_api_base and self.validate_base_url:
if "/openai" not in openai_api_base: if "/openai" not in openai_api_base:
values["openai_api_base"] += "/openai" self.openai_api_base += "/openai"
raise ValueError( raise ValueError(
"As of openai>=1.0.0, Azure endpoints should be specified via " "As of openai>=1.0.0, Azure endpoints should be specified via "
"the `azure_endpoint` param not `openai_api_base` " "the `azure_endpoint` param not `openai_api_base` "
"(or alias `base_url`). " "(or alias `base_url`). "
) )
if values["deployment"]: if self.deployment:
raise ValueError( raise ValueError(
"As of openai>=1.0.0, if `deployment` (or alias " "As of openai>=1.0.0, if `deployment` (or alias "
"`azure_deployment`) is specified then " "`azure_deployment`) is specified then "
@@ -176,38 +178,38 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"and `azure_endpoint`." "and `azure_endpoint`."
) )
client_params = { client_params = {
"api_version": values["openai_api_version"], "api_version": self.openai_api_version,
"azure_endpoint": values["azure_endpoint"], "azure_endpoint": self.azure_endpoint,
"azure_deployment": values["deployment"], "azure_deployment": self.deployment,
"api_key": ( "api_key": (
values["openai_api_key"].get_secret_value() self.openai_api_key.get_secret_value()
if values["openai_api_key"] if self.openai_api_key
else None else None
), ),
"azure_ad_token": ( "azure_ad_token": (
values["azure_ad_token"].get_secret_value() self.azure_ad_token.get_secret_value()
if values["azure_ad_token"] if self.azure_ad_token
else None else None
), ),
"azure_ad_token_provider": values["azure_ad_token_provider"], "azure_ad_token_provider": self.azure_ad_token_provider,
"organization": values["openai_organization"], "organization": self.openai_organization,
"base_url": values["openai_api_base"], "base_url": self.openai_api_base,
"timeout": values["request_timeout"], "timeout": self.request_timeout,
"max_retries": values["max_retries"], "max_retries": self.max_retries,
"default_headers": values["default_headers"], "default_headers": self.default_headers,
"default_query": values["default_query"], "default_query": self.default_query,
} }
if not values.get("client"): if not (self.client or None):
sync_specific = {"http_client": values["http_client"]} sync_specific = {"http_client": self.http_client}
values["client"] = openai.AzureOpenAI( self.client = openai.AzureOpenAI(
**client_params, **sync_specific **client_params, **sync_specific
).embeddings ).embeddings
if not values.get("async_client"): if not (self.async_client or None):
async_specific = {"http_client": values["http_async_client"]} async_specific = {"http_client": self.http_async_client}
values["async_client"] = openai.AsyncAzureOpenAI( self.async_client = openai.AsyncAzureOpenAI(
**client_params, **async_specific **client_params, **async_specific
).embeddings ).embeddings
return values return self
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:

View File

@@ -20,8 +20,12 @@ from typing import (
import openai import openai
import tiktoken import tiktoken
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from pydantic import BaseModel, Field, SecretStr, root_validator, model_validator
from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
from pydantic import ConfigDict
from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -263,14 +267,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Whether to check the token length of inputs and automatically split inputs """Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length.""" longer than embedding_ctx_length."""
class Config: model_config = ConfigDict(extra="forbid",populate_by_name=True,)
"""Configuration for this pydantic object."""
extra = "forbid" @model_validator(mode="before")
allow_population_by_field_name = True @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@@ -295,41 +296,41 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator(pre=False, skip_on_failure=True, allow_reuse=True) @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): if self.openai_api_type in ("azure", "azure_ad", "azuread"):
raise ValueError( raise ValueError(
"If you are using Azure, " "If you are using Azure, "
"please use the `AzureOpenAIEmbeddings` class." "please use the `AzureOpenAIEmbeddings` class."
) )
client_params = { client_params = {
"api_key": ( "api_key": (
values["openai_api_key"].get_secret_value() self.openai_api_key.get_secret_value()
if values["openai_api_key"] if self.openai_api_key
else None else None
), ),
"organization": values["openai_organization"], "organization": self.openai_organization,
"base_url": values["openai_api_base"], "base_url": self.openai_api_base,
"timeout": values["request_timeout"], "timeout": self.request_timeout,
"max_retries": values["max_retries"], "max_retries": self.max_retries,
"default_headers": values["default_headers"], "default_headers": self.default_headers,
"default_query": values["default_query"], "default_query": self.default_query,
} }
if values["openai_proxy"] and ( if self.openai_proxy and (
values["http_client"] or values["http_async_client"] self.http_client or self.http_async_client
): ):
openai_proxy = values["openai_proxy"] openai_proxy = self.openai_proxy
http_client = values["http_client"] http_client = self.http_client
http_async_client = values["http_async_client"] http_async_client = self.http_async_client
raise ValueError( raise ValueError(
"Cannot specify 'openai_proxy' if one of " "Cannot specify 'openai_proxy' if one of "
"'http_client'/'http_async_client' is already specified. Received:\n" "'http_client'/'http_async_client' is already specified. Received:\n"
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}" f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
) )
if not values.get("client"): if not (self.client or None):
if values["openai_proxy"] and not values["http_client"]: if self.openai_proxy and not self.http_client:
try: try:
import httpx import httpx
except ImportError as e: except ImportError as e:
@@ -337,13 +338,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"Could not import httpx python package. " "Could not import httpx python package. "
"Please install it with `pip install httpx`." "Please install it with `pip install httpx`."
) from e ) from e
values["http_client"] = httpx.Client(proxy=values["openai_proxy"]) self.http_client = httpx.Client(proxy=self.openai_proxy)
sync_specific = {"http_client": values["http_client"]} sync_specific = {"http_client": self.http_client}
values["client"] = openai.OpenAI( self.client = openai.OpenAI(
**client_params, **sync_specific **client_params, **sync_specific
).embeddings ).embeddings
if not values.get("async_client"): if not (self.async_client or None):
if values["openai_proxy"] and not values["http_async_client"]: if self.openai_proxy and not self.http_async_client:
try: try:
import httpx import httpx
except ImportError as e: except ImportError as e:
@@ -351,14 +352,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"Could not import httpx python package. " "Could not import httpx python package. "
"Please install it with `pip install httpx`." "Please install it with `pip install httpx`."
) from e ) from e
values["http_async_client"] = httpx.AsyncClient( self.http_async_client = httpx.AsyncClient(
proxy=values["openai_proxy"] proxy=self.openai_proxy
) )
async_specific = {"http_client": values["http_async_client"]} async_specific = {"http_client": self.http_async_client}
values["async_client"] = openai.AsyncOpenAI( self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific **client_params, **async_specific
).embeddings ).embeddings
return values return self
@property @property
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> Dict[str, Any]:

View File

@@ -5,10 +5,12 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai import openai
from langchain_core.language_models import LangSmithParams from langchain_core.language_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from pydantic import Field, SecretStr, root_validator, model_validator
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
from langchain_openai.llms.base import BaseOpenAI from langchain_openai.llms.base import BaseOpenAI
from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,29 +102,29 @@ class AzureOpenAI(BaseOpenAI):
"""Return whether this model can be serialized by Langchain.""" """Return whether this model can be serialized by Langchain."""
return True return True
@root_validator(pre=False, skip_on_failure=True, allow_reuse=True) @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if self.n < 1:
raise ValueError("n must be at least 1.") raise ValueError("n must be at least 1.")
if values["streaming"] and values["n"] > 1: if self.streaming and self.n > 1:
raise ValueError("Cannot stream results when n > 1.") raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1: if self.streaming and self.best_of > 1:
raise ValueError("Cannot stream results when best_of > 1.") raise ValueError("Cannot stream results when best_of > 1.")
# For backwards compatibility. Before openai v1, no distinction was made # For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base). # between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"] openai_api_base = self.openai_api_base
if openai_api_base and values["validate_base_url"]: if openai_api_base and self.validate_base_url:
if "/openai" not in openai_api_base: if "/openai" not in openai_api_base:
values["openai_api_base"] = ( self.openai_api_base = (
values["openai_api_base"].rstrip("/") + "/openai" self.openai_api_base.rstrip("/") + "/openai"
) )
raise ValueError( raise ValueError(
"As of openai>=1.0.0, Azure endpoints should be specified via " "As of openai>=1.0.0, Azure endpoints should be specified via "
"the `azure_endpoint` param not `openai_api_base` " "the `azure_endpoint` param not `openai_api_base` "
"(or alias `base_url`)." "(or alias `base_url`)."
) )
if values["deployment_name"]: if self.deployment_name:
raise ValueError( raise ValueError(
"As of openai>=1.0.0, if `deployment_name` (or alias " "As of openai>=1.0.0, if `deployment_name` (or alias "
"`azure_deployment`) is specified then " "`azure_deployment`) is specified then "
@@ -130,37 +132,37 @@ class AzureOpenAI(BaseOpenAI):
"Instead use `deployment_name` (or alias `azure_deployment`) " "Instead use `deployment_name` (or alias `azure_deployment`) "
"and `azure_endpoint`." "and `azure_endpoint`."
) )
values["deployment_name"] = None self.deployment_name = None
client_params = { client_params = {
"api_version": values["openai_api_version"], "api_version": self.openai_api_version,
"azure_endpoint": values["azure_endpoint"], "azure_endpoint": self.azure_endpoint,
"azure_deployment": values["deployment_name"], "azure_deployment": self.deployment_name,
"api_key": values["openai_api_key"].get_secret_value() "api_key": self.openai_api_key.get_secret_value()
if values["openai_api_key"] if self.openai_api_key
else None, else None,
"azure_ad_token": values["azure_ad_token"].get_secret_value() "azure_ad_token": self.azure_ad_token.get_secret_value()
if values["azure_ad_token"] if self.azure_ad_token
else None, else None,
"azure_ad_token_provider": values["azure_ad_token_provider"], "azure_ad_token_provider": self.azure_ad_token_provider,
"organization": values["openai_organization"], "organization": self.openai_organization,
"base_url": values["openai_api_base"], "base_url": self.openai_api_base,
"timeout": values["request_timeout"], "timeout": self.request_timeout,
"max_retries": values["max_retries"], "max_retries": self.max_retries,
"default_headers": values["default_headers"], "default_headers": self.default_headers,
"default_query": values["default_query"], "default_query": self.default_query,
} }
if not values.get("client"): if not (self.client or None):
sync_specific = {"http_client": values["http_client"]} sync_specific = {"http_client": self.http_client}
values["client"] = openai.AzureOpenAI( self.client = openai.AzureOpenAI(
**client_params, **sync_specific **client_params, **sync_specific
).completions ).completions
if not values.get("async_client"): if not (self.async_client or None):
async_specific = {"http_client": values["http_async_client"]} async_specific = {"http_client": self.http_async_client}
values["async_client"] = openai.AsyncAzureOpenAI( self.async_client = openai.AsyncAzureOpenAI(
**client_params, **async_specific **client_params, **async_specific
).completions ).completions
return values return self
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:

View File

@@ -26,9 +26,13 @@ from langchain_core.callbacks import (
) )
from langchain_core.language_models.llms import BaseLLM from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from pydantic import Field, SecretStr, root_validator, model_validator
from langchain_core.utils import get_pydantic_field_names from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from pydantic import ConfigDict
from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -152,13 +156,11 @@ class BaseOpenAI(BaseLLM):
"""Optional additional JSON properties to include in the request parameters when """Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM.""" making requests to OpenAI compatible APIs, such as vLLM."""
class Config: model_config = ConfigDict(populate_by_name=True,)
"""Configuration for this pydantic object."""
allow_population_by_field_name = True @model_validator(mode="before")
@classmethod
@root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@@ -167,41 +169,41 @@ class BaseOpenAI(BaseLLM):
) )
return values return values
@root_validator(pre=False, skip_on_failure=True, allow_reuse=True) @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if self.n < 1:
raise ValueError("n must be at least 1.") raise ValueError("n must be at least 1.")
if values["streaming"] and values["n"] > 1: if self.streaming and self.n > 1:
raise ValueError("Cannot stream results when n > 1.") raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1: if self.streaming and self.best_of > 1:
raise ValueError("Cannot stream results when best_of > 1.") raise ValueError("Cannot stream results when best_of > 1.")
client_params = { client_params = {
"api_key": ( "api_key": (
values["openai_api_key"].get_secret_value() self.openai_api_key.get_secret_value()
if values["openai_api_key"] if self.openai_api_key
else None else None
), ),
"organization": values["openai_organization"], "organization": self.openai_organization,
"base_url": values["openai_api_base"], "base_url": self.openai_api_base,
"timeout": values["request_timeout"], "timeout": self.request_timeout,
"max_retries": values["max_retries"], "max_retries": self.max_retries,
"default_headers": values["default_headers"], "default_headers": self.default_headers,
"default_query": values["default_query"], "default_query": self.default_query,
} }
if not values.get("client"): if not (self.client or None):
sync_specific = {"http_client": values["http_client"]} sync_specific = {"http_client": self.http_client}
values["client"] = openai.OpenAI( self.client = openai.OpenAI(
**client_params, **sync_specific **client_params, **sync_specific
).completions ).completions
if not values.get("async_client"): if not (self.async_client or None):
async_specific = {"http_client": values["http_async_client"]} async_specific = {"http_client": self.http_async_client}
values["async_client"] = openai.AsyncOpenAI( self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific **client_params, **async_specific
).completions ).completions
return values return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:

View File

@@ -20,7 +20,7 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
from langchain_standard_tests.integration_tests.chat_models import ( from langchain_standard_tests.integration_tests.chat_models import (
_validate_tool_call_message, _validate_tool_call_message,
) )

View File

@@ -17,7 +17,7 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import ( from langchain_openai.chat_models.base import (

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
class BaseFakeCallbackHandler(BaseModel): class BaseFakeCallbackHandler(BaseModel):

View File

@@ -9,7 +9,7 @@ def test_loads_openai_llm() -> None:
llm_string = dumps(llm) llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"}) llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm assert llm2.dict() == llm.dict()
llm_string_2 = dumps(llm2) llm_string_2 = dumps(llm2)
assert llm_string_2 == llm_string assert llm_string_2 == llm_string
assert isinstance(llm2, OpenAI) assert isinstance(llm2, OpenAI)
@@ -20,7 +20,7 @@ def test_load_openai_llm() -> None:
llm_obj = dumpd(llm) llm_obj = dumpd(llm)
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"}) llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm assert llm2.dict() == llm.dict()
assert dumpd(llm2) == llm_obj assert dumpd(llm2) == llm_obj
assert isinstance(llm2, OpenAI) assert isinstance(llm2, OpenAI)
@@ -30,7 +30,7 @@ def test_loads_openai_chat() -> None:
llm_string = dumps(llm) llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"}) llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm assert llm2.dict() == llm.dict()
llm_string_2 = dumps(llm2) llm_string_2 = dumps(llm2)
assert llm_string_2 == llm_string assert llm_string_2 == llm_string
assert isinstance(llm2, ChatOpenAI) assert isinstance(llm2, ChatOpenAI)
@@ -41,6 +41,6 @@ def test_load_openai_chat() -> None:
llm_obj = dumpd(llm) llm_obj = dumpd(llm)
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"}) llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm assert llm2.dict() == llm.dict()
assert dumpd(llm2) == llm_obj assert dumpd(llm2) == llm_obj
assert isinstance(llm2, ChatOpenAI) assert isinstance(llm2, ChatOpenAI)

View File

@@ -2,7 +2,7 @@ from typing import Type, cast
import pytest import pytest
from langchain_core.load import dumpd from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import SecretStr from pydantic import SecretStr
from pytest import CaptureFixture, MonkeyPatch from pytest import CaptureFixture, MonkeyPatch
from langchain_openai import ( from langchain_openai import (

View File

@@ -6,7 +6,7 @@ from unittest import mock
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from langchain_core.runnables import RunnableBinding from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool from langchain_core.tools import tool

View File

@@ -5,7 +5,7 @@ from unittest import mock
import pytest import pytest
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import SecretStr from pydantic import SecretStr
from langchain_standard_tests.base import BaseStandardTests from langchain_standard_tests.base import BaseStandardTests