Azure OpenAI Embeddings (#13039)

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Erick Friis 2023-11-08 12:37:17 -08:00 committed by GitHub
parent 37561d8986
commit f15f8e01cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 492 additions and 156 deletions

View File

@ -2,12 +2,15 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
import warnings
from typing import Any, Dict, Union from typing import Any, Dict, Union
from langchain.chat_models.openai import ChatOpenAI, _is_openai_v1 from langchain.chat_models.openai import ChatOpenAI
from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.schema import ChatResult from langchain.schema import ChatResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from langchain.utils.openai import is_openai_v1
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,48 +54,82 @@ class AzureChatOpenAI(ChatOpenAI):
in, even if not explicitly saved on this class. in, even if not explicitly saved on this class.
""" """
deployment_name: str = Field(default="", alias="azure_deployment") azure_endpoint: Union[str, None] = None
model_version: str = "" """Your Azure endpoint, including the resource.
openai_api_type: str = ""
openai_api_base: str = Field(default="", alias="azure_endpoint") Example: `https://example-resource.azure.openai.com/`
"""
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints.
"""
openai_api_version: str = Field(default="", alias="api_version") openai_api_version: str = Field(default="", alias="api_version")
openai_api_key: str = Field(default="", alias="api_key") """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_organization: str = Field(default="", alias="organization") openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
openai_proxy: str = "" """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
For more:
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
""" # noqa: E501
azure_ad_token_provider: Union[str, None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
"""
model_version: str = ""
"""Legacy, for openai<1.0.0 support."""
openai_api_type: str = ""
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env( if values["n"] < 1:
values, raise ValueError("n must be at least 1.")
"openai_api_key", if values["n"] > 1 and values["streaming"]:
"OPENAI_API_KEY", raise ValueError("n must be 1 when streaming.")
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
) )
values["openai_api_base"] = get_from_dict_or_env( values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values, "OPENAI_API_BASE"
"openai_api_base",
"OPENAI_API_BASE",
) )
values["openai_api_version"] = get_from_dict_or_env( values["openai_api_version"] = values["openai_api_version"] or os.getenv(
values, "OPENAI_API_VERSION"
"openai_api_version",
"OPENAI_API_VERSION",
) )
# Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
)
values["openai_api_type"] = get_from_dict_or_env( values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default="azure" values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
) )
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
values["openai_proxy"] = get_from_dict_or_env( values["openai_proxy"] = get_from_dict_or_env(
values, values, "openai_proxy", "OPENAI_PROXY", default=""
"openai_proxy",
"OPENAI_PROXY",
default="",
) )
try: try:
import openai import openai
@ -101,37 +138,69 @@ class AzureChatOpenAI(ChatOpenAI):
"Could not import openai python package. " "Could not import openai python package. "
"Please install it with `pip install openai`." "Please install it with `pip install openai`."
) )
if _is_openai_v1(): if is_openai_v1():
values["client"] = openai.AzureOpenAI( # For backwards compatibility. Before openai v1, no distinction was made
azure_endpoint=values["openai_api_base"], # between azure_endpoint and base_url (openai_api_base).
api_key=values["openai_api_key"], openai_api_base = values["openai_api_base"]
timeout=values["request_timeout"], if openai_api_base and values["validate_base_url"]:
max_retries=values["max_retries"], if "/openai" not in openai_api_base:
organization=values["openai_organization"], values["openai_api_base"] = (
api_version=values["openai_api_version"], values["openai_api_base"].rstrip("/") + "/openai"
azure_deployment=values["deployment_name"], )
).chat.completions warnings.warn(
"As of openai>=1.0.0, Azure endpoints should be specified via "
f"the `azure_endpoint` param not `openai_api_base` "
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
)
if values["deployment_name"]:
warnings.warn(
"As of openai>=1.0.0, if `deployment_name` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `deployment_name` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["deployment_name"] not in values["openai_api_base"]:
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"(or alias `base_url`) is specified it is expected to be "
"of the form "
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
f"Updating {openai_api_base} to "
f"{values['openai_api_base']}."
)
values["openai_api_base"] += (
"/deployments/" + values["deployment_name"]
)
values["deployment_name"] = None
client_params = {
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment_name"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
}
values["client"] = openai.AzureOpenAI(**client_params).chat.completions
values["async_client"] = openai.AsyncAzureOpenAI( values["async_client"] = openai.AsyncAzureOpenAI(
azure_endpoint=values["openai_api_base"], **client_params
api_key=values["openai_api_key"],
timeout=values["request_timeout"],
max_retries=values["max_retries"],
organization=values["openai_organization"],
api_version=values["openai_api_version"],
azure_deployment=values["deployment_name"],
).chat.completions ).chat.completions
else: else:
values["client"] = openai.ChatCompletion values["client"] = openai.ChatCompletion
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values return values
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""
if _is_openai_v1(): if is_openai_v1():
return super()._default_params return super()._default_params
else: else:
return { return {
@ -147,7 +216,7 @@ class AzureChatOpenAI(ChatOpenAI):
@property @property
def _client_params(self) -> Dict[str, Any]: def _client_params(self) -> Dict[str, Any]:
"""Get the config params used for the openai client.""" """Get the config params used for the openai client."""
if _is_openai_v1(): if is_openai_v1():
return super()._client_params return super()._client_params
else: else:
return { return {

View File

@ -2,8 +2,8 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
import sys import sys
from importlib.metadata import version
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -20,8 +20,6 @@ from typing import (
Union, Union,
) )
from packaging.version import Version, parse
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -51,6 +49,7 @@ from langchain.utils import (
get_from_dict_or_env, get_from_dict_or_env,
get_pydantic_field_names, get_pydantic_field_names,
) )
from langchain.utils.openai import is_openai_v1
if TYPE_CHECKING: if TYPE_CHECKING:
import httpx import httpx
@ -98,7 +97,7 @@ async def acompletion_with_retry(
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Use tenacity to retry the async completion call.""" """Use tenacity to retry the async completion call."""
if _is_openai_v1(): if is_openai_v1():
return await llm.async_client.create(**kwargs) return await llm.async_client.create(**kwargs)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -140,11 +139,6 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) return default_class(content=content)
def _is_openai_v1() -> bool:
_version = parse(version("openai"))
return _version >= Version("1.0.0")
class ChatOpenAI(BaseChatModel): class ChatOpenAI(BaseChatModel):
"""`OpenAI` Chat large language models API. """`OpenAI` Chat large language models API.
@ -169,13 +163,13 @@ class ChatOpenAI(BaseChatModel):
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: Dict[str, Any] = {}
if self.openai_organization != "": if self.openai_organization:
attributes["openai_organization"] = self.openai_organization attributes["openai_organization"] = self.openai_organization
if self.openai_api_base != "": if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base attributes["openai_api_base"] = self.openai_api_base
if self.openai_proxy != "": if self.openai_proxy:
attributes["openai_proxy"] = self.openai_proxy attributes["openai_proxy"] = self.openai_proxy
return attributes return attributes
@ -197,10 +191,12 @@ class ChatOpenAI(BaseChatModel):
# Check for classes that derive from this class (as some of them # Check for classes that derive from this class (as some of them
# may assume openai_api_key is a str) # may assume openai_api_key is a str)
openai_api_key: Optional[str] = Field(default=None, alias="api_key") openai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Base URL path for API requests, """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
leave blank if not using a proxy or service emulator."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url") openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(default=None, alias="organization") openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI # to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None openai_proxy: Optional[str] = None
request_timeout: Union[float, Tuple[float, float], httpx.Timeout, None] = Field( request_timeout: Union[float, Tuple[float, float], httpx.Timeout, None] = Field(
@ -225,6 +221,11 @@ class ChatOpenAI(BaseChatModel):
when using one of the many model providers that expose an OpenAI-like when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here.""" when tiktoken is called, you can specify a model name to use here."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[httpx.Client, None] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -260,20 +261,22 @@ class ChatOpenAI(BaseChatModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
values["openai_api_key"] = get_from_dict_or_env( values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY" values, "openai_api_key", "OPENAI_API_KEY"
) )
values["openai_organization"] = get_from_dict_or_env( # Check OPENAI_ORGANIZATION for backwards compatibility.
values, values["openai_organization"] = (
"openai_organization", values["openai_organization"]
"OPENAI_ORGANIZATION", or os.getenv("OPENAI_ORG_ID")
default="", or os.getenv("OPENAI_ORGANIZATION")
) )
values["openai_api_base"] = get_from_dict_or_env( values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values, "OPENAI_API_BASE"
"openai_api_base",
"OPENAI_API_BASE",
default="",
) )
values["openai_proxy"] = get_from_dict_or_env( values["openai_proxy"] = get_from_dict_or_env(
values, values,
@ -285,32 +288,28 @@ class ChatOpenAI(BaseChatModel):
import openai import openai
except ImportError: except ImportError:
raise ValueError( raise ImportError(
"Could not import openai python package. " "Could not import openai python package. "
"Please install it with `pip install openai`." "Please install it with `pip install openai`."
) )
if _is_openai_v1(): if is_openai_v1():
values["client"] = openai.OpenAI( client_params = {
api_key=values["openai_api_key"], "api_key": values["openai_api_key"],
timeout=values["request_timeout"], "organization": values["openai_organization"],
max_retries=values["max_retries"], "base_url": values["openai_api_base"],
organization=values["openai_organization"], "timeout": values["request_timeout"],
base_url=values["openai_api_base"] or None, "max_retries": values["max_retries"],
).chat.completions "default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
}
values["client"] = openai.OpenAI(**client_params).chat.completions
values["async_client"] = openai.AsyncOpenAI( values["async_client"] = openai.AsyncOpenAI(
api_key=values["openai_api_key"], **client_params
timeout=values["request_timeout"],
max_retries=values["max_retries"],
organization=values["openai_organization"],
base_url=values["openai_api_base"] or None,
).chat.completions ).chat.completions
else: else:
values["client"] = openai.ChatCompletion values["client"] = openai.ChatCompletion
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values return values
@property @property
@ -331,7 +330,7 @@ class ChatOpenAI(BaseChatModel):
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any: ) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
if _is_openai_v1(): if is_openai_v1():
return self.client.create(**kwargs) return self.client.create(**kwargs)
retry_decorator = _create_retry_decorator(self, run_manager=run_manager) retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@ -510,7 +509,7 @@ class ChatOpenAI(BaseChatModel):
openai_creds: Dict[str, Any] = { openai_creds: Dict[str, Any] = {
"model": self.model_name, "model": self.model_name,
} }
if not _is_openai_v1(): if not is_openai_v1():
openai_creds.update( openai_creds.update(
{ {
"api_key": self.openai_api_key, "api_key": self.openai_api_key,

View File

@ -19,6 +19,7 @@ from langchain.embeddings.aleph_alpha import (
AlephAlphaSymmetricSemanticEmbedding, AlephAlphaSymmetricSemanticEmbedding,
) )
from langchain.embeddings.awa import AwaEmbeddings from langchain.embeddings.awa import AwaEmbeddings
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
from langchain.embeddings.bedrock import BedrockEmbeddings from langchain.embeddings.bedrock import BedrockEmbeddings
from langchain.embeddings.cache import CacheBackedEmbeddings from langchain.embeddings.cache import CacheBackedEmbeddings
@ -72,6 +73,7 @@ logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
"OpenAIEmbeddings", "OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"CacheBackedEmbeddings", "CacheBackedEmbeddings",
"ClarifaiEmbeddings", "ClarifaiEmbeddings",
"CohereEmbeddings", "CohereEmbeddings",

View File

@ -0,0 +1,149 @@
"""Azure OpenAI embeddings wrapper."""
from __future__ import annotations
import os
import warnings
from typing import Dict, Optional, Union
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.pydantic_v1 import Field, root_validator
from langchain.utils import get_from_dict_or_env
from langchain.utils.openai import is_openai_v1
class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"""`Azure OpenAI` Embeddings API."""
azure_endpoint: Union[str, None] = None
"""Your Azure endpoint, including the resource.
Example: `https://example-resource.azure.openai.com/`
"""
azure_deployment: Optional[str] = None
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints.
"""
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
For more:
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
""" # noqa: E501
azure_ad_token_provider: Union[str, None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
"""
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
validate_base_url: bool = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
"OPENAI_API_VERSION", default="2023-05-15"
)
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
)
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
)
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if is_openai_v1():
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"]
if openai_api_base and values["validate_base_url"]:
if "/openai" not in openai_api_base:
values["openai_api_base"] += "/openai"
warnings.warn(
"As of openai>=1.0.0, Azure endpoints should be specified via "
f"the `azure_endpoint` param not `openai_api_base` "
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
)
if values["azure_deployment"]:
warnings.warn(
"As of openai>=1.0.0, if `azure_deployment` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `azure_deployment` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["azure_deployment"] not in values["openai_api_base"]:
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"(or alias `base_url`) is specified it is expected to be "
"of the form "
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
f"Updating {openai_api_base} to "
f"{values['openai_api_base']}."
)
values["openai_api_base"] += (
"/deployments/" + values["azure_deployment"]
)
values["azure_deployment"] = None
client_params = {
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["azure_deployment"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
}
values["client"] = openai.AzureOpenAI(**client_params).embeddings
values["async_client"] = openai.AsyncAzureOpenAI(**client_params).embeddings
else:
values["client"] = openai.Embedding
return values
@property
def _llm_type(self) -> str:
return "azure-openai-chat"

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
import warnings import warnings
from importlib.metadata import version from importlib.metadata import version
from typing import ( from typing import (
@ -10,6 +11,7 @@ from typing import (
Dict, Dict,
List, List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -157,6 +159,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
.. code-block:: python .. code-block:: python
import os import os
os.environ["OPENAI_API_TYPE"] = "azure" os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_BASE"] = "https://<your-endpoint.openai.azure.com/" os.environ["OPENAI_API_BASE"] = "https://<your-endpoint.openai.azure.com/"
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key" os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
@ -178,23 +181,30 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
client: Any = None #: :meta private: client: Any = None #: :meta private:
async_client: Any = None #: :meta private: async_client: Any = None #: :meta private:
model: str = "text-embedding-ada-002" model: str = "text-embedding-ada-002"
deployment: str = model # to support Azure OpenAI Service custom deployment names # to support Azure OpenAI Service custom deployment names
openai_api_version: Optional[str] = None deployment: str = model
# TODO: Move to AzureOpenAIEmbeddings.
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
# to support Azure OpenAI Service custom endpoints # to support Azure OpenAI Service custom endpoints
openai_api_base: Optional[str] = None openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
# to support Azure OpenAI Service custom endpoints # to support Azure OpenAI Service custom endpoints
openai_api_type: Optional[str] = None openai_api_type: Optional[str] = None
# to support explicit proxy for OpenAI # to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None openai_proxy: Optional[str] = None
embedding_ctx_length: int = 8191 embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once.""" """The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_organization: Optional[str] = None """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
allowed_special: Union[Literal["all"], Set[str]] = set() allowed_special: Union[Literal["all"], Set[str]] = set()
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
chunk_size: int = 1000 chunk_size: int = 1000
"""Maximum number of texts to embed in each batch""" """Maximum number of texts to embed in each batch"""
max_retries: int = 6 max_retries: int = 2
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
request_timeout: Optional[Union[float, Tuple[float, float], httpx.Timeout]] = Field( request_timeout: Optional[Union[float, Tuple[float, float], httpx.Timeout]] = Field(
default=None, alias="timeout" default=None, alias="timeout"
@ -218,11 +228,17 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
skip_empty: bool = False skip_empty: bool = False
"""Whether to skip empty strings when embedding or raise an error. """Whether to skip empty strings when embedding or raise an error.
Defaults to not skipping.""" Defaults to not skipping."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[httpx.Client, None] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid extra = Extra.forbid
allow_population_by_field_name = True
@root_validator(pre=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@ -250,17 +266,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator(pre=True) @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env( values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY" values, "openai_api_key", "OPENAI_API_KEY"
) )
values["openai_api_base"] = get_from_dict_or_env( values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values, "OPENAI_API_BASE"
"openai_api_base",
"OPENAI_API_BASE",
default="",
) )
values["openai_api_type"] = get_from_dict_or_env( values["openai_api_type"] = get_from_dict_or_env(
values, values,
@ -275,61 +288,61 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
default="", default="",
) )
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
default_api_version = "2022-12-01" default_api_version = "2023-05-15"
# Azure OpenAI embedding models allow a maximum of 16 texts # Azure OpenAI embedding models allow a maximum of 16 texts
# at a time in each batch # at a time in each batch
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings # See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
default_chunk_size = 16 values["chunk_size"] = max(values["chunk_size"], 16)
else: else:
default_api_version = "" default_api_version = ""
default_chunk_size = 1000
values["openai_api_version"] = get_from_dict_or_env( values["openai_api_version"] = get_from_dict_or_env(
values, values,
"openai_api_version", "openai_api_version",
"OPENAI_API_VERSION", "OPENAI_API_VERSION",
default=default_api_version, default=default_api_version,
) )
values["openai_organization"] = get_from_dict_or_env( # Check OPENAI_ORGANIZATION for backwards compatibility.
values, values["openai_organization"] = (
"openai_organization", values["openai_organization"]
"OPENAI_ORGANIZATION", or os.getenv("OPENAI_ORG_ID")
default="", or os.getenv("OPENAI_ORGANIZATION")
) )
if "chunk_size" not in values:
values["chunk_size"] = default_chunk_size
try: try:
import openai import openai
if _is_openai_v1():
values["client"] = openai.OpenAI(
api_key=values.get("openai_api_key"),
timeout=values.get("request_timeout"),
max_retries=values.get("max_retries"),
organization=values.get("openai_organization"),
base_url=values.get("openai_api_base") or None,
).embeddings
values["async_client"] = openai.AsyncOpenAI(
api_key=values.get("openai_api_key"),
timeout=values.get("request_timeout"),
max_retries=values.get("max_retries"),
organization=values.get("openai_organization"),
base_url=values.get("openai_api_base") or None,
).embeddings
else:
values["client"] = openai.Embedding
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import openai python package. " "Could not import openai python package. "
"Please install it with `pip install openai`." "Please install it with `pip install openai`."
) )
else:
if _is_openai_v1():
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
warnings.warn(
"If you have openai>=1.0.0 installed and are using Azure, "
"please use the `AzureOpenAIEmbeddings` class."
)
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
}
values["client"] = openai.OpenAI(**client_params).embeddings
values["async_client"] = openai.AsyncOpenAI(**client_params).embeddings
else:
values["client"] = openai.Embedding
return values return values
@property @property
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> Dict[str, Any]:
openai_args: Dict[str, Any] = ( if _is_openai_v1():
{"model": self.model, **self.model_kwargs} openai_args: Dict = {"model": self.model, **self.model_kwargs}
if _is_openai_v1() else:
else { openai_args = {
"model": self.model, "model": self.model,
"request_timeout": self.request_timeout, "request_timeout": self.request_timeout,
"headers": self.headers, "headers": self.headers,
@ -340,22 +353,22 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"api_version": self.openai_api_version, "api_version": self.openai_api_version,
**self.model_kwargs, **self.model_kwargs,
} }
) if self.openai_api_type in ("azure", "azure_ad", "azuread"):
if self.openai_api_type in ("azure", "azure_ad", "azuread"): openai_args["engine"] = self.deployment
openai_args["engine"] = self.deployment # TODO: Look into proxy with openai v1.
if self.openai_proxy: if self.openai_proxy:
try: try:
import openai import openai
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import openai python package. " "Could not import openai python package. "
"Please install it with `pip install openai`." "Please install it with `pip install openai`."
) )
openai.proxy = { openai.proxy = {
"http": self.openai_proxy, "http": self.openai_proxy,
"https": self.openai_proxy, "https": self.openai_proxy,
} # type: ignore[assignment] # noqa: E501 } # type: ignore[assignment] # noqa: E501
return openai_args return openai_args
# please refer to # please refer to

View File

@ -0,0 +1,10 @@
from __future__ import annotations
from importlib.metadata import version
from packaging.version import Version, parse
def is_openai_v1() -> bool:
_version = parse(version("openai"))
return _version >= Version("1.0.0")

View File

@ -0,0 +1,93 @@
"""Test openai embeddings."""
import os
from typing import Any
import numpy as np
import pytest
from langchain.embeddings import AzureOpenAIEmbeddings
def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
return AzureOpenAIEmbeddings(
openai_api_version=os.environ.get("AZURE_OPENAI_API_VERSION", ""),
**kwargs,
)
def test_azure_openai_embedding_documents() -> None:
"""Test openai embeddings."""
documents = ["foo bar"]
embedding = _get_embeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 1536
def test_azure_openai_embedding_documents_multiple() -> None:
"""Test openai embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = _get_embeddings(chunk_size=2)
embedding.embedding_ctx_length = 8191
output = embedding.embed_documents(documents)
assert len(output) == 3
assert len(output[0]) == 1536
assert len(output[1]) == 1536
assert len(output[2]) == 1536
@pytest.mark.asyncio
async def test_azure_openai_embedding_documents_async_multiple() -> None:
"""Test openai embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = _get_embeddings(chunk_size=2)
embedding.embedding_ctx_length = 8191
output = await embedding.aembed_documents(documents)
assert len(output) == 3
assert len(output[0]) == 1536
assert len(output[1]) == 1536
assert len(output[2]) == 1536
def test_azure_openai_embedding_query() -> None:
"""Test openai embeddings."""
document = "foo bar"
embedding = _get_embeddings()
output = embedding.embed_query(document)
assert len(output) == 1536
@pytest.mark.asyncio
async def test_azure_openai_embedding_async_query() -> None:
"""Test openai embeddings."""
document = "foo bar"
embedding = _get_embeddings()
output = await embedding.aembed_query(document)
assert len(output) == 1536
@pytest.mark.skip(reason="Unblock scheduled testing. TODO: fix.")
def test_azure_openai_embedding_with_empty_string() -> None:
"""Test openai embeddings with empty string."""
import openai
document = ["", "abc"]
embedding = _get_embeddings()
output = embedding.embed_documents(document)
assert len(output) == 2
assert len(output[0]) == 1536
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
"data"
][0]["embedding"]
assert np.allclose(output[0], expected_output)
assert len(output[1]) == 1536
def test_embed_documents_normalized() -> None:
output = _get_embeddings().embed_documents(["foo walked to the market"])
assert np.isclose(np.linalg.norm(output[0]), 1.0)
def test_embed_query_normalized() -> None:
output = _get_embeddings().embed_query("foo walked to the market")
assert np.isclose(np.linalg.norm(output), 1.0)

View File

@ -2,6 +2,7 @@ from langchain.embeddings import __all__
EXPECTED_ALL = [ EXPECTED_ALL = [
"OpenAIEmbeddings", "OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"CacheBackedEmbeddings", "CacheBackedEmbeddings",
"ClarifaiEmbeddings", "ClarifaiEmbeddings",
"CohereEmbeddings", "CohereEmbeddings",