mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
Add embedding and vectorstore provider info as tags (#8027)
Example: https://smith.langchain.com/public/bcd3714d-abba-4790-81c8-9b5718535867/r The vectorstore implementations aren't super standardized yet, so just adding an optional embeddings property to pass in.
This commit is contained in:
parent
355b7d8b86
commit
c38965fcba
@ -326,6 +326,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -346,6 +347,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
tags=tags,
|
||||
child_runs=[],
|
||||
run_type=RunTypeEnum.retriever,
|
||||
)
|
||||
|
@ -79,6 +79,10 @@ class AnalyticDB(VectorStore):
|
||||
self.engine = create_engine(self.connection_string, **_engine_args)
|
||||
self.create_collection()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._euclidean_relevance_score_fn
|
||||
|
||||
|
@ -61,6 +61,11 @@ class Annoy(VectorStore):
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Accept embedding object directly
|
||||
return None
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -46,7 +46,7 @@ class AtlasDB(VectorStore):
|
||||
Args:
|
||||
name (str): The name of your project. If the project already exists,
|
||||
it will be loaded.
|
||||
embedding_function (Optional[Callable]): An optional function used for
|
||||
embedding_function (Optional[Embeddings]): An optional function used for
|
||||
embedding your data. If None, data will be embedded with
|
||||
Nomic's embed model.
|
||||
api_key (str): Your nomic API key
|
||||
@ -86,6 +86,10 @@ class AtlasDB(VectorStore):
|
||||
)
|
||||
self.project._latest_project_state()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embedding_function
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -73,6 +73,12 @@ class AwaDB(VectorStore):
|
||||
self.table2embeddings[table_name] = embedding
|
||||
self.using_table_name = table_name
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
if self.using_table_name in self.table2embeddings:
|
||||
return self.table2embeddings[self.using_table_name]
|
||||
return None
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -191,6 +191,11 @@ class AzureSearch(VectorStore):
|
||||
self.semantic_configuration_name = semantic_configuration_name
|
||||
self.semantic_query_language = semantic_query_language
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Support embedding object directly
|
||||
return None
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
@ -31,6 +32,8 @@ from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VST = TypeVar("VST", bound="VectorStore")
|
||||
|
||||
|
||||
@ -55,6 +58,14 @@ class VectorStore(ABC):
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
"""Access the query embedding object if available."""
|
||||
logger.debug(
|
||||
f"{Embeddings.__name__} is not implemented for {self.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by vector ID or other criteria.
|
||||
|
||||
@ -435,8 +446,17 @@ class VectorStore(ABC):
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __get_retriever_tags(self) -> List[str]:
|
||||
"""Get tags for retriever."""
|
||||
tags = [self.__class__.__name__]
|
||||
if self.embeddings:
|
||||
tags.append(self.embeddings.__class__.__name__)
|
||||
return tags
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
|
||||
return VectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
tags = kwargs.pop("tags", None) or []
|
||||
tags.extend(self.__get_retriever_tags())
|
||||
return VectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
|
||||
|
||||
|
||||
class VectorStoreRetriever(BaseRetriever):
|
||||
|
@ -77,6 +77,10 @@ class Cassandra(VectorStore):
|
||||
primary_key_type="TEXT",
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._cosine_relevance_score_fn
|
||||
|
||||
|
@ -121,6 +121,10 @@ class Chroma(VectorStore):
|
||||
)
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embedding_function
|
||||
|
||||
@xor_args(("query_texts", "query_embeddings"))
|
||||
def __query_collection(
|
||||
self,
|
||||
|
@ -212,6 +212,10 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
self.client.command("SET allow_experimental_annoy_index=1")
|
||||
self.client.command(self.schema)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def escape_str(self, value: str) -> str:
|
||||
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
|
||||
|
||||
|
@ -151,6 +151,10 @@ class DeepLake(VectorStore):
|
||||
self._embedding_function = embedding_function
|
||||
self._id_tensor_name = "ids" if "ids" in self.vectorstore.tensors() else "id"
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embedding_function
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -153,6 +153,10 @@ class ElasticVectorSearch(VectorStore, ABC):
|
||||
f"Your elasticsearch client string is mis-formatted. Got error: {e} "
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embeddings
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -87,6 +87,11 @@ class FAISS(VectorStore):
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Accept embeddings object directly
|
||||
return None
|
||||
|
||||
def __add(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -148,6 +148,10 @@ class Hologres(VectorStore):
|
||||
self.create_vector_extension()
|
||||
self.create_table()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def create_vector_extension(self) -> None:
|
||||
try:
|
||||
self.storage.create_vector_extension()
|
||||
|
@ -51,6 +51,10 @@ class LanceDB(VectorStore):
|
||||
self._id_key = id_key
|
||||
self._text_key = text_key
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embedding
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -82,6 +82,10 @@ class Marqo(VectorStore):
|
||||
|
||||
self._document_batch_size = 1024
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return None
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -84,6 +84,10 @@ class MatchingEngine(VectorStore):
|
||||
self.credentials = credentials
|
||||
self.gcs_bucket_name = gcs_bucket_name
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
def _validate_google_libraries_installation(self) -> None:
|
||||
"""Validates that Google libraries that are needed are installed."""
|
||||
try:
|
||||
|
@ -164,6 +164,10 @@ class Milvus(VectorStore):
|
||||
# Initialize the vector store
|
||||
self._init()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_func
|
||||
|
||||
def _create_connection_alias(self, connection_args: dict) -> str:
|
||||
"""Create the connection to the Milvus server."""
|
||||
from pymilvus import MilvusException, connections
|
||||
|
@ -77,6 +77,10 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
self._text_key = text_key
|
||||
self._embedding_key = embedding_key
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embedding
|
||||
|
||||
@classmethod
|
||||
def from_connection_string(
|
||||
cls,
|
||||
|
@ -115,7 +115,7 @@ class MyScale(VectorStore):
|
||||
) -> None:
|
||||
"""MyScale Wrapper to LangChain
|
||||
|
||||
embedding_function (Embeddings):
|
||||
embedding (Embeddings):
|
||||
config (MyScaleSettings): Configuration to MyScale Client
|
||||
Other keyword arguments will pass into
|
||||
[clickhouse-connect](https://docs.myscale.com/)
|
||||
@ -175,7 +175,7 @@ class MyScale(VectorStore):
|
||||
self.dim = dim
|
||||
self.BS = "\\"
|
||||
self.must_escape = ("\\", "'")
|
||||
self.embedding_function = embedding.embed_query
|
||||
self._embeddings = embedding
|
||||
self.dist_order = "ASC" if self.config.metric in ["cosine", "l2"] else "DESC"
|
||||
|
||||
# Create a connection to myscale
|
||||
@ -189,6 +189,10 @@ class MyScale(VectorStore):
|
||||
self.client.command("SET allow_experimental_object_type=1")
|
||||
self.client.command(schema_)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embeddings
|
||||
|
||||
def escape_str(self, value: str) -> str:
|
||||
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
|
||||
|
||||
@ -238,7 +242,7 @@ class MyScale(VectorStore):
|
||||
column_names = {
|
||||
colmap_["id"]: ids,
|
||||
colmap_["text"]: texts,
|
||||
colmap_["vector"]: map(self.embedding_function, texts),
|
||||
colmap_["vector"]: map(self._embeddings.embed_query, texts),
|
||||
}
|
||||
metadatas = metadatas or [{} for _ in texts]
|
||||
column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
|
||||
@ -269,7 +273,7 @@ class MyScale(VectorStore):
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
texts: Iterable[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||
config: Optional[MyScaleSettings] = None,
|
||||
@ -280,8 +284,8 @@ class MyScale(VectorStore):
|
||||
"""Create Myscale wrapper with existing texts
|
||||
|
||||
Args:
|
||||
embedding_function (Embeddings): Function to extract text embedding
|
||||
texts (Iterable[str]): List or tuple of strings to be added
|
||||
embedding (Embeddings): Function to extract text embedding
|
||||
config (MyScaleSettings, Optional): Myscale configuration
|
||||
text_ids (Optional[Iterable], optional): IDs for the texts.
|
||||
Defaults to None.
|
||||
@ -357,7 +361,7 @@ class MyScale(VectorStore):
|
||||
List[Document]: List of Documents
|
||||
"""
|
||||
return self.similarity_search_by_vector(
|
||||
self.embedding_function(query), k, where_str, **kwargs
|
||||
self._embeddings.embed_query(query), k, where_str, **kwargs
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
@ -417,7 +421,7 @@ class MyScale(VectorStore):
|
||||
and cosine distance in float for each.
|
||||
Lower score represents more similarity.
|
||||
"""
|
||||
q_str = self._build_qstr(self.embedding_function(query), k, where_str)
|
||||
q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str)
|
||||
try:
|
||||
return [
|
||||
(
|
||||
|
@ -316,6 +316,10 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
self.index_name = index_name
|
||||
self.client = _get_opensearch_client(opensearch_url, **kwargs)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -135,6 +135,10 @@ class PGEmbedding(VectorStore):
|
||||
self.create_tables_if_not_exists()
|
||||
self.create_collection()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def connect(self) -> sqlalchemy.engine.Connection:
|
||||
engine = sqlalchemy.create_engine(self.connection_string)
|
||||
conn = engine.connect()
|
||||
|
@ -125,6 +125,10 @@ class PGVector(VectorStore):
|
||||
self.create_tables_if_not_exists()
|
||||
self.create_collection()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def connect(self) -> sqlalchemy.engine.Connection:
|
||||
engine = sqlalchemy.create_engine(self.connection_string)
|
||||
conn = engine.connect()
|
||||
|
@ -62,6 +62,11 @@ class Pinecone(VectorStore):
|
||||
self._namespace = namespace
|
||||
self.distance_strategy = distance_strategy
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Accept this object directly
|
||||
return None
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -123,7 +123,7 @@ class Qdrant(VectorStore):
|
||||
"Use `embeddings` only."
|
||||
)
|
||||
|
||||
self.embeddings = embeddings
|
||||
self._embeddings = embeddings
|
||||
self._embeddings_function = embedding_function
|
||||
self.client: qdrant_client.QdrantClient = client
|
||||
self.collection_name = collection_name
|
||||
@ -143,10 +143,14 @@ class Qdrant(VectorStore):
|
||||
"Using `embeddings` as `embedding_function` which is deprecated"
|
||||
)
|
||||
self._embeddings_function = embeddings
|
||||
self.embeddings = None
|
||||
self._embeddings = None
|
||||
|
||||
self.distance_strategy = distance_strategy.upper()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embeddings
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -161,6 +161,11 @@ class Redis(VectorStore):
|
||||
self.distance_metric = distance_metric
|
||||
self.relevance_score_fn = relevance_score_fn
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Accept embedding object directly
|
||||
return None
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
if self.relevance_score_fn:
|
||||
return self.relevance_score_fn
|
||||
@ -601,7 +606,9 @@ class Redis(VectorStore):
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
|
||||
return RedisVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
tags = kwargs.pop("tags", None) or []
|
||||
tags.extend(self.__get_retriever_tags())
|
||||
return RedisVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
|
||||
|
||||
|
||||
class RedisVectorStoreRetriever(VectorStoreRetriever):
|
||||
|
@ -83,6 +83,10 @@ class Rockset(VectorStore):
|
||||
self._text_key = text_key
|
||||
self._embedding_key = embedding_key
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embeddings
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -213,6 +213,10 @@ class SingleStoreDB(VectorStore):
|
||||
)
|
||||
self._create_table()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
|
||||
@ -441,7 +445,9 @@ class SingleStoreDB(VectorStore):
|
||||
return instance
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> SingleStoreDBRetriever:
|
||||
return SingleStoreDBRetriever(vectorstore=self, **kwargs)
|
||||
tags = kwargs.pop("tags", None) or []
|
||||
tags.extend(self.__get_retriever_tags())
|
||||
return SingleStoreDBRetriever(vectorstore=self, **kwargs, tags=tags)
|
||||
|
||||
|
||||
class SingleStoreDBRetriever(VectorStoreRetriever):
|
||||
|
@ -163,6 +163,10 @@ class SKLearnVectorStore(VectorStore):
|
||||
if self._persist_path is not None and os.path.isfile(self._persist_path):
|
||||
self._load()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embedding_function
|
||||
|
||||
def persist(self) -> None:
|
||||
if self._serializer is None:
|
||||
raise SKLearnVectorStoreException(
|
||||
|
@ -207,6 +207,10 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
def escape_str(self, value: str) -> str:
|
||||
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str:
|
||||
ks = ",".join(column_names)
|
||||
embed_tuple_index = tuple(column_names).index(
|
||||
|
@ -67,6 +67,10 @@ class SupabaseVectorStore(VectorStore):
|
||||
self.table_name = table_name or "documents"
|
||||
self.query_name = query_name or "match_documents"
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embedding
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -51,6 +51,10 @@ class Tair(VectorStore):
|
||||
self.metadata_key = metadata_key
|
||||
self.search_params = search_params
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def create_index_if_not_exist(
|
||||
self,
|
||||
dim: int,
|
||||
|
@ -28,6 +28,10 @@ class Tigris(VectorStore):
|
||||
self._embed_fn = embeddings
|
||||
self._vector_store = TigrisVectorStore(client.get_search(), index_name)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embed_fn
|
||||
|
||||
@property
|
||||
def search_index(self) -> TigrisVectorStore:
|
||||
return self._vector_store
|
||||
|
@ -81,6 +81,10 @@ class Typesense(VectorStore):
|
||||
def _collection(self) -> Collection:
|
||||
return self._typesense_client.collections[self._typesense_collection_name]
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embedding
|
||||
|
||||
def _prep_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -61,6 +61,10 @@ class Vectara(VectorStore):
|
||||
adapter = requests.adapters.HTTPAdapter(max_retries=3)
|
||||
self._session.mount("http://", adapter)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return None
|
||||
|
||||
def _get_post_headers(self) -> dict:
|
||||
"""Returns headers that should be attached to each post request."""
|
||||
return {
|
||||
@ -402,7 +406,9 @@ class Vectara(VectorStore):
|
||||
return vectara
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> VectaraRetriever:
|
||||
return VectaraRetriever(vectorstore=self, **kwargs)
|
||||
tags = kwargs.pop("tags", None) or []
|
||||
tags.extend(self.__get_retriever_tags())
|
||||
return VectaraRetriever(vectorstore=self, **kwargs, tags=tags)
|
||||
|
||||
|
||||
class VectaraRetriever(VectorStoreRetriever):
|
||||
|
@ -118,6 +118,10 @@ class Weaviate(VectorStore):
|
||||
if attributes is not None:
|
||||
self._query_attrs.extend(attributes)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embedding
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return (
|
||||
self.relevance_score_fn
|
||||
|
Loading…
Reference in New Issue
Block a user