Files
langchain/libs/community/langchain_community/vectorstores/epsilla.py
Bagatur a0c2281540 infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path

import toml
import subprocess
import re

ROOT_DIR = Path(__file__).parents[1]


def main():
    for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
        print(path)
        with open(path, "rb") as f:
            pyproject = tomllib.load(f)
        try:
            pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
                "^1.10"
            )
            pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
                "^0.5"
            )
        except KeyError:
            continue
        with open(path, "w") as f:
            toml.dump(pyproject, f)
        cwd = "/".join(path.split("/")[:-1])
        completed = subprocess.run(
            "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )
        logs = completed.stdout.split("\n")

        to_ignore = {}
        for l in logs:
            if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
                path, line_no, error_type = re.match(
                    "^(.*)\:(\d+)\: error:.*\[(.*)\]", l
                ).groups()
                if (path, line_no) in to_ignore:
                    to_ignore[(path, line_no)].append(error_type)
                else:
                    to_ignore[(path, line_no)] = [error_type]
        print(len(to_ignore))
        for (error_path, line_no), error_types in to_ignore.items():
            all_errors = ", ".join(error_types)
            full_path = f"{cwd}/{error_path}"
            try:
                with open(full_path, "r") as f:
                    file_lines = f.readlines()
            except FileNotFoundError:
                continue
            file_lines[int(line_no) - 1] = (
                file_lines[int(line_no) - 1][:-1] + f"  # type: ignore[{all_errors}]\n"
            )
            with open(full_path, "w") as f:
                f.write("".join(file_lines))

        subprocess.run(
            "poetry run ruff format .; poetry run ruff --select I --fix .",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )


if __name__ == "__main__":
    main()

```
2024-07-03 10:33:27 -07:00

377 lines
14 KiB
Python

"""Wrapper around Epsilla vector database."""
from __future__ import annotations
import logging
import uuid
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from pyepsilla import vectordb
logger = logging.getLogger()
class Epsilla(VectorStore):
"""
Wrapper around Epsilla vector database.
As a prerequisite, you need to install ``pyepsilla`` package
and have a running Epsilla vector database (for example, through our docker image)
See the following documentation for how to run an Epsilla vector database:
https://epsilla-inc.gitbook.io/epsilladb/quick-start
Args:
client (Any): Epsilla client to connect to.
embeddings (Embeddings): Function used to embed the texts.
db_path (Optional[str]): The path where the database will be persisted.
Defaults to "/tmp/langchain-epsilla".
db_name (Optional[str]): Give a name to the loaded database.
Defaults to "langchain_store".
Example:
.. code-block:: python
from langchain_community.vectorstores import Epsilla
from pyepsilla import vectordb
client = vectordb.Client()
embeddings = OpenAIEmbeddings()
db_path = "/tmp/vectorstore"
db_name = "langchain_store"
epsilla = Epsilla(client, embeddings, db_path, db_name)
"""
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
def __init__(
self,
client: Any,
embeddings: Embeddings,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
):
"""Initialize with necessary components."""
try:
import pyepsilla
except ImportError as e:
raise ImportError(
"Could not import pyepsilla python package. "
"Please install pyepsilla package with `pip install pyepsilla`."
) from e
if not isinstance(client, pyepsilla.vectordb.Client):
raise TypeError(
f"client should be an instance of pyepsilla.vectordb.Client, "
f"got {type(client)}"
)
self._client: vectordb.Client = client
self._db_name = db_name
self._embeddings = embeddings
self._collection_name = Epsilla._LANGCHAIN_DEFAULT_TABLE_NAME
self._client.load_db(db_name=db_name, db_path=db_path)
self._client.use_db(db_name=db_name)
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embeddings
def use_collection(self, collection_name: str) -> None:
"""
Set default collection to use.
Args:
collection_name (str): The name of the collection.
"""
self._collection_name = collection_name
def clear_data(self, collection_name: str = "") -> None:
"""
Clear data in a collection.
Args:
collection_name (Optional[str]): The name of the collection.
If not provided, the default collection will be used.
"""
if not collection_name:
collection_name = self._collection_name
self._client.drop_table(collection_name)
def get(
self, collection_name: str = "", response_fields: Optional[List[str]] = None
) -> List[dict]:
"""Get the collection.
Args:
collection_name (Optional[str]): The name of the collection
to retrieve data from.
If not provided, the default collection will be used.
response_fields (Optional[List[str]]): List of field names in the result.
If not specified, all available fields will be responded.
Returns:
A list of the retrieved data.
"""
if not collection_name:
collection_name = self._collection_name
status_code, response = self._client.get(
table_name=collection_name, response_fields=response_fields
)
if status_code != 200:
logger.error(f"Failed to get records: {response['message']}")
raise Exception("Error: {}.".format(response["message"]))
return response["result"]
def _create_collection(
self, table_name: str, embeddings: list, metadatas: Optional[list[dict]] = None
) -> None:
if not embeddings:
raise ValueError("Embeddings list is empty.")
dim = len(embeddings[0])
fields: List[dict] = [
{"name": "id", "dataType": "INT"},
{"name": "text", "dataType": "STRING"},
{"name": "embeddings", "dataType": "VECTOR_FLOAT", "dimensions": dim},
]
if metadatas is not None:
field_names = [field["name"] for field in fields]
for metadata in metadatas:
for key, value in metadata.items():
if key in field_names:
continue
d_type: str
if isinstance(value, str):
d_type = "STRING"
elif isinstance(value, int):
d_type = "INT"
elif isinstance(value, float):
d_type = "FLOAT"
elif isinstance(value, bool):
d_type = "BOOL"
else:
raise ValueError(f"Unsupported data type for {key}.")
fields.append({"name": key, "dataType": d_type})
field_names.append(key)
status_code, response = self._client.create_table(
table_name, table_fields=fields
)
if status_code != 200:
if status_code == 409:
logger.info(f"Continuing with the existing table {table_name}.")
else:
logger.error(
f"Failed to create collection {table_name}: {response['message']}"
)
raise Exception("Error: {}.".format(response["message"]))
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
collection_name: Optional[str] = "",
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> List[str]:
"""
Embed texts and add them to the database.
Args:
texts (Iterable[str]): The texts to embed.
metadatas (Optional[List[dict]]): Metadata dicts
attached to each of the texts. Defaults to None.
collection_name (Optional[str]): Which collection to use.
Defaults to "langchain_collection".
If provided, default collection name will be set as well.
drop_old (Optional[bool]): Whether to drop the previous collection
and create a new one. Defaults to False.
Returns:
List of ids of the added texts.
"""
if not collection_name:
collection_name = self._collection_name
else:
self._collection_name = collection_name
if drop_old:
self._client.drop_db(db_name=collection_name)
texts = list(texts)
try:
embeddings = self._embeddings.embed_documents(texts)
except NotImplementedError:
embeddings = [self._embeddings.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
self._create_collection(
table_name=collection_name, embeddings=embeddings, metadatas=metadatas
)
ids = [hash(uuid.uuid4()) for _ in texts]
records = []
for index, id in enumerate(ids):
record = {
"id": id,
"text": texts[index],
"embeddings": embeddings[index],
}
if metadatas is not None:
metadata = metadatas[index].items()
for key, value in metadata:
record[key] = value
records.append(record)
status_code, response = self._client.insert(
table_name=collection_name, records=records
)
if status_code != 200:
logger.error(
f"Failed to add records to {collection_name}: {response['message']}"
)
raise Exception("Error: {}.".format(response["message"]))
return [str(id) for id in ids]
def similarity_search(
self, query: str, k: int = 4, collection_name: str = "", **kwargs: Any
) -> List[Document]:
"""
Return the documents that are semantically most relevant to the query.
Args:
query (str): String to query the vectorstore with.
k (Optional[int]): Number of documents to return. Defaults to 4.
collection_name (Optional[str]): Collection to use.
Defaults to "langchain_store" or the one provided before.
Returns:
List of documents that are semantically most relevant to the query
"""
if not collection_name:
collection_name = self._collection_name
query_vector = self._embeddings.embed_query(query)
status_code, response = self._client.query(
table_name=collection_name,
query_field="embeddings",
query_vector=query_vector,
limit=k,
)
if status_code != 200:
logger.error(f"Search failed: {response['message']}.")
raise Exception("Error: {}.".format(response["message"]))
exclude_keys = ["id", "text", "embeddings"]
return list(
map(
lambda item: Document(
page_content=item["text"],
metadata={
key: item[key] for key in item if key not in exclude_keys
},
),
response["result"],
)
)
@classmethod
def from_texts(
cls: Type[Epsilla],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
client: Any = None,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> Epsilla:
"""Create an Epsilla vectorstore from raw documents.
Args:
texts (List[str]): List of text data to be inserted.
embeddings (Embeddings): Embedding function.
client (pyepsilla.vectordb.Client): Epsilla client to connect to.
metadatas (Optional[List[dict]]): Metadata for each text.
Defaults to None.
db_path (Optional[str]): The path where the database will be persisted.
Defaults to "/tmp/langchain-epsilla".
db_name (Optional[str]): Give a name to the loaded database.
Defaults to "langchain_store".
collection_name (Optional[str]): Which collection to use.
Defaults to "langchain_collection".
If provided, default collection name will be set as well.
drop_old (Optional[bool]): Whether to drop the previous collection
and create a new one. Defaults to False.
Returns:
Epsilla: Epsilla vector store.
"""
instance = Epsilla(client, embedding, db_path=db_path, db_name=db_name)
instance.add_texts(
texts,
metadatas=metadatas,
collection_name=collection_name,
drop_old=drop_old,
**kwargs,
)
return instance
@classmethod
def from_documents(
cls: Type[Epsilla],
documents: List[Document],
embedding: Embeddings,
client: Any = None,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> Epsilla:
"""Create an Epsilla vectorstore from a list of documents.
Args:
texts (List[str]): List of text data to be inserted.
embeddings (Embeddings): Embedding function.
client (pyepsilla.vectordb.Client): Epsilla client to connect to.
metadatas (Optional[List[dict]]): Metadata for each text.
Defaults to None.
db_path (Optional[str]): The path where the database will be persisted.
Defaults to "/tmp/langchain-epsilla".
db_name (Optional[str]): Give a name to the loaded database.
Defaults to "langchain_store".
collection_name (Optional[str]): Which collection to use.
Defaults to "langchain_collection".
If provided, default collection name will be set as well.
drop_old (Optional[bool]): Whether to drop the previous collection
and create a new one. Defaults to False.
Returns:
Epsilla: Epsilla vector store.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts,
embedding,
metadatas=metadatas,
client=client,
db_path=db_path,
db_name=db_name,
collection_name=collection_name,
drop_old=drop_old,
**kwargs,
)