mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
fix: separate model and deployment for OpenAIEmbeddings (#3076)
Separated the deployment from model to support Azure OpenAI Embeddings properly. Also removed the deprecated document_model_name and query_model_name attributes.
This commit is contained in:
parent
4adfd790f0
commit
6e48107734
@ -92,7 +92,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
|
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
|
||||||
|
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
embeddings = OpenAIEmbeddings(model="your-embeddings-deployment-name")
|
embeddings = OpenAIEmbeddings(
|
||||||
|
deployment="your-embeddings-deployment-name",
|
||||||
|
model="your-embeddings-model-name"
|
||||||
|
)
|
||||||
text = "This is a test query."
|
text = "This is a test query."
|
||||||
query_result = embeddings.embed_query(text)
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
@ -100,12 +103,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
model: str = "text-embedding-ada-002"
|
model: str = "text-embedding-ada-002"
|
||||||
|
deployment: str = model # to support Azure OpenAI Service custom deployment names
|
||||||
# TODO: deprecate these two in favor of model
|
|
||||||
# https://community.openai.com/t/api-update-engines-models/18597
|
|
||||||
# https://github.com/openai/openai-python/issues/132
|
|
||||||
document_model_name: str = "text-embedding-ada-002"
|
|
||||||
query_model_name: str = "text-embedding-ada-002"
|
|
||||||
embedding_ctx_length: int = 8191
|
embedding_ctx_length: int = 8191
|
||||||
openai_api_key: Optional[str] = None
|
openai_api_key: Optional[str] = None
|
||||||
openai_organization: Optional[str] = None
|
openai_organization: Optional[str] = None
|
||||||
@ -121,51 +119,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
# TODO: deprecate this
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def get_model_names(cls, values: Dict) -> Dict:
|
|
||||||
# model_name is for first generation, and model is for second generation.
|
|
||||||
# Both are not allowed together.
|
|
||||||
if "model_name" in values and "model" in values:
|
|
||||||
raise ValueError(
|
|
||||||
"Both `model_name` and `model` were provided, "
|
|
||||||
"but only one should be."
|
|
||||||
)
|
|
||||||
|
|
||||||
"""Get model names from just old model name."""
|
|
||||||
if "model_name" in values:
|
|
||||||
if "document_model_name" in values:
|
|
||||||
raise ValueError(
|
|
||||||
"Both `model_name` and `document_model_name` were provided, "
|
|
||||||
"but only one should be."
|
|
||||||
)
|
|
||||||
if "query_model_name" in values:
|
|
||||||
raise ValueError(
|
|
||||||
"Both `model_name` and `query_model_name` were provided, "
|
|
||||||
"but only one should be."
|
|
||||||
)
|
|
||||||
model_name = values.pop("model_name")
|
|
||||||
values["document_model_name"] = f"text-search-{model_name}-doc-001"
|
|
||||||
values["query_model_name"] = f"text-search-{model_name}-query-001"
|
|
||||||
|
|
||||||
# Set document/query model names from model parameter.
|
|
||||||
if "model" in values:
|
|
||||||
if "document_model_name" in values:
|
|
||||||
raise ValueError(
|
|
||||||
"Both `model` and `document_model_name` were provided, "
|
|
||||||
"but only one should be."
|
|
||||||
)
|
|
||||||
if "query_model_name" in values:
|
|
||||||
raise ValueError(
|
|
||||||
"Both `model` and `query_model_name` were provided, "
|
|
||||||
"but only one should be."
|
|
||||||
)
|
|
||||||
model = values.get("model")
|
|
||||||
values["document_model_name"] = model
|
|
||||||
values["query_model_name"] = model
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
@ -203,7 +156,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
indices = []
|
indices = []
|
||||||
encoding = tiktoken.model.encoding_for_model(self.document_model_name)
|
encoding = tiktoken.model.encoding_for_model(self.model)
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
# replace newlines, which can negatively affect performance.
|
# replace newlines, which can negatively affect performance.
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
@ -222,7 +175,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
response = embed_with_retry(
|
response = embed_with_retry(
|
||||||
self,
|
self,
|
||||||
input=tokens[i : i + _chunk_size],
|
input=tokens[i : i + _chunk_size],
|
||||||
engine=self.document_model_name,
|
engine=self.deployment,
|
||||||
)
|
)
|
||||||
batched_embeddings += [r["embedding"] for r in response["data"]]
|
batched_embeddings += [r["embedding"] for r in response["data"]]
|
||||||
|
|
||||||
@ -272,7 +225,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
# handle batches of large input text
|
# handle batches of large input text
|
||||||
if self.embedding_ctx_length > 0:
|
if self.embedding_ctx_length > 0:
|
||||||
return self._get_len_safe_embeddings(texts, engine=self.document_model_name)
|
return self._get_len_safe_embeddings(texts, engine=self.deployment)
|
||||||
else:
|
else:
|
||||||
results = []
|
results = []
|
||||||
_chunk_size = chunk_size or self.chunk_size
|
_chunk_size = chunk_size or self.chunk_size
|
||||||
@ -280,7 +233,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
response = embed_with_retry(
|
response = embed_with_retry(
|
||||||
self,
|
self,
|
||||||
input=texts[i : i + _chunk_size],
|
input=texts[i : i + _chunk_size],
|
||||||
engine=self.document_model_name,
|
engine=self.deployment,
|
||||||
)
|
)
|
||||||
results += [r["embedding"] for r in response["data"]]
|
results += [r["embedding"] for r in response["data"]]
|
||||||
return results
|
return results
|
||||||
@ -294,5 +247,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embedding for the text.
|
Embedding for the text.
|
||||||
"""
|
"""
|
||||||
embedding = self._embedding_func(text, engine=self.query_model_name)
|
embedding = self._embedding_func(text, engine=self.deployment)
|
||||||
return embedding
|
return embedding
|
||||||
|
Loading…
Reference in New Issue
Block a user