mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
add Epsilla vectorstore (#9239)
[Epsilla](https://github.com/epsilla-cloud/vectordb) vectordb is an open-source vector database that leverages the advanced academic parallel graph traversal techniques for vector indexing. This PR adds basic integration with [pyepsilla](https://github.com/epsilla-cloud/epsilla-python-client)(Epsilla vectordb python client) as a vectorstore. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -42,6 +42,7 @@ from langchain.vectorstores.elastic_vector_search import (
|
||||
ElasticVectorSearch,
|
||||
)
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
from langchain.vectorstores.epsilla import Epsilla
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.vectorstores.hologres import Hologres
|
||||
from langchain.vectorstores.lancedb import LanceDB
|
||||
@@ -93,6 +94,7 @@ __all__ = [
|
||||
"ElasticVectorSearch",
|
||||
"ElasticKnnSearch",
|
||||
"ElasticsearchStore",
|
||||
"Epsilla",
|
||||
"FAISS",
|
||||
"PGEmbedding",
|
||||
"Hologres",
|
||||
|
375
libs/langchain/langchain/vectorstores/epsilla.py
Normal file
375
libs/langchain/langchain/vectorstores/epsilla.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""Wrapper around Epsilla vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyepsilla import vectordb
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class Epsilla(VectorStore):
|
||||
"""
|
||||
Wrapper around Epsilla vector database.
|
||||
|
||||
As a prerequisite, you need to install ``pyepsilla`` package
|
||||
and have a running Epsilla vector database (for example, through our docker image)
|
||||
See the following documentation for how to run an Epsilla vector database:
|
||||
https://epsilla-inc.gitbook.io/epsilladb/quick-start
|
||||
|
||||
Args:
|
||||
client (Any): Epsilla client to connect to.
|
||||
embeddings (Embeddings): Function used to embed the texts.
|
||||
db_path (Optional[str]): The path where the database will be persisted.
|
||||
Defaults to "/tmp/langchain-epsilla".
|
||||
db_name (Optional[str]): Give a name to the loaded database.
|
||||
Defaults to "langchain_store".
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Epsilla
|
||||
from pyepsilla import vectordb
|
||||
|
||||
client = vectordb.Client()
|
||||
embeddings = OpenAIEmbeddings()
|
||||
db_path = "/tmp/vectorstore"
|
||||
db_name = "langchain_store"
|
||||
epsilla = Epsilla(client, embeddings, db_path, db_name)
|
||||
"""
|
||||
|
||||
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
|
||||
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
|
||||
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
embeddings: Embeddings,
|
||||
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
|
||||
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
try:
|
||||
import pyepsilla
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import pyepsilla python package. "
|
||||
"Please install pyepsilla package with `pip install pyepsilla`."
|
||||
) from e
|
||||
|
||||
if not isinstance(client, pyepsilla.vectordb.Client):
|
||||
raise TypeError(
|
||||
f"client should be an instance of pyepsilla.vectordb.Client, "
|
||||
f"got {type(client)}"
|
||||
)
|
||||
|
||||
self._client: vectordb.Client = client
|
||||
self._db_name = db_name
|
||||
self._embeddings = embeddings
|
||||
self._collection_name = Epsilla._LANGCHAIN_DEFAULT_TABLE_NAME
|
||||
self._client.load_db(db_name=db_name, db_path=db_path)
|
||||
self._client.use_db(db_name=db_name)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embeddings
|
||||
|
||||
def use_collection(self, collection_name: str) -> None:
|
||||
"""
|
||||
Set default collection to use.
|
||||
|
||||
Args:
|
||||
collection_name (str): The name of the collection.
|
||||
"""
|
||||
self._collection_name = collection_name
|
||||
|
||||
def clear_data(self, collection_name: str = "") -> None:
|
||||
"""
|
||||
Clear data in a collection.
|
||||
|
||||
Args:
|
||||
collection_name (Optional[str]): The name of the collection.
|
||||
If not provided, the default collection will be used.
|
||||
"""
|
||||
if not collection_name:
|
||||
collection_name = self._collection_name
|
||||
self._client.drop_table(collection_name)
|
||||
|
||||
def get(
|
||||
self, collection_name: str = "", response_fields: Optional[List[str]] = None
|
||||
) -> List[dict]:
|
||||
"""Get the collection.
|
||||
|
||||
Args:
|
||||
collection_name (Optional[str]): The name of the collection
|
||||
to retrieve data from.
|
||||
If not provided, the default collection will be used.
|
||||
response_fields (Optional[List[str]]): List of field names in the result.
|
||||
If not specified, all available fields will be responded.
|
||||
|
||||
Returns:
|
||||
A list of the retrieved data.
|
||||
"""
|
||||
if not collection_name:
|
||||
collection_name = self._collection_name
|
||||
status_code, response = self._client.get(
|
||||
table_name=collection_name, response_fields=response_fields
|
||||
)
|
||||
if status_code != 200:
|
||||
logger.error(f"Failed to get records: {response['message']}")
|
||||
raise Exception("Error: {}.".format(response["message"]))
|
||||
return response["result"]
|
||||
|
||||
def _create_collection(
|
||||
self, table_name: str, embeddings: list, metadatas: Optional[list[dict]] = None
|
||||
) -> None:
|
||||
if not embeddings:
|
||||
raise ValueError("Embeddings list is empty.")
|
||||
|
||||
dim = len(embeddings[0])
|
||||
fields: List[dict] = [
|
||||
{"name": "id", "dataType": "INT"},
|
||||
{"name": "text", "dataType": "STRING"},
|
||||
{"name": "embeddings", "dataType": "VECTOR_FLOAT", "dimensions": dim},
|
||||
]
|
||||
if metadatas is not None:
|
||||
field_names = [field["name"] for field in fields]
|
||||
for metadata in metadatas:
|
||||
for key, value in metadata.items():
|
||||
if key in field_names:
|
||||
continue
|
||||
d_type: str
|
||||
if isinstance(value, str):
|
||||
d_type = "STRING"
|
||||
elif isinstance(value, int):
|
||||
d_type = "INT"
|
||||
elif isinstance(value, float):
|
||||
d_type = "FLOAT"
|
||||
elif isinstance(value, bool):
|
||||
d_type = "BOOL"
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type for {key}.")
|
||||
fields.append({"name": key, "dataType": d_type})
|
||||
field_names.append(key)
|
||||
|
||||
status_code, response = self._client.create_table(
|
||||
table_name, table_fields=fields
|
||||
)
|
||||
if status_code != 200:
|
||||
if status_code == 409:
|
||||
logger.info(f"Continuing with the existing table {table_name}.")
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to create collection {table_name}: {response['message']}"
|
||||
)
|
||||
raise Exception("Error: {}.".format(response["message"]))
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: Optional[str] = "",
|
||||
drop_old: Optional[bool] = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Embed texts and add them to the database.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): The texts to embed.
|
||||
metadatas (Optional[List[dict]]): Metadata dicts
|
||||
attached to each of the texts. Defaults to None.
|
||||
collection_name (Optional[str]): Which collection to use.
|
||||
Defaults to "langchain_collection".
|
||||
If provided, default collection name will be set as well.
|
||||
drop_old (Optional[bool]): Whether to drop the previous collection
|
||||
and create a new one. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List of ids of the added texts.
|
||||
"""
|
||||
if not collection_name:
|
||||
collection_name = self._collection_name
|
||||
else:
|
||||
self._collection_name = collection_name
|
||||
|
||||
if drop_old:
|
||||
self._client.drop_db(db_name=collection_name)
|
||||
|
||||
texts = list(texts)
|
||||
try:
|
||||
embeddings = self._embeddings.embed_documents(texts)
|
||||
except NotImplementedError:
|
||||
embeddings = [self._embeddings.embed_query(x) for x in texts]
|
||||
|
||||
if len(embeddings) == 0:
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
|
||||
self._create_collection(
|
||||
table_name=collection_name, embeddings=embeddings, metadatas=metadatas
|
||||
)
|
||||
|
||||
ids = [hash(uuid.uuid4()) for _ in texts]
|
||||
records = []
|
||||
for index, id in enumerate(ids):
|
||||
record = {
|
||||
"id": id,
|
||||
"text": texts[index],
|
||||
"embeddings": embeddings[index],
|
||||
}
|
||||
if metadatas is not None:
|
||||
metadata = metadatas[index].items()
|
||||
for key, value in metadata:
|
||||
record[key] = value
|
||||
records.append(record)
|
||||
|
||||
status_code, response = self._client.insert(
|
||||
table_name=collection_name, records=records
|
||||
)
|
||||
if status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to add records to {collection_name}: {response['message']}"
|
||||
)
|
||||
raise Exception("Error: {}.".format(response["message"]))
|
||||
return [str(id) for id in ids]
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, collection_name: str = "", **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Return the documents that are semantically most relevant to the query.
|
||||
|
||||
Args:
|
||||
query (str): String to query the vectorstore with.
|
||||
k (Optional[int]): Number of documents to return. Defaults to 4.
|
||||
collection_name (Optional[str]): Collection to use.
|
||||
Defaults to "langchain_store" or the one provided before.
|
||||
Returns:
|
||||
List of documents that are semantically most relevant to the query
|
||||
"""
|
||||
if not collection_name:
|
||||
collection_name = self._collection_name
|
||||
query_vector = self._embeddings.embed_query(query)
|
||||
status_code, response = self._client.query(
|
||||
table_name=collection_name,
|
||||
query_field="embeddings",
|
||||
query_vector=query_vector,
|
||||
limit=k,
|
||||
)
|
||||
if status_code != 200:
|
||||
logger.error(f"Search failed: {response['message']}.")
|
||||
raise Exception("Error: {}.".format(response["message"]))
|
||||
|
||||
exclude_keys = ["id", "text", "embeddings"]
|
||||
return list(
|
||||
map(
|
||||
lambda item: Document(
|
||||
page_content=item["text"],
|
||||
metadata={
|
||||
key: item[key] for key in item if key not in exclude_keys
|
||||
},
|
||||
),
|
||||
response["result"],
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Epsilla],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
client: Any = None,
|
||||
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
|
||||
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
|
||||
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
|
||||
drop_old: Optional[bool] = False,
|
||||
**kwargs: Any,
|
||||
) -> Epsilla:
|
||||
"""Create an Epsilla vectorstore from raw documents.
|
||||
|
||||
Args:
|
||||
texts (List[str]): List of text data to be inserted.
|
||||
embeddings (Embeddings): Embedding function.
|
||||
client (pyepsilla.vectordb.Client): Epsilla client to connect to.
|
||||
metadatas (Optional[List[dict]]): Metadata for each text.
|
||||
Defaults to None.
|
||||
db_path (Optional[str]): The path where the database will be persisted.
|
||||
Defaults to "/tmp/langchain-epsilla".
|
||||
db_name (Optional[str]): Give a name to the loaded database.
|
||||
Defaults to "langchain_store".
|
||||
collection_name (Optional[str]): Which collection to use.
|
||||
Defaults to "langchain_collection".
|
||||
If provided, default collection name will be set as well.
|
||||
drop_old (Optional[bool]): Whether to drop the previous collection
|
||||
and create a new one. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Epsilla: Epsilla vector store.
|
||||
"""
|
||||
instance = Epsilla(client, embedding, db_path=db_path, db_name=db_name)
|
||||
instance.add_texts(
|
||||
texts,
|
||||
metadatas=metadatas,
|
||||
collection_name=collection_name,
|
||||
drop_old=drop_old,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[Epsilla],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
client: Any = None,
|
||||
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
|
||||
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
|
||||
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
|
||||
drop_old: Optional[bool] = False,
|
||||
**kwargs: Any,
|
||||
) -> Epsilla:
|
||||
"""Create an Epsilla vectorstore from a list of documents.
|
||||
|
||||
Args:
|
||||
texts (List[str]): List of text data to be inserted.
|
||||
embeddings (Embeddings): Embedding function.
|
||||
client (pyepsilla.vectordb.Client): Epsilla client to connect to.
|
||||
metadatas (Optional[List[dict]]): Metadata for each text.
|
||||
Defaults to None.
|
||||
db_path (Optional[str]): The path where the database will be persisted.
|
||||
Defaults to "/tmp/langchain-epsilla".
|
||||
db_name (Optional[str]): Give a name to the loaded database.
|
||||
Defaults to "langchain_store".
|
||||
collection_name (Optional[str]): Which collection to use.
|
||||
Defaults to "langchain_collection".
|
||||
If provided, default collection name will be set as well.
|
||||
drop_old (Optional[bool]): Whether to drop the previous collection
|
||||
and create a new one. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Epsilla: Epsilla vector store.
|
||||
"""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
|
||||
return cls.from_texts(
|
||||
texts,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
client=client,
|
||||
db_path=db_path,
|
||||
db_name=db_name,
|
||||
collection_name=collection_name,
|
||||
drop_old=drop_old,
|
||||
**kwargs,
|
||||
)
|
@@ -0,0 +1,31 @@
|
||||
"""Test Epsilla functionality."""
|
||||
from pyepsilla import vectordb
|
||||
|
||||
from langchain.vectorstores import Epsilla
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
FakeEmbeddings,
|
||||
fake_texts,
|
||||
)
|
||||
|
||||
|
||||
def _test_from_texts() -> Epsilla:
|
||||
embeddings = FakeEmbeddings()
|
||||
client = vectordb.Client()
|
||||
return Epsilla.from_texts(fake_texts, embeddings, client)
|
||||
|
||||
|
||||
def test_epsilla() -> None:
|
||||
instance = _test_from_texts()
|
||||
search = instance.similarity_search(query="bar", k=1)
|
||||
result_texts = [doc.page_content for doc in search]
|
||||
assert "bar" in result_texts
|
||||
|
||||
|
||||
def test_epsilla_add_texts() -> None:
|
||||
embeddings = FakeEmbeddings()
|
||||
client = vectordb.Client()
|
||||
db = Epsilla(client, embeddings)
|
||||
db.add_texts(fake_texts)
|
||||
search = db.similarity_search(query="foo", k=1)
|
||||
result_texts = [doc.page_content for doc in search]
|
||||
assert "foo" in result_texts
|
Reference in New Issue
Block a user