diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 74765c6592a..6326f0f63a8 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -92,7 +92,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings): os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key" from langchain.embeddings.openai import OpenAIEmbeddings - embeddings = OpenAIEmbeddings(model="your-embeddings-deployment-name") + embeddings = OpenAIEmbeddings( + deployment="your-embeddings-deployment-name", + model="your-embeddings-model-name" + ) text = "This is a test query." query_result = embeddings.embed_query(text) @@ -100,12 +103,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: model: str = "text-embedding-ada-002" - - # TODO: deprecate these two in favor of model - # https://community.openai.com/t/api-update-engines-models/18597 - # https://github.com/openai/openai-python/issues/132 - document_model_name: str = "text-embedding-ada-002" - query_model_name: str = "text-embedding-ada-002" + deployment: str = model # to support Azure OpenAI Service custom deployment names embedding_ctx_length: int = 8191 openai_api_key: Optional[str] = None openai_organization: Optional[str] = None @@ -121,51 +119,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - # TODO: deprecate this - @root_validator(pre=True) - def get_model_names(cls, values: Dict) -> Dict: - # model_name is for first generation, and model is for second generation. - # Both are not allowed together. - if "model_name" in values and "model" in values: - raise ValueError( - "Both `model_name` and `model` were provided, " - "but only one should be." - ) - - """Get model names from just old model name.""" - if "model_name" in values: - if "document_model_name" in values: - raise ValueError( - "Both `model_name` and `document_model_name` were provided, " - "but only one should be." - ) - if "query_model_name" in values: - raise ValueError( - "Both `model_name` and `query_model_name` were provided, " - "but only one should be." - ) - model_name = values.pop("model_name") - values["document_model_name"] = f"text-search-{model_name}-doc-001" - values["query_model_name"] = f"text-search-{model_name}-query-001" - - # Set document/query model names from model parameter. - if "model" in values: - if "document_model_name" in values: - raise ValueError( - "Both `model` and `document_model_name` were provided, " - "but only one should be." - ) - if "query_model_name" in values: - raise ValueError( - "Both `model` and `query_model_name` were provided, " - "but only one should be." - ) - model = values.get("model") - values["document_model_name"] = model - values["query_model_name"] = model - - return values - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -203,7 +156,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): tokens = [] indices = [] - encoding = tiktoken.model.encoding_for_model(self.document_model_name) + encoding = tiktoken.model.encoding_for_model(self.model) for i, text in enumerate(texts): # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") @@ -222,7 +175,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): response = embed_with_retry( self, input=tokens[i : i + _chunk_size], - engine=self.document_model_name, + engine=self.deployment, ) batched_embeddings += [r["embedding"] for r in response["data"]] @@ -272,7 +225,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ # handle batches of large input text if self.embedding_ctx_length > 0: - return self._get_len_safe_embeddings(texts, engine=self.document_model_name) + return self._get_len_safe_embeddings(texts, engine=self.deployment) else: results = [] _chunk_size = chunk_size or self.chunk_size @@ -280,7 +233,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): response = embed_with_retry( self, input=texts[i : i + _chunk_size], - engine=self.document_model_name, + engine=self.deployment, ) results += [r["embedding"] for r in response["data"]] return results @@ -294,5 +247,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: Embedding for the text. """ - embedding = self._embedding_func(text, engine=self.query_model_name) + embedding = self._embedding_func(text, engine=self.deployment) return embedding