community[patch]: LanceDB integration patch update (#20686)

Description : 

- added functionalities - delete, index creation, using existing
connection object etc.
- updated usage 
- Added LaceDB cloud OSS support

make lint_diff , make test checks done
This commit is contained in:
Raghav Dixit 2024-04-24 19:27:43 -04:00 committed by GitHub
parent 9e983c9500
commit 9b7fb381a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 226 additions and 82 deletions

View File

@ -12,6 +12,16 @@
"This notebook shows how to use functionality related to the `LanceDB` vector database based on the Lance data format." "This notebook shows how to use functionality related to the `LanceDB` vector database based on the Lance data format."
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "88ac92c0",
"metadata": {},
"outputs": [],
"source": [
"! pip install -U langchain-openai"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -32,7 +42,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 2,
"id": "a0361f5c-e6f4-45f4-b829-11680cf03cec", "id": "a0361f5c-e6f4-45f4-b829-11680cf03cec",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -47,25 +57,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 13,
"id": "aac9563e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.vectorstores import LanceDB"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a3c3999a", "id": "a3c3999a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.document_loaders import TextLoader\n", "from langchain.document_loaders import TextLoader\n",
"from langchain.vectorstores import LanceDB\n",
"from langchain_openai import OpenAIEmbeddings\n",
"from langchain_text_splitters import CharacterTextSplitter\n", "from langchain_text_splitters import CharacterTextSplitter\n",
"\n", "\n",
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n", "loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
@ -75,22 +74,61 @@
"embeddings = OpenAIEmbeddings()" "embeddings = OpenAIEmbeddings()"
] ]
}, },
{
"cell_type": "markdown",
"id": "e9517bb0",
"metadata": {},
"source": [
"##### For LanceDB cloud, you can invoke the vector store as follows :\n",
"\n",
"\n",
"```python\n",
"db_url = \"db://lang_test\" # url of db you created\n",
"api_key = \"xxxxx\" # your API key\n",
"region=\"us-east-1-dev\" # your selected region\n",
"\n",
"vector_store = LanceDB(\n",
" uri=db_url,\n",
" api_key=api_key,\n",
" region=region,\n",
" embedding=embeddings,\n",
" table_name='langchain_test'\n",
" )\n",
"```\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 7,
"id": "6e104aee", "id": "6e104aee",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"docsearch = LanceDB.from_documents(documents, embeddings)\n", "docsearch = LanceDB.from_documents(documents, embeddings)\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = docsearch.similarity_search(query)" "docs = docsearch.similarity_search(query)"
] ]
}, },
{
"cell_type": "markdown",
"id": "f5e1cdfd",
"metadata": {},
"source": [
"Additionaly, to explore the table you can load it into a df or save it in a csv file: \n",
"```python\n",
"tbl = docsearch.get_table()\n",
"print(\"tbl:\", tbl)\n",
"pd_df = tbl.to_pandas()\n",
"# pd_df.to_csv(\"docsearch.csv\", index=False)\n",
"\n",
"# you can also create a new vector store object using an older connection object:\n",
"vector_store = LanceDB(connection=tbl, embedding=embeddings)\n",
"```"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 9,
"id": "9c608226", "id": "9c608226",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -166,7 +204,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 10,
"id": "a359ed74", "id": "a359ed74",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -267,7 +305,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.8" "version": "3.11.5"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -60,7 +60,7 @@
" * document addition by id (`add_documents` method with `ids` argument)\n", " * document addition by id (`add_documents` method with `ids` argument)\n",
" * delete by id (`delete` method with `ids` argument)\n", " * delete by id (`delete` method with `ids` argument)\n",
"\n", "\n",
"Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `CouchbaseVectorStore`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `HanaDB`, `Milvus`, `MyScale`, `OpenSearchVectorSearch`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `Rockset`, `ScaNN`, `SupabaseVectorStore`, `SurrealDBStore`, `TimescaleVector`, `Vald`, `VDMS`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`, `TencentVectorDB`, `OpenSearchVectorSearch`.\n", "Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `CouchbaseVectorStore`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `HanaDB`, `LanceDB`, `Milvus`, `MyScale`, `OpenSearchVectorSearch`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `Rockset`, `ScaNN`, `SupabaseVectorStore`, `SurrealDBStore`, `TimescaleVector`, `Vald`, `VDMS`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`, `TencentVectorDB`, `OpenSearchVectorSearch`.\n",
" \n", " \n",
"## Caution\n", "## Caution\n",
"\n", "\n",

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import os
import uuid import uuid
import warnings
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from langchain_core.documents import Document from langchain_core.documents import Document
@ -8,6 +10,17 @@ from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
def import_lancedb() -> Any:
try:
import lancedb
except ImportError as e:
raise ImportError(
"Could not import pinecone lancedb package. "
"Please install it with `pip install lancedb`."
) from e
return lancedb
class LanceDB(VectorStore): class LanceDB(VectorStore):
"""`LanceDB` vector store. """`LanceDB` vector store.
@ -22,15 +35,15 @@ class LanceDB(VectorStore):
id_key: Key to use for the id in the database. Defaults to ``id``. id_key: Key to use for the id in the database. Defaults to ``id``.
text_key: Key to use for the text in the database. Defaults to ``text``. text_key: Key to use for the text in the database. Defaults to ``text``.
table_name: Name of the table to use. Defaults to ``vectorstore``. table_name: Name of the table to use. Defaults to ``vectorstore``.
api_key: API key to use for LanceDB cloud database.
region: Region to use for LanceDB cloud database.
mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
Example: Example:
.. code-block:: python .. code-block:: python
vectorstore = LanceDB(uri='/lancedb', embedding_function)
db = lancedb.connect('./lancedb')
table = db.open_table('my_table')
vectorstore = LanceDB(table, embedding_function)
vectorstore.add_texts(['text1', 'text2']) vectorstore.add_texts(['text1', 'text2'])
result = vectorstore.similarity_search('text1') result = vectorstore.similarity_search('text1')
""" """
@ -39,38 +52,55 @@ class LanceDB(VectorStore):
self, self,
connection: Optional[Any] = None, connection: Optional[Any] = None,
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
uri: Optional[str] = "/tmp/lancedb",
vector_key: Optional[str] = "vector", vector_key: Optional[str] = "vector",
id_key: Optional[str] = "id", id_key: Optional[str] = "id",
text_key: Optional[str] = "text", text_key: Optional[str] = "text",
table_name: Optional[str] = "vectorstore", table_name: Optional[str] = "vectorstore",
api_key: Optional[str] = None,
region: Optional[str] = None,
mode: Optional[str] = "overwrite",
): ):
"""Initialize with Lance DB vectorstore""" """Initialize with Lance DB vectorstore"""
try: lancedb = import_lancedb()
import lancedb
except ImportError:
raise ImportError(
"Could not import lancedb python package. "
"Please install it with `pip install lancedb`."
)
self.lancedb = lancedb
self._embedding = embedding self._embedding = embedding
self._vector_key = vector_key self._vector_key = vector_key
self._id_key = id_key self._id_key = id_key
self._text_key = text_key self._text_key = text_key
self._table_name = table_name self._table_name = table_name
self.api_key = api_key or os.getenv("LANCE_API_KEY") if api_key != "" else None
self.region = region
self.mode = mode
if isinstance(uri, str) and self.api_key is None:
if uri.startswith("db://"):
raise ValueError("API key is required for LanceDB cloud.")
if self._embedding is None: if self._embedding is None:
raise ValueError("embedding should be provided") raise ValueError("embedding object should be provided")
if connection is not None: if isinstance(connection, lancedb.db.LanceDBConnection):
if not isinstance(connection, lancedb.db.LanceTable):
raise ValueError(
"connection should be an instance of lancedb.db.LanceTable, ",
f"got {type(connection)}",
)
self._connection = connection self._connection = connection
elif isinstance(connection, (str, lancedb.db.LanceTable)):
raise ValueError(
"`connection` has to be a lancedb.db.LanceDBConnection object.\
`lancedb.db.LanceTable` is deprecated."
)
else: else:
self._connection = self._init_table() if self.api_key is None:
self._connection = lancedb.connect(uri)
else:
if isinstance(uri, str):
if uri.startswith("db://"):
self._connection = lancedb.connect(
uri, api_key=self.api_key, region=self.region
)
else:
self._connection = lancedb.connect(uri)
warnings.warn(
"api key provided with local uri.\
The data will be stored locally"
)
@property @property
def embeddings(self) -> Optional[Embeddings]: def embeddings(self) -> Optional[Embeddings]:
@ -88,7 +118,7 @@ class LanceDB(VectorStore):
Args: Args:
texts: Iterable of strings to add to the vectorstore. texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts. metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of ids to associate with the texts. ids: Optional list of ids to associate w ith the texts.
Returns: Returns:
List of ids of the added texts. List of ids of the added texts.
@ -99,20 +129,70 @@ class LanceDB(VectorStore):
embeddings = self._embedding.embed_documents(list(texts)) # type: ignore embeddings = self._embedding.embed_documents(list(texts)) # type: ignore
for idx, text in enumerate(texts): for idx, text in enumerate(texts):
embedding = embeddings[idx] embedding = embeddings[idx]
metadata = metadatas[idx] if metadatas else {} metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
docs.append( docs.append(
{ {
self._vector_key: embedding, self._vector_key: embedding,
self._id_key: ids[idx], self._id_key: ids[idx],
self._text_key: text, self._text_key: text,
**metadata, "metadata": metadata,
} }
) )
self._connection.add(docs)
if self._table_name in self._connection.table_names():
tbl = self._connection.open_table(self._table_name)
if self.api_key is None:
tbl.add(docs, mode=self.mode)
else:
tbl.add(docs)
else:
self._connection.create_table(self._table_name, data=docs)
return ids return ids
def get_table(self, name: Optional[str] = None) -> Any:
if name is not None:
try:
self._connection.open_table(name)
except Exception:
raise ValueError(f"Table {name} not found in the database")
else:
return self._connection.open_table(self._table_name)
def create_index(
self,
col_name: Optional[str] = None,
vector_col: Optional[str] = None,
num_partitions: Optional[int] = 256,
num_sub_vectors: Optional[int] = 96,
index_cache_size: Optional[int] = None,
) -> None:
"""
Create a scalar(for non-vector cols) or a vector index on a table.
Make sure your vector column has enough data before creating an index on it.
Args:
vector_col: Provide if you want to create index on a vector column.
col_name: Provide if you want to create index on a non-vector column.
metric: Provide the metric to use for vector index. Defaults to 'L2'
choice of metrics: 'L2', 'dot', 'cosine'
Returns:
None
"""
if vector_col:
self._connection.create_index(
vector_column_name=vector_col,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
index_cache_size=index_cache_size,
)
elif col_name:
self._connection.create_scalar_index(col_name)
else:
raise ValueError("Provide either vector_col or col_name")
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, name: Optional[str] = None, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
"""Return documents most similar to the query """Return documents most similar to the query
@ -124,8 +204,9 @@ class LanceDB(VectorStore):
List of documents most similar to the query. List of documents most similar to the query.
""" """
embedding = self._embedding.embed_query(query) # type: ignore embedding = self._embedding.embed_query(query) # type: ignore
tbl = self.get_table(name)
docs = ( docs = (
self._connection.search(embedding, vector_column_name=self._vector_key) tbl.search(embedding, vector_column_name=self._vector_key)
.limit(k) .limit(k)
.to_arrow() .to_arrow()
) )
@ -155,32 +236,47 @@ class LanceDB(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> LanceDB: ) -> LanceDB:
instance = LanceDB( instance = LanceDB(
connection, connection=connection,
embedding, embedding=embedding,
vector_key, vector_key=vector_key,
id_key, id_key=id_key,
text_key, text_key=text_key,
) )
instance.add_texts(texts, metadatas=metadatas, **kwargs) instance.add_texts(texts, metadatas=metadatas, **kwargs)
return instance return instance
def _init_table(self) -> Any: def delete(
import pyarrow as pa self,
ids: Optional[List[str]] = None,
delete_all: Optional[bool] = None,
filter: Optional[str] = None,
drop_columns: Optional[List[str]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Allows deleting rows by filtering, by ids or drop columns from the table.
schema = pa.schema( Args:
[ filter: Provide a string SQL expression - "{col} {operation} {value}".
pa.field( ids: Provide list of ids to delete from the table.
self._vector_key, drop_columns: Provide list of columns to drop from the table.
pa.list_( delete_all: If True, delete all rows from the table.
pa.float32(), """
len(self.embeddings.embed_query("test")), # type: ignore tbl = self.get_table(name)
), if filter:
), tbl.delete(filter)
pa.field(self._id_key, pa.string()), elif ids:
pa.field(self._text_key, pa.string()), tbl.delete("id in ('{}')".format(",".join(ids)))
] elif drop_columns:
) if self.api_key is not None:
db = self.lancedb.connect("/tmp/lancedb") raise NotImplementedError(
tbl = db.create_table(self._table_name, schema=schema, mode="overwrite") "Column operations currently not supported in LanceDB Cloud."
return tbl )
else:
tbl.drop_columns(drop_columns)
elif delete_all:
tbl.delete("true")
else:
raise ValueError("Provide either filter, ids, drop_columns or delete_all")

View File

@ -1,30 +1,39 @@
from typing import Any
import pytest import pytest
from langchain_community.vectorstores import LanceDB from langchain_community.vectorstores import LanceDB
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
def import_lancedb() -> Any:
try:
import lancedb
except ImportError as e:
raise ImportError(
"Could not import pinecone lancedb package. "
"Please install it with `pip install lancedb`."
) from e
return lancedb
@pytest.mark.requires("lancedb") @pytest.mark.requires("lancedb")
def test_lancedb_with_connection() -> None: def test_lancedb_with_connection() -> None:
import lancedb lancedb = import_lancedb()
embeddings = FakeEmbeddings() embeddings = FakeEmbeddings()
db = lancedb.connect("/tmp/lancedb") db = lancedb.connect("/tmp/lancedb_connection")
texts = ["text 1", "text 2", "item 3"] texts = ["text 1", "text 2", "item 3"]
vectors = embeddings.embed_documents(texts) store = LanceDB(connection=db, embedding=embeddings)
table = db.create_table( store.add_texts(texts)
"my_table",
data=[
{"vector": vectors[idx], "id": text, "text": text}
for idx, text in enumerate(texts)
],
mode="overwrite",
)
store = LanceDB(table, embeddings)
result = store.similarity_search("text 1") result = store.similarity_search("text 1")
result_texts = [doc.page_content for doc in result] result_texts = [doc.page_content for doc in result]
assert "text 1" in result_texts assert "text 1" in result_texts
store.delete(filter="text = 'text 1'")
assert store.get_table().count_rows() == 2
@pytest.mark.requires("lancedb") @pytest.mark.requires("lancedb")
def test_lancedb_without_connection() -> None: def test_lancedb_without_connection() -> None:

View File

@ -67,6 +67,7 @@ def test_compatible_vectorstore_documentation() -> None:
"FAISS", "FAISS",
"HanaDB", "HanaDB",
"InMemoryVectorStore", "InMemoryVectorStore",
"LanceDB",
"Milvus", "Milvus",
"MomentoVectorIndex", "MomentoVectorIndex",
"MyScale", "MyScale",