mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
update azure embedding docs (#13091)
This commit is contained in:
@@ -21,7 +21,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
|
||||
Example: `https://example-resource.azure.openai.com/`
|
||||
"""
|
||||
azure_deployment: Optional[str] = None
|
||||
deployment: Optional[str] = Field(default=None, alias="azure_deployment")
|
||||
"""A model deployment.
|
||||
|
||||
If given sets the base client URL to include `/deployments/{azure_deployment}`.
|
||||
@@ -104,15 +104,15 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
f"(or alias `base_url`). Updating `openai_api_base` from "
|
||||
f"{openai_api_base} to {values['openai_api_base']}."
|
||||
)
|
||||
if values["azure_deployment"]:
|
||||
if values["deployment"]:
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `azure_deployment` (or alias "
|
||||
"As of openai>=1.0.0, if `deployment` (or alias "
|
||||
"`azure_deployment`) is specified then "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. "
|
||||
"Instead use `azure_deployment` (or alias `azure_deployment`) "
|
||||
"Instead use `deployment` (or alias `azure_deployment`) "
|
||||
"and `azure_endpoint`."
|
||||
)
|
||||
if values["azure_deployment"] not in values["openai_api_base"]:
|
||||
if values["deployment"] not in values["openai_api_base"]:
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `openai_api_base` "
|
||||
"(or alias `base_url`) is specified it is expected to be "
|
||||
@@ -122,13 +122,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
f"{values['openai_api_base']}."
|
||||
)
|
||||
values["openai_api_base"] += (
|
||||
"/deployments/" + values["azure_deployment"]
|
||||
"/deployments/" + values["deployment"]
|
||||
)
|
||||
values["azure_deployment"] = None
|
||||
values["deployment"] = None
|
||||
client_params = {
|
||||
"api_version": values["openai_api_version"],
|
||||
"azure_endpoint": values["azure_endpoint"],
|
||||
"azure_deployment": values["azure_deployment"],
|
||||
"azure_deployment": values["deployment"],
|
||||
"api_key": values["openai_api_key"],
|
||||
"azure_ad_token": values["azure_ad_token"],
|
||||
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
||||
|
@@ -17,6 +17,7 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@@ -182,7 +183,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
async_client: Any = None #: :meta private:
|
||||
model: str = "text-embedding-ada-002"
|
||||
# to support Azure OpenAI Service custom deployment names
|
||||
deployment: str = model
|
||||
deployment: Optional[str] = model
|
||||
# TODO: Move to AzureOpenAIEmbeddings.
|
||||
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
||||
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
||||
@@ -546,7 +547,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
return self._get_len_safe_embeddings(texts, engine=self.deployment)
|
||||
engine = cast(str, self.deployment)
|
||||
return self._get_len_safe_embeddings(texts, engine=engine)
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
@@ -563,7 +565,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
return await self._aget_len_safe_embeddings(texts, engine=self.deployment)
|
||||
engine = cast(str, self.deployment)
|
||||
return await self._aget_len_safe_embeddings(texts, engine=engine)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
Reference in New Issue
Block a user