mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
langchain[patch]: embedddings distance move import of openai embeddings into local scope (#21148)
This commit is contained in:
parent
8b4b75e543
commit
b879184595
@ -3,7 +3,6 @@ from enum import Enum
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
@ -18,6 +17,27 @@ from langchain.schema import RUN_KEY
|
|||||||
from langchain.utils.math import cosine_similarity
|
from langchain.utils.math import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
|
def _embedding_factory() -> Embeddings:
|
||||||
|
"""Create an Embeddings object.
|
||||||
|
Returns:
|
||||||
|
Embeddings: The created Embeddings object.
|
||||||
|
"""
|
||||||
|
# Here for backwards compatibility.
|
||||||
|
# Generally, we do not want to be seeing imports from langchain community
|
||||||
|
# or partner packages in langchain.
|
||||||
|
try:
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import OpenAIEmbeddings. Please install the "
|
||||||
|
"OpenAIEmbeddings package using `pip install langchain-openai`."
|
||||||
|
)
|
||||||
|
return OpenAIEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingDistance(str, Enum):
|
class EmbeddingDistance(str, Enum):
|
||||||
"""Embedding Distance Metric.
|
"""Embedding Distance Metric.
|
||||||
|
|
||||||
@ -45,7 +65,7 @@ class _EmbeddingDistanceChainMixin(Chain):
|
|||||||
for comparing the embeddings.
|
for comparing the embeddings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
|
embeddings: Embeddings = Field(default_factory=_embedding_factory)
|
||||||
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
||||||
|
|
||||||
@root_validator(pre=False)
|
@root_validator(pre=False)
|
||||||
@ -59,7 +79,28 @@ class _EmbeddingDistanceChainMixin(Chain):
|
|||||||
Dict[str, Any]: The validated values.
|
Dict[str, Any]: The validated values.
|
||||||
"""
|
"""
|
||||||
embeddings = values.get("embeddings")
|
embeddings = values.get("embeddings")
|
||||||
if isinstance(embeddings, OpenAIEmbeddings):
|
types_ = []
|
||||||
|
try:
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
types_.append(OpenAIEmbeddings)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
types_.append(OpenAIEmbeddings)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not types_:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import OpenAIEmbeddings. Please install the "
|
||||||
|
"OpenAIEmbeddings package using `pip install langchain-openai`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(embeddings, tuple(types_)):
|
||||||
try:
|
try:
|
||||||
import tiktoken # noqa: F401
|
import tiktoken # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
Loading…
Reference in New Issue
Block a user