openai[patch]: Upgrade @root_validators in preparation for pydantic 2 migration (#25491)

* Upgrade @root_validator in openai pkg
* Ran notebooks for all but AzureAI embeddings

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Eugene Yurtsev 2024-09-03 17:42:24 -04:00 committed by GitHub
parent 0207dc1431
commit bc3b851f08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 248 additions and 233 deletions

View File

@ -34,7 +34,7 @@ from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import from_env, secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
@ -474,10 +474,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
} }
""" # noqa: E501 """ # noqa: E501
azure_endpoint: Union[str, None] = None azure_endpoint: Optional[str] = Field(
default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None)
)
"""Your Azure endpoint, including the resource. """Your Azure endpoint, including the resource.
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided. Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
Example: `https://example-resource.azure.openai.com/` Example: `https://example-resource.azure.openai.com/`
""" """
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment") deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
@ -486,14 +489,28 @@ class AzureChatOpenAI(BaseChatOpenAI):
If given sets the base client URL to include `/deployments/{azure_deployment}`. If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Note: this means you won't be able to use non-deployment endpoints.
""" """
openai_api_version: str = Field(default="", alias="api_version") openai_api_version: Optional[str] = Field(
alias="api_version",
default_factory=from_env("OPENAI_API_VERSION", default=None),
)
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") # Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
openai_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env(
["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None
),
)
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided.""" """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Optional[SecretStr] = None azure_ad_token: Optional[SecretStr] = Field(
default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None)
)
"""Your Azure Active Directory token. """Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
For more: For more:
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id. https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
""" """
@ -516,7 +533,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
correct cost. correct cost.
""" """
openai_api_type: str = "" openai_api_type: Optional[str] = Field(
default_factory=from_env("OPENAI_API_TYPE", default="azure")
)
"""Legacy, for openai<1.0.0 support.""" """Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True validate_base_url: bool = True
"""If legacy arg openai_api_base is passed in, try to infer if it is a base_url or """If legacy arg openai_api_base is passed in, try to infer if it is a base_url or
@ -546,7 +565,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if values["n"] < 1:
@ -554,45 +573,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
if values["n"] > 1 and values["streaming"]: if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.") raise ValueError("n must be 1 when streaming.")
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
openai_api_key = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = (
values["openai_api_base"]
if "openai_api_base" in values
else os.getenv("OPENAI_API_BASE")
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
"OPENAI_API_VERSION"
)
# Check OPENAI_ORGANIZATION for backwards compatibility. # Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = ( values["openai_organization"] = (
values["openai_organization"] values["openai_organization"]
or os.getenv("OPENAI_ORG_ID") or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION") or os.getenv("OPENAI_ORGANIZATION")
) )
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN")
values["azure_ad_token"] = (
convert_to_secret_str(azure_ad_token) if azure_ad_token else None
)
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
)
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
)
# For backwards compatibility. Before openai v1, no distinction was made # For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base). # between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"] openai_api_base = values["openai_api_base"]

View File

@ -443,7 +443,7 @@ class BaseChatOpenAI(BaseChatModel):
) )
return values return values
@root_validator(pre=False, skip_on_failure=True) @root_validator(pre=False, skip_on_failure=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if values["n"] < 1:

View File

@ -2,12 +2,11 @@
from __future__ import annotations from __future__ import annotations
import os
from typing import Callable, Dict, Optional, Union from typing import Callable, Dict, Optional, Union
import openai import openai
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import from_env, secret_from_env
from langchain_openai.embeddings.base import OpenAIEmbeddings from langchain_openai.embeddings.base import OpenAIEmbeddings
@ -100,7 +99,9 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
[-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188] [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188]
""" # noqa: E501 """ # noqa: E501
azure_endpoint: Union[str, None] = None azure_endpoint: Optional[str] = Field(
default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None)
)
"""Your Azure endpoint, including the resource. """Your Azure endpoint, including the resource.
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided. Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
@ -113,9 +114,26 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
If given sets the base client URL to include `/deployments/{azure_deployment}`. If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Note: this means you won't be able to use non-deployment endpoints.
""" """
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") # Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
openai_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env(
["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None
),
)
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided.""" """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Optional[SecretStr] = None openai_api_version: Optional[str] = Field(
default_factory=from_env("OPENAI_API_VERSION", default="2023-05-15")
)
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.
Set to "2023-05-15" by default if env variable `OPENAI_API_VERSION` is not set.
"""
azure_ad_token: Optional[SecretStr] = Field(
default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None)
)
"""Your Azure Active Directory token. """Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
@ -128,52 +146,16 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
Will be invoked on every request. Will be invoked on every request.
""" """
openai_api_version: Optional[str] = Field(default=None, alias="api_version") openai_api_type: Optional[str] = Field(
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" default_factory=from_env("OPENAI_API_TYPE", default="azure")
)
validate_base_url: bool = True validate_base_url: bool = True
chunk_size: int = 2048 chunk_size: int = 2048
"""Maximum number of texts to embed in each batch""" """Maximum number of texts to embed in each batch"""
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
openai_api_key = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = (
values["openai_api_base"]
if "openai_api_base" in values
else os.getenv("OPENAI_API_BASE")
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
"OPENAI_API_VERSION", default="2023-05-15"
)
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
)
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
)
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN")
values["azure_ad_token"] = (
convert_to_secret_str(azure_ad_token) if azure_ad_token else None
)
# For backwards compatibility. Before openai v1, no distinction was made # For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base). # between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"] openai_api_base = values["openai_api_base"]

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
import warnings import warnings
from typing import ( from typing import (
Any, Any,
@ -22,11 +21,7 @@ import openai
import tiktoken import tiktoken
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import ( from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -185,21 +180,37 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# to support Azure OpenAI Service custom deployment names # to support Azure OpenAI Service custom deployment names
deployment: Optional[str] = model deployment: Optional[str] = model
# TODO: Move to AzureOpenAIEmbeddings. # TODO: Move to AzureOpenAIEmbeddings.
openai_api_version: Optional[str] = Field(default=None, alias="api_version") openai_api_version: Optional[str] = Field(
default_factory=from_env("OPENAI_API_VERSION", default=None),
alias="api_version",
)
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
# to support Azure OpenAI Service custom endpoints # to support Azure OpenAI Service custom endpoints
openai_api_base: Optional[str] = Field(default=None, alias="base_url") openai_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service """Base URL path for API requests, leave blank if not using a proxy or service
emulator.""" emulator."""
# to support Azure OpenAI Service custom endpoints # to support Azure OpenAI Service custom endpoints
openai_api_type: Optional[str] = None openai_api_type: Optional[str] = Field(
default_factory=from_env("OPENAI_API_TYPE", default=None)
)
# to support explicit proxy for OpenAI # to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None openai_proxy: Optional[str] = Field(
default_factory=from_env("OPENAI_PROXY", default=None)
)
embedding_ctx_length: int = 8191 embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once.""" """The maximum number of tokens to embed at once."""
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_organization: Optional[str] = Field(default=None, alias="organization") openai_organization: Optional[str] = Field(
alias="organization",
default_factory=from_env(
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
),
)
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
allowed_special: Union[Literal["all"], Set[str], None] = None allowed_special: Union[Literal["all"], Set[str], None] = None
disallowed_special: Union[Literal["all"], Set[str], Sequence[str], None] = None disallowed_special: Union[Literal["all"], Set[str], Sequence[str], None] = None
@ -284,33 +295,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default=""
)
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
)
values["openai_api_version"] = get_from_dict_or_env(
values, "openai_api_version", "OPENAI_API_VERSION", default=""
)
# Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
raise ValueError( raise ValueError(
"If you are using Azure, " "If you are using Azure, "

View File

@ -1,13 +1,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai import openai
from langchain_core.language_models import LangSmithParams from langchain_core.language_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import from_env, secret_from_env
from langchain_openai.llms.base import BaseOpenAI from langchain_openai.llms.base import BaseOpenAI
@ -31,7 +30,9 @@ class AzureOpenAI(BaseOpenAI):
openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct") openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct")
""" """
azure_endpoint: Union[str, None] = None azure_endpoint: Optional[str] = Field(
default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None)
)
"""Your Azure endpoint, including the resource. """Your Azure endpoint, including the resource.
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided. Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
@ -44,11 +45,23 @@ class AzureOpenAI(BaseOpenAI):
If given sets the base client URL to include `/deployments/{azure_deployment}`. If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Note: this means you won't be able to use non-deployment endpoints.
""" """
openai_api_version: str = Field(default="", alias="api_version") openai_api_version: Optional[str] = Field(
alias="api_version",
default_factory=from_env("OPENAI_API_VERSION", default=None),
)
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") # Check OPENAI_KEY for backwards compatibility.
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided.""" # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
azure_ad_token: Optional[SecretStr] = None # other forms of azure credentials.
openai_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env(
["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None
),
)
azure_ad_token: Optional[SecretStr] = Field(
default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None)
)
"""Your Azure Active Directory token. """Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
@ -61,7 +74,9 @@ class AzureOpenAI(BaseOpenAI):
Will be invoked on every request. Will be invoked on every request.
""" """
openai_api_type: str = "" openai_api_type: Optional[str] = Field(
default_factory=from_env("OPENAI_API_TYPE", default="azure")
)
"""Legacy, for openai<1.0.0 support.""" """Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True validate_base_url: bool = True
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to """For backwards compatibility. If legacy val openai_api_base is passed in, try to
@ -85,7 +100,7 @@ class AzureOpenAI(BaseOpenAI):
"""Return whether this model can be serialized by Langchain.""" """Return whether this model can be serialized by Langchain."""
return True return True
@root_validator() @root_validator(pre=False, skip_on_failure=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if values["n"] < 1:
@ -94,43 +109,6 @@ class AzureOpenAI(BaseOpenAI):
raise ValueError("Cannot stream results when n > 1.") raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1: if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.") raise ValueError("Cannot stream results when best_of > 1.")
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
openai_api_key = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN")
values["azure_ad_token"] = (
convert_to_secret_str(azure_ad_token) if azure_ad_token else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
)
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
"OPENAI_API_VERSION"
)
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
)
# For backwards compatibility. Before openai v1, no distinction was made # For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base). # between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"] openai_api_base = values["openai_api_base"]

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
import sys import sys
from typing import ( from typing import (
AbstractSet, AbstractSet,
@ -28,12 +27,8 @@ from langchain_core.callbacks import (
from langchain_core.language_models.llms import BaseLLM from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import ( from langchain_core.utils import get_pydantic_field_names
convert_to_secret_str, from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -90,15 +85,26 @@ class BaseOpenAI(BaseLLM):
"""Generates best_of completions server-side and returns the "best".""" """Generates best_of completions server-side and returns the "best"."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url") openai_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service """Base URL path for API requests, leave blank if not using a proxy or service
emulator.""" emulator."""
openai_organization: Optional[str] = Field(default=None, alias="organization") openai_organization: Optional[str] = Field(
alias="organization",
default_factory=from_env(
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
),
)
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI # to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None openai_proxy: Optional[str] = Field(
default_factory=from_env("OPENAI_PROXY", default=None)
)
batch_size: int = 20 batch_size: int = 20
"""Batch size to use when passing multiple documents to generate.""" """Batch size to use when passing multiple documents to generate."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
@ -161,7 +167,7 @@ class BaseOpenAI(BaseLLM):
) )
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if values["n"] < 1:
@ -171,24 +177,6 @@ class BaseOpenAI(BaseLLM):
if values["streaming"] and values["best_of"] > 1: if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.") raise ValueError("Cannot stream results when best_of > 1.")
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
)
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
client_params = { client_params = {
"api_key": ( "api_key": (
values["openai_api_key"].get_secret_value() values["openai_api_key"].get_secret_value()

View File

@ -1,6 +1,7 @@
"""Test Azure OpenAI Chat API wrapper.""" """Test Azure OpenAI Chat API wrapper."""
import os import os
from unittest import mock
from langchain_openai import AzureChatOpenAI from langchain_openai import AzureChatOpenAI
@ -39,7 +40,7 @@ def test_initialize_more() -> None:
def test_initialize_azure_openai_with_openai_api_base_set() -> None: def test_initialize_azure_openai_with_openai_api_base_set() -> None:
os.environ["OPENAI_API_BASE"] = "https://api.openai.com" with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}):
llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg] llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg]
api_key="xyz", # type: ignore[arg-type] api_key="xyz", # type: ignore[arg-type]
azure_endpoint="my-base-url", azure_endpoint="my-base-url",

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type from typing import Tuple, Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@ -25,3 +25,25 @@ class TestOpenAIStandard(ChatModelUnitTests):
@pytest.mark.xfail(reason="AzureOpenAI does not support tool_choice='any'") @pytest.mark.xfail(reason="AzureOpenAI does not support tool_choice='any'")
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None: def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
super().test_bind_tool_pydantic(model) super().test_bind_tool_pydantic(model)
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"AZURE_OPENAI_API_KEY": "api_key",
"AZURE_OPENAI_ENDPOINT": "https://endpoint.com",
"AZURE_OPENAI_AD_TOKEN": "token",
"OPENAI_ORG_ID": "org_id",
"OPENAI_API_VERSION": "yyyy-mm-dd",
"OPENAI_API_TYPE": "type",
},
{},
{
"openai_api_key": "api_key",
"azure_endpoint": "https://endpoint.com",
"azure_ad_token": "token",
"openai_organization": "org_id",
"openai_api_version": "yyyy-mm-dd",
"openai_api_type": "type",
},
)

View File

@ -18,7 +18,7 @@ class TestOpenAIStandard(ChatModelUnitTests):
return ( return (
{ {
"OPENAI_API_KEY": "api_key", "OPENAI_API_KEY": "api_key",
"OPENAI_ORGANIZATION": "org_id", "OPENAI_ORG_ID": "org_id",
"OPENAI_API_BASE": "api_base", "OPENAI_API_BASE": "api_base",
"OPENAI_PROXY": "https://proxy.com", "OPENAI_PROXY": "https://proxy.com",
}, },

View File

@ -1,4 +1,5 @@
import os import os
from unittest import mock
from langchain_openai import AzureOpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings
@ -15,7 +16,7 @@ def test_initialize_azure_openai() -> None:
def test_intialize_azure_openai_with_base_set() -> None: def test_intialize_azure_openai_with_base_set() -> None:
os.environ["OPENAI_API_BASE"] = "https://api.openai.com" with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}):
embeddings = AzureOpenAIEmbeddings( # type: ignore[call-arg, call-arg] embeddings = AzureOpenAIEmbeddings( # type: ignore[call-arg, call-arg]
model="text-embedding-large", model="text-embedding-large",
api_key="xyz", # type: ignore[arg-type] api_key="xyz", # type: ignore[arg-type]

View File

@ -0,0 +1,38 @@
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings
from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests
from langchain_openai import AzureOpenAIEmbeddings
class TestAzureOpenAIStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
return AzureOpenAIEmbeddings
@property
def embedding_model_params(self) -> dict:
return {"api_key": "api_key", "azure_endpoint": "https://endpoint.com"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"AZURE_OPENAI_API_KEY": "api_key",
"AZURE_OPENAI_ENDPOINT": "https://endpoint.com",
"AZURE_OPENAI_AD_TOKEN": "token",
"OPENAI_ORG_ID": "org_id",
"OPENAI_API_VERSION": "yyyy-mm-dd",
"OPENAI_API_TYPE": "type",
},
{},
{
"openai_api_key": "api_key",
"azure_endpoint": "https://endpoint.com",
"azure_ad_token": "token",
"openai_organization": "org_id",
"openai_api_version": "yyyy-mm-dd",
"openai_api_type": "type",
},
)

View File

@ -0,0 +1,32 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings
from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests
from langchain_openai import OpenAIEmbeddings
class TestOpenAIStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
return OpenAIEmbeddings
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"OPENAI_API_KEY": "api_key",
"OPENAI_ORG_ID": "org_id",
"OPENAI_API_BASE": "api_base",
"OPENAI_PROXY": "https://proxy.com",
},
{},
{
"openai_api_key": "api_key",
"openai_organization": "org_id",
"openai_api_base": "api_base",
"openai_proxy": "https://proxy.com",
},
)