mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
fix: separate model and deployment for OpenAIEmbeddings (#3076)
Separated the deployment from model to support Azure OpenAI Embeddings properly. Also removed the deprecated document_model_name and query_model_name attributes.
This commit is contained in:
parent
4adfd790f0
commit
6e48107734
@ -92,7 +92,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings(model="your-embeddings-deployment-name")
|
||||
embeddings = OpenAIEmbeddings(
|
||||
deployment="your-embeddings-deployment-name",
|
||||
model="your-embeddings-model-name"
|
||||
)
|
||||
text = "This is a test query."
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
@ -100,12 +103,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str = "text-embedding-ada-002"
|
||||
|
||||
# TODO: deprecate these two in favor of model
|
||||
# https://community.openai.com/t/api-update-engines-models/18597
|
||||
# https://github.com/openai/openai-python/issues/132
|
||||
document_model_name: str = "text-embedding-ada-002"
|
||||
query_model_name: str = "text-embedding-ada-002"
|
||||
deployment: str = model # to support Azure OpenAI Service custom deployment names
|
||||
embedding_ctx_length: int = 8191
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_organization: Optional[str] = None
|
||||
@ -121,51 +119,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
# TODO: deprecate this
|
||||
@root_validator(pre=True)
|
||||
def get_model_names(cls, values: Dict) -> Dict:
|
||||
# model_name is for first generation, and model is for second generation.
|
||||
# Both are not allowed together.
|
||||
if "model_name" in values and "model" in values:
|
||||
raise ValueError(
|
||||
"Both `model_name` and `model` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
|
||||
"""Get model names from just old model name."""
|
||||
if "model_name" in values:
|
||||
if "document_model_name" in values:
|
||||
raise ValueError(
|
||||
"Both `model_name` and `document_model_name` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
if "query_model_name" in values:
|
||||
raise ValueError(
|
||||
"Both `model_name` and `query_model_name` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
model_name = values.pop("model_name")
|
||||
values["document_model_name"] = f"text-search-{model_name}-doc-001"
|
||||
values["query_model_name"] = f"text-search-{model_name}-query-001"
|
||||
|
||||
# Set document/query model names from model parameter.
|
||||
if "model" in values:
|
||||
if "document_model_name" in values:
|
||||
raise ValueError(
|
||||
"Both `model` and `document_model_name` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
if "query_model_name" in values:
|
||||
raise ValueError(
|
||||
"Both `model` and `query_model_name` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
model = values.get("model")
|
||||
values["document_model_name"] = model
|
||||
values["query_model_name"] = model
|
||||
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@ -203,7 +156,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
encoding = tiktoken.model.encoding_for_model(self.document_model_name)
|
||||
encoding = tiktoken.model.encoding_for_model(self.model)
|
||||
for i, text in enumerate(texts):
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
@ -222,7 +175,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
response = embed_with_retry(
|
||||
self,
|
||||
input=tokens[i : i + _chunk_size],
|
||||
engine=self.document_model_name,
|
||||
engine=self.deployment,
|
||||
)
|
||||
batched_embeddings += [r["embedding"] for r in response["data"]]
|
||||
|
||||
@ -272,7 +225,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
# handle batches of large input text
|
||||
if self.embedding_ctx_length > 0:
|
||||
return self._get_len_safe_embeddings(texts, engine=self.document_model_name)
|
||||
return self._get_len_safe_embeddings(texts, engine=self.deployment)
|
||||
else:
|
||||
results = []
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
@ -280,7 +233,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
response = embed_with_retry(
|
||||
self,
|
||||
input=texts[i : i + _chunk_size],
|
||||
engine=self.document_model_name,
|
||||
engine=self.deployment,
|
||||
)
|
||||
results += [r["embedding"] for r in response["data"]]
|
||||
return results
|
||||
@ -294,5 +247,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embedding = self._embedding_func(text, engine=self.query_model_name)
|
||||
embedding = self._embedding_func(text, engine=self.deployment)
|
||||
return embedding
|
||||
|
Loading…
Reference in New Issue
Block a user