update azure embedding docs (#13091)

This commit is contained in:
Bagatur
2023-11-08 13:39:31 -08:00
committed by GitHub
parent 9fdfac22c2
commit 1703f132c6
3 changed files with 105 additions and 14 deletions

View File

@@ -21,7 +21,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
Example: `https://example-resource.azure.openai.com/`
"""
azure_deployment: Optional[str] = None
deployment: Optional[str] = Field(default=None, alias="azure_deployment")
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
@@ -104,15 +104,15 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
)
if values["azure_deployment"]:
if values["deployment"]:
warnings.warn(
"As of openai>=1.0.0, if `azure_deployment` (or alias "
"As of openai>=1.0.0, if `deployment` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `azure_deployment` (or alias `azure_deployment`) "
"Instead use `deployment` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["azure_deployment"] not in values["openai_api_base"]:
if values["deployment"] not in values["openai_api_base"]:
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"(or alias `base_url`) is specified it is expected to be "
@@ -122,13 +122,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
f"{values['openai_api_base']}."
)
values["openai_api_base"] += (
"/deployments/" + values["azure_deployment"]
"/deployments/" + values["deployment"]
)
values["azure_deployment"] = None
values["deployment"] = None
client_params = {
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["azure_deployment"],
"azure_deployment": values["deployment"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],

View File

@@ -17,6 +17,7 @@ from typing import (
Set,
Tuple,
Union,
cast,
)
import numpy as np
@@ -182,7 +183,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
async_client: Any = None #: :meta private:
model: str = "text-embedding-ada-002"
# to support Azure OpenAI Service custom deployment names
deployment: str = model
deployment: Optional[str] = model
# TODO: Move to AzureOpenAIEmbeddings.
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
@@ -546,7 +547,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
return self._get_len_safe_embeddings(texts, engine=self.deployment)
engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(texts, engine=engine)
async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
@@ -563,7 +565,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
return await self._aget_len_safe_embeddings(texts, engine=self.deployment)
engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(texts, engine=engine)
def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.