diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index 7fc0b0c66aa..9db1df4f1e4 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -3,7 +3,6 @@ from enum import Enum from typing import Any, Dict, List, Optional import numpy as np -from langchain_community.embeddings.openai import OpenAIEmbeddings from langchain_core.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -18,6 +17,27 @@ from langchain.schema import RUN_KEY from langchain.utils.math import cosine_similarity +def _embedding_factory() -> Embeddings: + """Create an Embeddings object. + Returns: + Embeddings: The created Embeddings object. + """ + # Here for backwards compatibility. + # Generally, we do not want to be seeing imports from langchain community + # or partner packages in langchain. + try: + from langchain_openai import OpenAIEmbeddings + except ImportError: + try: + from langchain_community.embeddings.openai import OpenAIEmbeddings + except ImportError: + raise ImportError( + "Could not import OpenAIEmbeddings. Please install the " + "OpenAIEmbeddings package using `pip install langchain-openai`." + ) + return OpenAIEmbeddings() + + class EmbeddingDistance(str, Enum): """Embedding Distance Metric. @@ -45,7 +65,7 @@ class _EmbeddingDistanceChainMixin(Chain): for comparing the embeddings. """ - embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings) + embeddings: Embeddings = Field(default_factory=_embedding_factory) distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE) @root_validator(pre=False) @@ -59,7 +79,28 @@ class _EmbeddingDistanceChainMixin(Chain): Dict[str, Any]: The validated values. """ embeddings = values.get("embeddings") - if isinstance(embeddings, OpenAIEmbeddings): + types_ = [] + try: + from langchain_openai import OpenAIEmbeddings + + types_.append(OpenAIEmbeddings) + except ImportError: + pass + + try: + from langchain_community.embeddings.openai import OpenAIEmbeddings + + types_.append(OpenAIEmbeddings) + except ImportError: + pass + + if not types_: + raise ImportError( + "Could not import OpenAIEmbeddings. Please install the " + "OpenAIEmbeddings package using `pip install langchain-openai`." + ) + + if isinstance(embeddings, tuple(types_)): try: import tiktoken # noqa: F401 except ImportError: