diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 864e7758f37..205d028c49d 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -22,9 +22,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ client: Any #: :meta private: - model_name: str = "babbage" - """Model name to use.""" - + document_model_name: str = "text-embedding-ada-002" + query_model_name: str = "text-embedding-ada-002" openai_api_key: Optional[str] = None class Config: @@ -32,6 +31,26 @@ class OpenAIEmbeddings(BaseModel, Embeddings): extra = Extra.forbid + # TODO: deprecate this + @root_validator(pre=True) + def get_model_names(cls, values: Dict) -> Dict: + """Get model names from just old model name.""" + if "model_name" in values: + if "document_model_name" in values: + raise ValueError( + "Both `model_name` and `document_model_name` were provided, " + "but only one should be." + ) + if "query_model_name" in values: + raise ValueError( + "Both `model_name` and `query_model_name` were provided, " + "but only one should be." + ) + model_name = values.pop("model_name") + values["document_model_name"] = f"text-search-{model_name}-doc-001" + values["query_model_name"] = f"text-search-{model_name}-query-001" + return values + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -66,7 +85,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ responses = [ - self._embedding_func(text, engine=f"text-search-{self.model_name}-doc-001") + self._embedding_func(text, engine=self.document_model_name) for text in texts ] return responses @@ -80,7 +99,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - embedding = self._embedding_func( - text, engine=f"text-search-{self.model_name}-query-001" - ) + embedding = self._embedding_func(text, engine=self.query_model_name) return embedding