mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
langchain[patch]: embedddings distance move import of openai embeddings into local scope (#21148)
This commit is contained in:
parent
8b4b75e543
commit
b879184595
@ -3,7 +3,6 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
@ -18,6 +17,27 @@ from langchain.schema import RUN_KEY
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
def _embedding_factory() -> Embeddings:
|
||||
"""Create an Embeddings object.
|
||||
Returns:
|
||||
Embeddings: The created Embeddings object.
|
||||
"""
|
||||
# Here for backwards compatibility.
|
||||
# Generally, we do not want to be seeing imports from langchain community
|
||||
# or partner packages in langchain.
|
||||
try:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
except ImportError:
|
||||
try:
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import OpenAIEmbeddings. Please install the "
|
||||
"OpenAIEmbeddings package using `pip install langchain-openai`."
|
||||
)
|
||||
return OpenAIEmbeddings()
|
||||
|
||||
|
||||
class EmbeddingDistance(str, Enum):
|
||||
"""Embedding Distance Metric.
|
||||
|
||||
@ -45,7 +65,7 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
for comparing the embeddings.
|
||||
"""
|
||||
|
||||
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
|
||||
embeddings: Embeddings = Field(default_factory=_embedding_factory)
|
||||
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
||||
|
||||
@root_validator(pre=False)
|
||||
@ -59,7 +79,28 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
Dict[str, Any]: The validated values.
|
||||
"""
|
||||
embeddings = values.get("embeddings")
|
||||
if isinstance(embeddings, OpenAIEmbeddings):
|
||||
types_ = []
|
||||
try:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
types_.append(OpenAIEmbeddings)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
types_.append(OpenAIEmbeddings)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not types_:
|
||||
raise ImportError(
|
||||
"Could not import OpenAIEmbeddings. Please install the "
|
||||
"OpenAIEmbeddings package using `pip install langchain-openai`."
|
||||
)
|
||||
|
||||
if isinstance(embeddings, tuple(types_)):
|
||||
try:
|
||||
import tiktoken # noqa: F401
|
||||
except ImportError:
|
||||
|
Loading…
Reference in New Issue
Block a user