From 2d3137ce20cc32f342af106ee3bf6154ea87c6b4 Mon Sep 17 00:00:00 2001 From: Dev 2049 Date: Mon, 22 May 2023 15:35:53 -0700 Subject: [PATCH] rename --- langchain/cache.py | 4 ++-- langchain/chains/hyde/base.py | 8 ++++---- langchain/chains/router/embedding_router.py | 4 ++-- langchain/document_transformers.py | 6 +++--- langchain/embeddings/aleph_alpha.py | 4 ++-- langchain/embeddings/base.py | 6 +++++- langchain/embeddings/cohere.py | 4 ++-- langchain/embeddings/fake.py | 4 ++-- langchain/embeddings/google_palm.py | 4 ++-- langchain/embeddings/huggingface.py | 6 +++--- langchain/embeddings/huggingface_hub.py | 4 ++-- langchain/embeddings/jina.py | 4 ++-- langchain/embeddings/llamacpp.py | 4 ++-- langchain/embeddings/openai.py | 4 ++-- langchain/embeddings/sagemaker_endpoint.py | 4 ++-- langchain/embeddings/self_hosted.py | 4 ++-- langchain/embeddings/tensorflow_hub.py | 4 ++-- langchain/indexes/vectorstore.py | 4 ++-- .../prompts/example_selector/semantic_similarity.py | 6 +++--- .../document_compressors/embeddings_filter.py | 4 ++-- langchain/retrievers/knn.py | 8 ++++---- langchain/retrievers/milvus.py | 4 ++-- langchain/retrievers/pinecone_hybrid_search.py | 6 +++--- langchain/retrievers/svm.py | 8 ++++---- langchain/retrievers/zilliz.py | 4 ++-- langchain/vectorstores/analyticdb.py | 8 ++++---- langchain/vectorstores/annoy.py | 10 +++++----- langchain/vectorstores/atlas.py | 12 ++++++------ langchain/vectorstores/base.py | 10 +++++----- langchain/vectorstores/chroma.py | 12 ++++++------ langchain/vectorstores/deeplake.py | 8 ++++---- langchain/vectorstores/docarray/base.py | 4 ++-- langchain/vectorstores/docarray/hnsw.py | 10 +++++----- langchain/vectorstores/docarray/in_memory.py | 10 +++++----- langchain/vectorstores/elastic_vector_search.py | 8 ++++---- langchain/vectorstores/faiss.py | 10 +++++----- langchain/vectorstores/lancedb.py | 6 +++--- langchain/vectorstores/milvus.py | 10 +++++----- langchain/vectorstores/myscale.py | 8 ++++---- langchain/vectorstores/opensearch_vector_search.py | 6 +++--- langchain/vectorstores/pgvector.py | 8 ++++---- langchain/vectorstores/pinecone.py | 6 +++--- langchain/vectorstores/qdrant.py | 8 ++++---- langchain/vectorstores/redis.py | 8 ++++---- langchain/vectorstores/supabase.py | 10 +++++----- langchain/vectorstores/tair.py | 10 +++++----- langchain/vectorstores/weaviate.py | 6 +++--- langchain/vectorstores/zilliz.py | 6 +++--- .../vectorstores/fake_embeddings.py | 4 ++-- tests/integration_tests/vectorstores/test_qdrant.py | 6 +++--- tests/unit_tests/chains/test_hyde.py | 4 ++-- .../retrievers/test_time_weighted_retriever.py | 6 +++--- 52 files changed, 170 insertions(+), 166 deletions(-) diff --git a/langchain/cache.py b/langchain/cache.py index 5b2cf2c0e41..4b579220081 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -14,7 +14,7 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.schema import Generation from langchain.vectorstores.redis import Redis as RedisVectorstore @@ -178,7 +178,7 @@ class RedisSemanticCache(BaseCache): # TODO - implement a TTL policy in Redis def __init__( - self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2 + self, redis_url: str, embedding: EmbeddingModel, score_threshold: float = 0.2 ): """Initialize by passing in the `init` GPTCache func diff --git a/langchain/chains/hyde/base.py b/langchain/chains/hyde/base.py index 7764c85474f..6936cd24d65 100644 --- a/langchain/chains/hyde/base.py +++ b/langchain/chains/hyde/base.py @@ -14,16 +14,16 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.llm import LLMChain -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel -class HypotheticalDocumentEmbedder(Chain, Embeddings): +class HypotheticalDocumentEmbedder(Chain, EmbeddingModel): """Generate hypothetical document for query, and then embed that. Based on https://arxiv.org/abs/2212.10496 """ - base_embeddings: Embeddings + base_embeddings: EmbeddingModel llm_chain: LLMChain class Config: @@ -71,7 +71,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): def from_llm( cls, llm: BaseLanguageModel, - base_embeddings: Embeddings, + base_embeddings: EmbeddingModel, prompt_key: str, **kwargs: Any, ) -> HypotheticalDocumentEmbedder: diff --git a/langchain/chains/router/embedding_router.py b/langchain/chains/router/embedding_router.py index 57ad90d33d5..16b05b6a396 100644 --- a/langchain/chains/router/embedding_router.py +++ b/langchain/chains/router/embedding_router.py @@ -7,7 +7,7 @@ from pydantic import Extra from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.router.base import RouterChain from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore @@ -45,7 +45,7 @@ class EmbeddingRouterChain(RouterChain): cls, names_and_descriptions: Sequence[Tuple[str, Sequence[str]]], vectorstore_cls: Type[VectorStore], - embeddings: Embeddings, + embeddings: EmbeddingModel, **kwargs: Any, ) -> EmbeddingRouterChain: """Convenience constructor.""" diff --git a/langchain/document_transformers.py b/langchain/document_transformers.py index 7f17cb68985..eeed4e419aa 100644 --- a/langchain/document_transformers.py +++ b/langchain/document_transformers.py @@ -4,7 +4,7 @@ from typing import Any, Callable, List, Sequence import numpy as np from pydantic import BaseModel, Field -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.math_utils import cosine_similarity from langchain.schema import BaseDocumentTransformer, Document @@ -50,7 +50,7 @@ def _filter_similar_embeddings( def _get_embeddings_from_stateful_docs( - embeddings: Embeddings, documents: Sequence[_DocumentWithState] + embeddings: EmbeddingModel, documents: Sequence[_DocumentWithState] ) -> List[List[float]]: if len(documents) and "embedded_doc" in documents[0].state: embedded_documents = [doc.state["embedded_doc"] for doc in documents] @@ -66,7 +66,7 @@ def _get_embeddings_from_stateful_docs( class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): """Filter that drops redundant documents by comparing their embeddings.""" - embeddings: Embeddings + embeddings: EmbeddingModel """Embeddings to use for embedding document contents.""" similarity_fn: Callable = cosine_similarity """Similarity function for comparing documents. Function expected to take as input diff --git a/langchain/embeddings/aleph_alpha.py b/langchain/embeddings/aleph_alpha.py index f6ca5008ed4..543bdcd4798 100644 --- a/langchain/embeddings/aleph_alpha.py +++ b/langchain/embeddings/aleph_alpha.py @@ -2,11 +2,11 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, root_validator -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env -class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings): +class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, EmbeddingModel): """ Wrapper for Aleph Alpha's Asymmetric Embeddings AA provides you with an endpoint to embed a document and a query. diff --git a/langchain/embeddings/base.py b/langchain/embeddings/base.py index 4a56cd6acb8..a1617c47b41 100644 --- a/langchain/embeddings/base.py +++ b/langchain/embeddings/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import List -class Embeddings(ABC): +class EmbeddingModel(ABC): """Interface for embedding models.""" @abstractmethod @@ -13,3 +13,7 @@ class Embeddings(ABC): @abstractmethod def embed_query(self, text: str) -> List[float]: """Embed query text.""" + + +# For backwards compatibility. +Embedding = EmbeddingModel diff --git a/langchain/embeddings/cohere.py b/langchain/embeddings/cohere.py index a107dc02cc0..8cd020b2015 100644 --- a/langchain/embeddings/cohere.py +++ b/langchain/embeddings/cohere.py @@ -3,11 +3,11 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env -class CohereEmbeddings(BaseModel, Embeddings): +class CohereEmbeddings(BaseModel, EmbeddingModel): """Wrapper around Cohere embedding models. To use, you should have the ``cohere`` python package installed, and the diff --git a/langchain/embeddings/fake.py b/langchain/embeddings/fake.py index 9328f927e26..3a284e803b2 100644 --- a/langchain/embeddings/fake.py +++ b/langchain/embeddings/fake.py @@ -3,10 +3,10 @@ from typing import List import numpy as np from pydantic import BaseModel -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel -class FakeEmbeddings(Embeddings, BaseModel): +class FakeEmbeddings(EmbeddingModel, BaseModel): size: int def _get_embedding(self) -> List[float]: diff --git a/langchain/embeddings/google_palm.py b/langchain/embeddings/google_palm.py index 5be7e736f3c..ff8002e1a2a 100644 --- a/langchain/embeddings/google_palm.py +++ b/langchain/embeddings/google_palm.py @@ -13,7 +13,7 @@ from tenacity import ( wait_exponential, ) -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) @@ -54,7 +54,7 @@ def embed_with_retry( return _embed_with_retry(*args, **kwargs) -class GooglePalmEmbeddings(BaseModel, Embeddings): +class GooglePalmEmbeddings(BaseModel, EmbeddingModel): client: Any google_api_key: Optional[str] model_name: str = "models/embedding-gecko-001" diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index 04e0c76e6fa..77141d38989 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, Field -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" @@ -13,7 +13,7 @@ DEFAULT_QUERY_INSTRUCTION = ( ) -class HuggingFaceEmbeddings(BaseModel, Embeddings): +class HuggingFaceEmbeddings(BaseModel, EmbeddingModel): """Wrapper around sentence_transformers embedding models. To use, you should have the ``sentence_transformers`` python package installed. @@ -87,7 +87,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): return embedding.tolist() -class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): +class HuggingFaceInstructEmbeddings(BaseModel, EmbeddingModel): """Wrapper around sentence_transformers embedding models. To use, you should have the ``sentence_transformers`` diff --git a/langchain/embeddings/huggingface_hub.py b/langchain/embeddings/huggingface_hub.py index 6273ac26051..944ea6360db 100644 --- a/langchain/embeddings/huggingface_hub.py +++ b/langchain/embeddings/huggingface_hub.py @@ -3,14 +3,14 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env DEFAULT_REPO_ID = "sentence-transformers/all-mpnet-base-v2" VALID_TASKS = ("feature-extraction",) -class HuggingFaceHubEmbeddings(BaseModel, Embeddings): +class HuggingFaceHubEmbeddings(BaseModel, EmbeddingModel): """Wrapper around HuggingFaceHub embedding models. To use, you should have the ``huggingface_hub`` python package installed, and the diff --git a/langchain/embeddings/jina.py b/langchain/embeddings/jina.py index d980ee3a8ce..b373a1b7fd1 100644 --- a/langchain/embeddings/jina.py +++ b/langchain/embeddings/jina.py @@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional import requests from pydantic import BaseModel, root_validator -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env -class JinaEmbeddings(BaseModel, Embeddings): +class JinaEmbeddings(BaseModel, EmbeddingModel): client: Any #: :meta private: model_name: str = "ViT-B-32::openai" diff --git a/langchain/embeddings/llamacpp.py b/langchain/embeddings/llamacpp.py index 0c11731e98d..6aae08df4e8 100644 --- a/langchain/embeddings/llamacpp.py +++ b/langchain/embeddings/llamacpp.py @@ -3,10 +3,10 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, Field, root_validator -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel -class LlamaCppEmbeddings(BaseModel, Embeddings): +class LlamaCppEmbeddings(BaseModel, EmbeddingModel): """Wrapper around llama.cpp embedding models. To use, you should have the llama-cpp-python library installed, and provide the diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index f0e2215ec84..f7d14283955 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -25,7 +25,7 @@ from tenacity import ( wait_exponential, ) -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) @@ -64,7 +64,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: return _embed_with_retry(**kwargs) -class OpenAIEmbeddings(BaseModel, Embeddings): +class OpenAIEmbeddings(BaseModel, EmbeddingModel): """Wrapper around OpenAI embedding models. To use, you should have the ``openai`` python package installed, and the diff --git a/langchain/embeddings/sagemaker_endpoint.py b/langchain/embeddings/sagemaker_endpoint.py index 25ba961df58..3139eedd622 100644 --- a/langchain/embeddings/sagemaker_endpoint.py +++ b/langchain/embeddings/sagemaker_endpoint.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.llms.sagemaker_endpoint import ContentHandlerBase @@ -11,7 +11,7 @@ class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]) """Content handler for LLM class.""" -class SagemakerEndpointEmbeddings(BaseModel, Embeddings): +class SagemakerEndpointEmbeddings(BaseModel, EmbeddingModel): """Wrapper around custom Sagemaker Inference Endpoints. To use, you must supply the endpoint name from your deployed diff --git a/langchain/embeddings/self_hosted.py b/langchain/embeddings/self_hosted.py index c010d5d500a..dd9aff5abb8 100644 --- a/langchain/embeddings/self_hosted.py +++ b/langchain/embeddings/self_hosted.py @@ -3,7 +3,7 @@ from typing import Any, Callable, List from pydantic import Extra -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.llms import SelfHostedPipeline @@ -16,7 +16,7 @@ def _embed_documents(pipeline: Any, *args: Any, **kwargs: Any) -> List[List[floa return pipeline(*args, **kwargs) -class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings): +class SelfHostedEmbeddings(SelfHostedPipeline, EmbeddingModel): """Runs custom embedding models on self-hosted remote hardware. Supported hardware includes auto-launched instances on AWS, GCP, Azure, diff --git a/langchain/embeddings/tensorflow_hub.py b/langchain/embeddings/tensorflow_hub.py index 1e699ecbd31..92bb4358d5f 100644 --- a/langchain/embeddings/tensorflow_hub.py +++ b/langchain/embeddings/tensorflow_hub.py @@ -3,12 +3,12 @@ from typing import Any, List from pydantic import BaseModel, Extra -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" -class TensorflowHubEmbeddings(BaseModel, Embeddings): +class TensorflowHubEmbeddings(BaseModel, EmbeddingModel): """Wrapper around tensorflow_hub embedding models. To use, you should have the ``tensorflow_text`` python package installed. diff --git a/langchain/indexes/vectorstore.py b/langchain/indexes/vectorstore.py index f07d01a676d..709d3371bc1 100644 --- a/langchain/indexes/vectorstore.py +++ b/langchain/indexes/vectorstore.py @@ -6,7 +6,7 @@ from langchain.base_language import BaseLanguageModel from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.retrieval_qa.base import RetrievalQA from langchain.document_loaders.base import BaseLoader -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms.openai import OpenAI from langchain.schema import Document @@ -55,7 +55,7 @@ class VectorstoreIndexCreator(BaseModel): """Logic for creating indexes.""" vectorstore_cls: Type[VectorStore] = Chroma - embedding: Embeddings = Field(default_factory=OpenAIEmbeddings) + embedding: EmbeddingModel = Field(default_factory=OpenAIEmbeddings) text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter) vectorstore_kwargs: dict = Field(default_factory=dict) diff --git a/langchain/prompts/example_selector/semantic_similarity.py b/langchain/prompts/example_selector/semantic_similarity.py index 0d66c13673f..b4f8c06a562 100644 --- a/langchain/prompts/example_selector/semantic_similarity.py +++ b/langchain/prompts/example_selector/semantic_similarity.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Type from pydantic import BaseModel, Extra -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.prompts.example_selector.base import BaseExampleSelector from langchain.vectorstores.base import VectorStore @@ -64,7 +64,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): def from_examples( cls, examples: List[dict], - embeddings: Embeddings, + embeddings: EmbeddingModel, vectorstore_cls: Type[VectorStore], k: int = 4, input_keys: Optional[List[str]] = None, @@ -130,7 +130,7 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector): def from_examples( cls, examples: List[dict], - embeddings: Embeddings, + embeddings: EmbeddingModel, vectorstore_cls: Type[VectorStore], k: int = 4, input_keys: Optional[List[str]] = None, diff --git a/langchain/retrievers/document_compressors/embeddings_filter.py b/langchain/retrievers/document_compressors/embeddings_filter.py index 543380189d8..ddc4e391d9a 100644 --- a/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/langchain/retrievers/document_compressors/embeddings_filter.py @@ -8,7 +8,7 @@ from langchain.document_transformers import ( _get_embeddings_from_stateful_docs, get_stateful_documents, ) -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.math_utils import cosine_similarity from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, @@ -17,7 +17,7 @@ from langchain.schema import Document class EmbeddingsFilter(BaseDocumentCompressor): - embeddings: Embeddings + embeddings: EmbeddingModel """Embeddings to use for embedding document contents and queries.""" similarity_fn: Callable = cosine_similarity """Similarity function for comparing documents. Function expected to take as input diff --git a/langchain/retrievers/knn.py b/langchain/retrievers/knn.py index d6204723c63..4e57549d5c7 100644 --- a/langchain/retrievers/knn.py +++ b/langchain/retrievers/knn.py @@ -10,17 +10,17 @@ from typing import Any, List, Optional import numpy as np from pydantic import BaseModel -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.schema import BaseRetriever, Document -def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: +def create_index(contexts: List[str], embeddings: EmbeddingModel) -> np.ndarray: with concurrent.futures.ThreadPoolExecutor() as executor: return np.array(list(executor.map(embeddings.embed_query, contexts))) class KNNRetriever(BaseRetriever, BaseModel): - embeddings: Embeddings + embeddings: EmbeddingModel index: Any texts: List[str] k: int = 4 @@ -34,7 +34,7 @@ class KNNRetriever(BaseRetriever, BaseModel): @classmethod def from_texts( - cls, texts: List[str], embeddings: Embeddings, **kwargs: Any + cls, texts: List[str], embeddings: EmbeddingModel, **kwargs: Any ) -> KNNRetriever: index = create_index(texts, embeddings) return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) diff --git a/langchain/retrievers/milvus.py b/langchain/retrievers/milvus.py index 915d61d9897..b6eaee10ddd 100644 --- a/langchain/retrievers/milvus.py +++ b/langchain/retrievers/milvus.py @@ -1,7 +1,7 @@ """Milvus Retriever""" from typing import Any, Dict, List, Optional -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.schema import BaseRetriever, Document from langchain.vectorstores.milvus import Milvus @@ -11,7 +11,7 @@ from langchain.vectorstores.milvus import Milvus class MilvusRetreiver(BaseRetriever): def __init__( self, - embedding_function: Embeddings, + embedding_function: EmbeddingModel, collection_name: str = "LangChainCollection", connection_args: Optional[Dict[str, Any]] = None, consistency_level: str = "Session", diff --git a/langchain/retrievers/pinecone_hybrid_search.py b/langchain/retrievers/pinecone_hybrid_search.py index bd04a296a7a..0c60f8e792d 100644 --- a/langchain/retrievers/pinecone_hybrid_search.py +++ b/langchain/retrievers/pinecone_hybrid_search.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.schema import BaseRetriever, Document @@ -15,7 +15,7 @@ def hash_text(text: str) -> str: def create_index( contexts: List[str], index: Any, - embeddings: Embeddings, + embeddings: EmbeddingModel, sparse_encoder: Any, ids: Optional[List[str]] = None, metadatas: Optional[List[dict]] = None, @@ -74,7 +74,7 @@ def create_index( class PineconeHybridSearchRetriever(BaseRetriever, BaseModel): - embeddings: Embeddings + embeddings: EmbeddingModel sparse_encoder: Any index: Any top_k: int = 4 diff --git a/langchain/retrievers/svm.py b/langchain/retrievers/svm.py index 694ac97893a..67bc4b1cb04 100644 --- a/langchain/retrievers/svm.py +++ b/langchain/retrievers/svm.py @@ -10,17 +10,17 @@ from typing import Any, List, Optional import numpy as np from pydantic import BaseModel -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.schema import BaseRetriever, Document -def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: +def create_index(contexts: List[str], embeddings: EmbeddingModel) -> np.ndarray: with concurrent.futures.ThreadPoolExecutor() as executor: return np.array(list(executor.map(embeddings.embed_query, contexts))) class SVMRetriever(BaseRetriever, BaseModel): - embeddings: Embeddings + embeddings: EmbeddingModel index: Any texts: List[str] k: int = 4 @@ -34,7 +34,7 @@ class SVMRetriever(BaseRetriever, BaseModel): @classmethod def from_texts( - cls, texts: List[str], embeddings: Embeddings, **kwargs: Any + cls, texts: List[str], embeddings: EmbeddingModel, **kwargs: Any ) -> SVMRetriever: index = create_index(texts, embeddings) return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) diff --git a/langchain/retrievers/zilliz.py b/langchain/retrievers/zilliz.py index 6b39a3a022e..6b72d6d874f 100644 --- a/langchain/retrievers/zilliz.py +++ b/langchain/retrievers/zilliz.py @@ -1,7 +1,7 @@ """Zilliz Retriever""" from typing import Any, Dict, List, Optional -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.schema import BaseRetriever, Document from langchain.vectorstores.zilliz import Zilliz @@ -11,7 +11,7 @@ from langchain.vectorstores.zilliz import Zilliz class ZillizRetreiver(BaseRetriever): def __init__( self, - embedding_function: Embeddings, + embedding_function: EmbeddingModel, collection_name: str = "LangChainCollection", connection_args: Optional[Dict[str, Any]] = None, consistency_level: str = "Session", diff --git a/langchain/vectorstores/analyticdb.py b/langchain/vectorstores/analyticdb.py index bb32549b450..c418f21633b 100644 --- a/langchain/vectorstores/analyticdb.py +++ b/langchain/vectorstores/analyticdb.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import Session, relationship from sqlalchemy.sql.expression import func from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore @@ -126,7 +126,7 @@ class AnalyticDB(VectorStore): def __init__( self, connection_string: str, - embedding_function: Embeddings, + embedding_function: EmbeddingModel, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, pre_delete_collection: bool = False, @@ -343,7 +343,7 @@ class AnalyticDB(VectorStore): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, ids: Optional[List[str]] = None, @@ -390,7 +390,7 @@ class AnalyticDB(VectorStore): def from_documents( cls, documents: List[Document], - embedding: Embeddings, + embedding: EmbeddingModel, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, diff --git a/langchain/vectorstores/annoy.py b/langchain/vectorstores/annoy.py index 538f75c229e..81e26b59b9b 100644 --- a/langchain/vectorstores/annoy.py +++ b/langchain/vectorstores/annoy.py @@ -13,7 +13,7 @@ import numpy as np from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -282,7 +282,7 @@ class Annoy(VectorStore): cls, texts: List[str], embeddings: List[List[float]], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, metric: str = DEFAULT_METRIC, trees: int = 100, @@ -319,7 +319,7 @@ class Annoy(VectorStore): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, metric: str = DEFAULT_METRIC, trees: int = 100, @@ -360,7 +360,7 @@ class Annoy(VectorStore): def from_embeddings( cls, text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, metric: str = DEFAULT_METRIC, trees: int = 100, @@ -424,7 +424,7 @@ class Annoy(VectorStore): def load_local( cls, folder_path: str, - embeddings: Embeddings, + embeddings: EmbeddingModel, ) -> Annoy: """Load Annoy index, docstore, and index_to_docstore_id to disk. diff --git a/langchain/vectorstores/atlas.py b/langchain/vectorstores/atlas.py index 6166a101373..d8fe8584928 100644 --- a/langchain/vectorstores/atlas.py +++ b/langchain/vectorstores/atlas.py @@ -8,7 +8,7 @@ from typing import Any, Iterable, List, Optional, Type import numpy as np from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore logger = logging.getLogger(__name__) @@ -34,7 +34,7 @@ class AtlasDB(VectorStore): def __init__( self, name: str, - embedding_function: Optional[Embeddings] = None, + embedding_function: Optional[EmbeddingModel] = None, api_key: Optional[str] = None, description: str = "A description for your project", is_public: bool = True, @@ -212,7 +212,7 @@ class AtlasDB(VectorStore): def from_texts( cls: Type[AtlasDB], texts: List[str], - embedding: Optional[Embeddings] = None, + embedding: Optional[EmbeddingModel] = None, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, name: Optional[str] = None, @@ -229,7 +229,7 @@ class AtlasDB(VectorStore): texts (List[str]): The list of texts to ingest. name (str): Name of the project to create. api_key (str): Your nomic API key, - embedding (Optional[Embeddings]): Embedding function. Defaults to None. + embedding (Optional[EmbeddingModel]): Embedding function. Defaults to None. metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. ids (Optional[List[str]]): Optional list of document IDs. If None, ids will be auto created @@ -272,7 +272,7 @@ class AtlasDB(VectorStore): def from_documents( cls: Type[AtlasDB], documents: List[Document], - embedding: Optional[Embeddings] = None, + embedding: Optional[EmbeddingModel] = None, ids: Optional[List[str]] = None, name: Optional[str] = None, api_key: Optional[str] = None, @@ -289,7 +289,7 @@ class AtlasDB(VectorStore): name (str): Name of the collection to create. api_key (str): Your nomic API key, documents (List[Document]): List of documents to add to the vectorstore. - embedding (Optional[Embeddings]): Embedding function. Defaults to None. + embedding (Optional[EmbeddingModel]): Embedding function. Defaults to None. ids (Optional[List[str]]): Optional list of document IDs. If None, ids will be auto created description (str): A description for your project. diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 6b689cd14e3..605ee471097 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar from pydantic import BaseModel, Field, root_validator from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.schema import BaseRetriever VST = TypeVar("VST", bound="VectorStore") @@ -298,7 +298,7 @@ class VectorStore(ABC): def from_documents( cls: Type[VST], documents: List[Document], - embedding: Embeddings, + embedding: EmbeddingModel, **kwargs: Any, ) -> VST: """Return VectorStore initialized from documents and embeddings.""" @@ -310,7 +310,7 @@ class VectorStore(ABC): async def afrom_documents( cls: Type[VST], documents: List[Document], - embedding: Embeddings, + embedding: EmbeddingModel, **kwargs: Any, ) -> VST: """Return VectorStore initialized from documents and embeddings.""" @@ -323,7 +323,7 @@ class VectorStore(ABC): def from_texts( cls: Type[VST], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> VST: @@ -333,7 +333,7 @@ class VectorStore(ABC): async def afrom_texts( cls: Type[VST], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> VST: diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index f83c8c22c55..0dcca1b71c9 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Ty import numpy as np from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import xor_args from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -58,7 +58,7 @@ class Chroma(VectorStore): def __init__( self, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - embedding_function: Optional[Embeddings] = None, + embedding_function: Optional[EmbeddingModel] = None, persist_directory: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, collection_metadata: Optional[Dict] = None, @@ -354,7 +354,7 @@ class Chroma(VectorStore): def from_texts( cls: Type[Chroma], texts: List[str], - embedding: Optional[Embeddings] = None, + embedding: Optional[EmbeddingModel] = None, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, @@ -372,7 +372,7 @@ class Chroma(VectorStore): texts (List[str]): List of texts to add to the collection. collection_name (str): Name of the collection to create. persist_directory (Optional[str]): Directory to persist the collection. - embedding (Optional[Embeddings]): Embedding function. Defaults to None. + embedding (Optional[EmbeddingModel]): Embedding function. Defaults to None. metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. ids (Optional[List[str]]): List of document IDs. Defaults to None. client_settings (Optional[chromadb.config.Settings]): Chroma client settings @@ -394,7 +394,7 @@ class Chroma(VectorStore): def from_documents( cls: Type[Chroma], documents: List[Document], - embedding: Optional[Embeddings] = None, + embedding: Optional[EmbeddingModel] = None, ids: Optional[List[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, @@ -412,7 +412,7 @@ class Chroma(VectorStore): persist_directory (Optional[str]): Directory to persist the collection. ids (Optional[List[str]]): List of document IDs. Defaults to None. documents (List[Document]): List of documents to add to the vectorstore. - embedding (Optional[Embeddings]): Embedding function. Defaults to None. + embedding (Optional[EmbeddingModel]): Embedding function. Defaults to None. client_settings (Optional[chromadb.config.Settings]): Chroma client settings Returns: Chroma: Chroma vectorstore. diff --git a/langchain/vectorstores/deeplake.py b/langchain/vectorstores/deeplake.py index 01ed62df132..e3d8ce8208d 100644 --- a/langchain/vectorstores/deeplake.py +++ b/langchain/vectorstores/deeplake.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tupl import numpy as np from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -96,7 +96,7 @@ class DeepLake(VectorStore): self, dataset_path: str = _LANGCHAIN_DEFAULT_DEEPLAKE_PATH, token: Optional[str] = None, - embedding_function: Optional[Embeddings] = None, + embedding_function: Optional[EmbeddingModel] = None, read_only: Optional[bool] = False, ingestion_batch_size: int = 1024, num_workers: int = 0, @@ -494,7 +494,7 @@ class DeepLake(VectorStore): def from_texts( cls, texts: List[str], - embedding: Optional[Embeddings] = None, + embedding: Optional[EmbeddingModel] = None, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, dataset_path: str = _LANGCHAIN_DEFAULT_DEEPLAKE_PATH, @@ -522,7 +522,7 @@ class DeepLake(VectorStore): save the dataset, but keeps it in memory instead. Should be used only for testing as it does not persist. documents (List[Document]): List of documents to add. - embedding (Optional[Embeddings]): Embedding function. Defaults to None. + embedding (Optional[EmbeddingModel]): Embedding function. Defaults to None. metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. ids (Optional[List[str]]): List of document IDs. Defaults to None. diff --git a/langchain/vectorstores/docarray/base.py b/langchain/vectorstores/docarray/base.py index d7b2f3c9ac2..6b24e004f8f 100644 --- a/langchain/vectorstores/docarray/base.py +++ b/langchain/vectorstores/docarray/base.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type import numpy as np from pydantic import Field -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.schema import Document from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -36,7 +36,7 @@ class DocArrayIndex(VectorStore, ABC): def __init__( self, doc_index: "BaseDocIndex", - embedding: Embeddings, + embedding: EmbeddingModel, ): """Initialize a vector store from DocArray's DocIndex.""" self.doc_index = doc_index diff --git a/langchain/vectorstores/docarray/hnsw.py b/langchain/vectorstores/docarray/hnsw.py index 9e334c3c47b..9d7658431ec 100644 --- a/langchain/vectorstores/docarray/hnsw.py +++ b/langchain/vectorstores/docarray/hnsw.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, List, Literal, Optional -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.docarray.base import ( DocArrayIndex, _check_docarray_import, @@ -20,7 +20,7 @@ class DocArrayHnswSearch(DocArrayIndex): @classmethod def from_params( cls, - embedding: Embeddings, + embedding: EmbeddingModel, work_dir: str, n_dim: int, dist_metric: Literal["cosine", "ip", "l2"] = "cosine", @@ -36,7 +36,7 @@ class DocArrayHnswSearch(DocArrayIndex): """Initialize DocArrayHnswSearch store. Args: - embedding (Embeddings): Embedding function. + embedding (EmbeddingModel): Embedding function. work_dir (str): path to the location where all the data will be stored. n_dim (int): dimension of an embedding. dist_metric (str): Distance metric for DocArrayHnswSearch can be one of: @@ -78,7 +78,7 @@ class DocArrayHnswSearch(DocArrayIndex): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, work_dir: Optional[str] = None, n_dim: Optional[int] = None, @@ -89,7 +89,7 @@ class DocArrayHnswSearch(DocArrayIndex): Args: texts (List[str]): Text data. - embedding (Embeddings): Embedding function. + embedding (EmbeddingModel): Embedding function. metadatas (Optional[List[dict]]): Metadata for each text if it exists. Defaults to None. work_dir (str): path to the location where all the data will be stored. diff --git a/langchain/vectorstores/docarray/in_memory.py b/langchain/vectorstores/docarray/in_memory.py index 8ab664859eb..c53f5ea1c74 100644 --- a/langchain/vectorstores/docarray/in_memory.py +++ b/langchain/vectorstores/docarray/in_memory.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List, Literal, Optional -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.docarray.base import ( DocArrayIndex, _check_docarray_import, @@ -20,7 +20,7 @@ class DocArrayInMemorySearch(DocArrayIndex): @classmethod def from_params( cls, - embedding: Embeddings, + embedding: EmbeddingModel, metric: Literal[ "cosine_sim", "euclidian_dist", "sgeuclidean_dist" ] = "cosine_sim", @@ -29,7 +29,7 @@ class DocArrayInMemorySearch(DocArrayIndex): """Initialize DocArrayInMemorySearch store. Args: - embedding (Embeddings): Embedding function. + embedding (EmbeddingModel): Embedding function. metric (str): metric for exact nearest-neighbor search. Can be one of: "cosine_sim", "euclidean_dist" and "sqeuclidean_dist". Defaults to "cosine_sim". @@ -46,7 +46,7 @@ class DocArrayInMemorySearch(DocArrayIndex): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[Dict[Any, Any]]] = None, **kwargs: Any, ) -> DocArrayInMemorySearch: @@ -54,7 +54,7 @@ class DocArrayInMemorySearch(DocArrayIndex): Args: texts (List[str]): Text data. - embedding (Embeddings): Embedding function. + embedding (EmbeddingModel): Embedding function. metadatas (Optional[List[Dict[Any, Any]]]): Metadata for each text if it exists. Defaults to None. metric (str): metric for exact nearest-neighbor search. diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 6663e79d2b4..9e7d6acff42 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -6,7 +6,7 @@ from abc import ABC from typing import Any, Dict, Iterable, List, Optional, Tuple from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_env from langchain.vectorstores.base import VectorStore @@ -106,7 +106,7 @@ class ElasticVectorSearch(VectorStore, ABC): Args: elasticsearch_url (str): The URL for the Elasticsearch instance. index_name (str): The name of the Elasticsearch index for the embeddings. - embedding (Embeddings): An object that provides the ability to embed text. + embedding (EmbeddingModel): An object that provides the ability to embed text. It should be an instance of a class that subclasses the Embeddings abstract base class, such as OpenAIEmbeddings() @@ -118,7 +118,7 @@ class ElasticVectorSearch(VectorStore, ABC): self, elasticsearch_url: str, index_name: str, - embedding: Embeddings, + embedding: EmbeddingModel, *, ssl_verify: Optional[Dict[str, Any]] = None, ): @@ -244,7 +244,7 @@ class ElasticVectorSearch(VectorStore, ABC): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, elasticsearch_url: Optional[str] = None, index_name: Optional[str] = None, diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 143979168c8..221b3973948 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -13,7 +13,7 @@ import numpy as np from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -364,7 +364,7 @@ class FAISS(VectorStore): cls, texts: List[str], embeddings: List[List[float]], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, normalize_L2: bool = False, **kwargs: Any, @@ -396,7 +396,7 @@ class FAISS(VectorStore): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> FAISS: @@ -430,7 +430,7 @@ class FAISS(VectorStore): def from_embeddings( cls, text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> FAISS: @@ -486,7 +486,7 @@ class FAISS(VectorStore): @classmethod def load_local( - cls, folder_path: str, embeddings: Embeddings, index_name: str = "index" + cls, folder_path: str, embeddings: EmbeddingModel, index_name: str = "index" ) -> FAISS: """Load FAISS index, docstore, and index_to_docstore_id to disk. diff --git a/langchain/vectorstores/lancedb.py b/langchain/vectorstores/lancedb.py index eec6d4e05ee..0f66a3f9a37 100644 --- a/langchain/vectorstores/lancedb.py +++ b/langchain/vectorstores/lancedb.py @@ -5,7 +5,7 @@ import uuid from typing import Any, Iterable, List, Optional from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore @@ -27,7 +27,7 @@ class LanceDB(VectorStore): def __init__( self, connection: Any, - embedding: Embeddings, + embedding: EmbeddingModel, vector_key: Optional[str] = "vector", id_key: Optional[str] = "id", text_key: Optional[str] = "text", @@ -113,7 +113,7 @@ class LanceDB(VectorStore): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, connection: Any = None, vector_key: Optional[str] = "vector", diff --git a/langchain/vectorstores/milvus.py b/langchain/vectorstores/milvus.py index 1f051605e3d..1de5916bdd3 100644 --- a/langchain/vectorstores/milvus.py +++ b/langchain/vectorstores/milvus.py @@ -8,7 +8,7 @@ from uuid import uuid4 import numpy as np from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -28,7 +28,7 @@ class Milvus(VectorStore): def __init__( self, - embedding_function: Embeddings, + embedding_function: EmbeddingModel, collection_name: str = "LangChainCollection", connection_args: Optional[dict[str, Any]] = None, consistency_level: str = "Session", @@ -77,7 +77,7 @@ class Milvus(VectorStore): server_name (str): If use tls, need to write the common name. Args: - embedding_function (Embeddings): Function used to embed the text. + embedding_function (EmbeddingModel): Function used to embed the text. collection_name (str): Which Milvus collection to use. Defaults to "LangChainCollection". connection_args (Optional[dict[str, any]]): The arguments for connection to @@ -754,7 +754,7 @@ class Milvus(VectorStore): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, collection_name: str = "LangChainCollection", connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, @@ -768,7 +768,7 @@ class Milvus(VectorStore): Args: texts (List[str]): Text data. - embedding (Embeddings): Embedding function. + embedding (EmbeddingModel): Embedding function. metadatas (Optional[List[dict]]): Metadata for each text if it exists. Defaults to None. collection_name (str, optional): Collection name to use. Defaults to diff --git a/langchain/vectorstores/myscale.py b/langchain/vectorstores/myscale.py index 3ae8d275dbf..ea3dd19fd8b 100644 --- a/langchain/vectorstores/myscale.py +++ b/langchain/vectorstores/myscale.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple from pydantic import BaseSettings from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore logger = logging.getLogger() @@ -98,7 +98,7 @@ class MyScale(VectorStore): def __init__( self, - embedding: Embeddings, + embedding: EmbeddingModel, config: Optional[MyScaleSettings] = None, **kwargs: Any, ) -> None: @@ -259,7 +259,7 @@ class MyScale(VectorStore): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[Dict[Any, Any]]] = None, config: Optional[MyScaleSettings] = None, text_ids: Optional[Iterable[str]] = None, @@ -269,7 +269,7 @@ class MyScale(VectorStore): """Create Myscale wrapper with existing texts Args: - embedding_function (Embeddings): Function to extract text embedding + embedding_function (EmbeddingModel): Function to extract text embedding texts (Iterable[str]): List or tuple of strings to be added config (MyScaleSettings, Optional): Myscale configuration text_ids (Optional[Iterable], optional): IDs for the texts. diff --git a/langchain/vectorstores/opensearch_vector_search.py b/langchain/vectorstores/opensearch_vector_search.py index 624d62c5715..4da12eaefab 100644 --- a/langchain/vectorstores/opensearch_vector_search.py +++ b/langchain/vectorstores/opensearch_vector_search.py @@ -5,7 +5,7 @@ import uuid from typing import Any, Dict, Iterable, List, Optional, Tuple from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore @@ -295,7 +295,7 @@ class OpenSearchVectorSearch(VectorStore): self, opensearch_url: str, index_name: str, - embedding_function: Embeddings, + embedding_function: EmbeddingModel, **kwargs: Any, ): """Initialize with necessary components.""" @@ -494,7 +494,7 @@ class OpenSearchVectorSearch(VectorStore): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, bulk_size: int = 500, **kwargs: Any, diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index c042cc2f652..39f8c1448c2 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -12,7 +12,7 @@ from sqlalchemy.dialects.postgresql import JSON, UUID from sqlalchemy.orm import Session, declarative_base, relationship from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore @@ -122,7 +122,7 @@ class PGVector(VectorStore): def __init__( self, connection_string: str, - embedding_function: Embeddings, + embedding_function: EmbeddingModel, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, @@ -363,7 +363,7 @@ class PGVector(VectorStore): def from_texts( cls: Type[PGVector], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DistanceStrategy.COSINE, @@ -412,7 +412,7 @@ class PGVector(VectorStore): def from_documents( cls: Type[PGVector], documents: List[Document], - embedding: Embeddings, + embedding: EmbeddingModel, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, diff --git a/langchain/vectorstores/pinecone.py b/langchain/vectorstores/pinecone.py index f9a6fe9b426..953ff4610fc 100644 --- a/langchain/vectorstores/pinecone.py +++ b/langchain/vectorstores/pinecone.py @@ -6,7 +6,7 @@ import uuid from typing import Any, Callable, Iterable, List, Optional, Tuple from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore logger = logging.getLogger(__name__) @@ -161,7 +161,7 @@ class Pinecone(VectorStore): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, batch_size: int = 32, @@ -247,7 +247,7 @@ class Pinecone(VectorStore): def from_existing_index( cls, index_name: str, - embedding: Embeddings, + embedding: EmbeddingModel, text_key: str = "text", namespace: Optional[str] = None, ) -> Pinecone: diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index 5af38b9694d..b01ac5ee6e1 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -21,7 +21,7 @@ from typing import ( import numpy as np from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -55,7 +55,7 @@ class Qdrant(VectorStore): self, client: Any, collection_name: str, - embeddings: Optional[Embeddings] = None, + embeddings: Optional[EmbeddingModel] = None, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, embedding_function: Optional[Callable] = None, # deprecated @@ -99,7 +99,7 @@ class Qdrant(VectorStore): "Pass `Embeddings` instance to `embeddings` instead." ) - if not isinstance(embeddings, Embeddings): + if not isinstance(embeddings, EmbeddingModel): warnings.warn( "`embeddings` should be an instance of `Embeddings`." "Using `embeddings` as `embedding_function` which is deprecated" @@ -292,7 +292,7 @@ class Qdrant(VectorStore): def from_texts( cls: Type[Qdrant], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, location: Optional[str] = None, url: Optional[str] = None, diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index 9c4ec0c4154..bc71f1684ab 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -22,7 +22,7 @@ import numpy as np from pydantic import BaseModel, root_validator from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore, VectorStoreRetriever @@ -361,7 +361,7 @@ class Redis(VectorStore): def from_texts_return_keys( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, index_name: Optional[str] = None, content_key: str = "content", @@ -421,7 +421,7 @@ class Redis(VectorStore): def from_texts( cls: Type[Redis], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, index_name: Optional[str] = None, content_key: str = "content", @@ -502,7 +502,7 @@ class Redis(VectorStore): @classmethod def from_existing_index( cls, - embedding: Embeddings, + embedding: EmbeddingModel, index_name: str, content_key: str = "content", metadata_key: str = "metadata", diff --git a/langchain/vectorstores/supabase.py b/langchain/vectorstores/supabase.py index d6d5b0275b0..f5653d7e935 100644 --- a/langchain/vectorstores/supabase.py +++ b/langchain/vectorstores/supabase.py @@ -15,7 +15,7 @@ from typing import ( import numpy as np from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -41,14 +41,14 @@ class SupabaseVectorStore(VectorStore): # This is the embedding function. Don't confuse with the embedding vectors. # We should perhaps rename the underlying Embedding base class to EmbeddingFunction # or something - _embedding: Embeddings + _embedding: EmbeddingModel table_name: str query_name: str def __init__( self, client: supabase.client.Client, - embedding: Embeddings, + embedding: EmbeddingModel, table_name: str, query_name: Union[str, None] = None, ) -> None: @@ -62,7 +62,7 @@ class SupabaseVectorStore(VectorStore): ) self._client = client - self._embedding: Embeddings = embedding + self._embedding: EmbeddingModel = embedding self.table_name = table_name or "documents" self.query_name = query_name or "match_documents" @@ -81,7 +81,7 @@ class SupabaseVectorStore(VectorStore): def from_texts( cls: Type["SupabaseVectorStore"], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, client: Optional[supabase.client.Client] = None, table_name: Optional[str] = "documents", diff --git a/langchain/vectorstores/tair.py b/langchain/vectorstores/tair.py index 75a98aadf34..5b340056ba2 100644 --- a/langchain/vectorstores/tair.py +++ b/langchain/vectorstores/tair.py @@ -7,7 +7,7 @@ import uuid from typing import Any, Iterable, List, Optional, Type from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore @@ -21,7 +21,7 @@ def _uuid_key() -> str: class Tair(VectorStore): def __init__( self, - embedding_function: Embeddings, + embedding_function: EmbeddingModel, url: str, index_name: str, content_key: str = "content", @@ -140,7 +140,7 @@ class Tair(VectorStore): def from_texts( cls: Type[Tair], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, index_name: str = "langchain", content_key: str = "content", @@ -208,7 +208,7 @@ class Tair(VectorStore): def from_documents( cls, documents: List[Document], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, index_name: str = "langchain", content_key: str = "content", @@ -262,7 +262,7 @@ class Tair(VectorStore): @classmethod def from_existing_index( cls, - embedding: Embeddings, + embedding: EmbeddingModel, index_name: str = "langchain", content_key: str = "content", metadata_key: str = "metadata", diff --git a/langchain/vectorstores/weaviate.py b/langchain/vectorstores/weaviate.py index a7398ae7744..53c291bd450 100644 --- a/langchain/vectorstores/weaviate.py +++ b/langchain/vectorstores/weaviate.py @@ -8,7 +8,7 @@ from uuid import uuid4 import numpy as np from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -89,7 +89,7 @@ class Weaviate(VectorStore): client: Any, index_name: str, text_key: str, - embedding: Optional[Embeddings] = None, + embedding: Optional[EmbeddingModel] = None, attributes: Optional[List[str]] = None, relevance_score_fn: Optional[ Callable[[float], float] @@ -360,7 +360,7 @@ class Weaviate(VectorStore): def from_texts( cls: Type[Weaviate], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> Weaviate: diff --git a/langchain/vectorstores/zilliz.py b/langchain/vectorstores/zilliz.py index 13d165d6f7b..72819db2e46 100644 --- a/langchain/vectorstores/zilliz.py +++ b/langchain/vectorstores/zilliz.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from typing import Any, List, Optional -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores.milvus import Milvus logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ class Zilliz(Milvus): def from_texts( cls, texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, collection_name: str = "LangChainCollection", connection_args: dict[str, Any] = {}, @@ -73,7 +73,7 @@ class Zilliz(Milvus): Args: texts (List[str]): Text data. - embedding (Embeddings): Embedding function. + embedding (EmbeddingModel): Embedding function. metadatas (Optional[List[dict]]): Metadata for each text if it exists. Defaults to None. collection_name (str, optional): Collection name to use. Defaults to diff --git a/tests/integration_tests/vectorstores/fake_embeddings.py b/tests/integration_tests/vectorstores/fake_embeddings.py index 17a81e0493c..2a147cffd6b 100644 --- a/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/tests/integration_tests/vectorstores/fake_embeddings.py @@ -1,12 +1,12 @@ """Fake Embedding class for testing purposes.""" from typing import List -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel fake_texts = ["foo", "bar", "baz"] -class FakeEmbeddings(Embeddings): +class FakeEmbeddings(EmbeddingModel): """Fake embeddings functionality for testing.""" def embed_documents(self, texts: List[str]) -> List[List[float]]: diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index 8362951c6c8..77818dbc56d 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -4,7 +4,7 @@ from typing import Callable, Optional import pytest from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings @@ -139,7 +139,7 @@ def test_qdrant_max_marginal_relevance_search( ], ) def test_qdrant_embedding_interface( - embeddings: Optional[Embeddings], embedding_function: Optional[Callable] + embeddings: Optional[EmbeddingModel], embedding_function: Optional[Callable] ) -> None: from qdrant_client import QdrantClient @@ -162,7 +162,7 @@ def test_qdrant_embedding_interface( ], ) def test_qdrant_embedding_interface_raises( - embeddings: Optional[Embeddings], embedding_function: Optional[Callable] + embeddings: Optional[EmbeddingModel], embedding_function: Optional[Callable] ) -> None: from qdrant_client import QdrantClient diff --git a/tests/unit_tests/chains/test_hyde.py b/tests/unit_tests/chains/test_hyde.py index dd2ade83c18..da5c96ff46e 100644 --- a/tests/unit_tests/chains/test_hyde.py +++ b/tests/unit_tests/chains/test_hyde.py @@ -9,12 +9,12 @@ from langchain.callbacks.manager import ( ) from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.prompts import PROMPT_MAP -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.llms.base import BaseLLM from langchain.schema import Generation, LLMResult -class FakeEmbeddings(Embeddings): +class FakeEmbeddings(EmbeddingModel): """Fake embedding class for tests.""" def embed_documents(self, texts: List[str]) -> List[List[float]]: diff --git a/tests/unit_tests/retrievers/test_time_weighted_retriever.py b/tests/unit_tests/retrievers/test_time_weighted_retriever.py index d021ed93cb4..833cc50d9cc 100644 --- a/tests/unit_tests/retrievers/test_time_weighted_retriever.py +++ b/tests/unit_tests/retrievers/test_time_weighted_retriever.py @@ -5,7 +5,7 @@ from typing import Any, Iterable, List, Optional, Tuple, Type import pytest -from langchain.embeddings.base import Embeddings +from langchain.embeddings.base import EmbeddingModel from langchain.retrievers.time_weighted_retriever import ( TimeWeightedVectorStoreRetriever, _get_hours_passed, @@ -67,7 +67,7 @@ class MockVectorStore(VectorStore): def from_documents( cls: Type["MockVectorStore"], documents: List[Document], - embedding: Embeddings, + embedding: EmbeddingModel, **kwargs: Any, ) -> "MockVectorStore": """Return VectorStore initialized from documents and embeddings.""" @@ -79,7 +79,7 @@ class MockVectorStore(VectorStore): def from_texts( cls: Type["MockVectorStore"], texts: List[str], - embedding: Embeddings, + embedding: EmbeddingModel, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> "MockVectorStore":