community[minor]: Add Baidu VectorDB as vector store (#17997)

Co-authored-by: fengjialin <fengjialin@MacBook-Pro.local>
This commit is contained in:
fengjial 2024-03-16 03:01:58 +08:00 committed by GitHub
parent 190887c5cd
commit c922ea36cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 559 additions and 0 deletions

View File

@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
}
},
"source": [
"# Baidu VectorDB\n",
"\n",
">[Baidu VectorDB](https://cloud.baidu.com/product/vdb.html) is a robust, enterprise-level distributed database service, meticulously developed and fully managed by Baidu Intelligent Cloud. It stands out for its exceptional ability to store, retrieve, and analyze multi-dimensional vector data. At its core, VectorDB operates on Baidu's proprietary \"Mochow\" vector database kernel, which ensures high performance, availability, and security, alongside remarkable scalability and user-friendliness.\n",
"\n",
">This database service supports a diverse range of index types and similarity calculation methods, catering to various use cases. A standout feature of VectorDB is its capacity to manage an immense vector scale of up to 10 billion, while maintaining impressive query performance, supporting millions of queries per second (QPS) with millisecond-level query latency.\n",
"\n",
"This notebook shows how to use functionality related to the Baidu VectorDB. \n",
"\n",
"To run, you should have a [Database instance.](https://cloud.baidu.com/doc/VDB/s/hlrsoazuf)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip3 install pymochow"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain_community.document_loaders import TextLoader\n",
"from langchain_community.embeddings.fake import FakeEmbeddings\n",
"from langchain_community.vectorstores import BaiduVectorDB\n",
"from langchain_community.vectorstores.baiduvectordb import ConnectionParams"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"embeddings = FakeEmbeddings(size=128)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"conn_params = ConnectionParams(\n",
" endpoint=\"http://192.168.xx.xx:xxxx\", account=\"root\", api_key=\"****\"\n",
")\n",
"\n",
"vector_db = BaiduVectorDB.from_documents(\n",
" docs, embeddings, connection_params=conn_params, drop=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = vector_db.similarity_search(query)\n",
"docs[0].page_content"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vector_db = BaiduVectorDB(embeddings, conn_params)\n",
"vector_db.add_texts([\"Ankush went to Princeton\"])\n",
"query = \"Where did Ankush go to college?\"\n",
"docs = vector_db.max_marginal_relevance_search(query)\n",
"docs[0].page_content"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -33,6 +33,7 @@ _module_lookup = {
"AwaDB": "langchain_community.vectorstores.awadb",
"AzureCosmosDBVectorSearch": "langchain_community.vectorstores.azure_cosmos_db",
"AzureSearch": "langchain_community.vectorstores.azuresearch",
"BaiduVectorDB": "langchain_community.vectorstores.baiduvectordb",
"BESVectorStore": "langchain_community.vectorstores.baiducloud_vector_search",
"Bagel": "langchain_community.vectorstores.bageldb",
"BigQueryVectorSearch": "langchain_community.vectorstores.bigquery_vector_search",

View File

@ -0,0 +1,436 @@
"""Wrapper around the Baidu vector database."""
from __future__ import annotations
import json
import logging
import time
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
class ConnectionParams:
"""Baidu VectorDB Connection params.
See the following documentation for details:
https://cloud.baidu.com/doc/VDB/s/6lrsob0wy
Attribute:
endpoint (str) : The access address of the vector database server
that the client needs to connect to.
api_key (str): API key for client to access the vector database server,
which is used for authentication.
account (str) : Account for client to access the vector database server.
connection_timeout_in_mills (int) : Request Timeout.
"""
def __init__(
self,
endpoint: str,
api_key: str,
account: str = "root",
connection_timeout_in_mills: int = 50 * 1000,
):
self.endpoint = endpoint
self.api_key = api_key
self.account = account
self.connection_timeout_in_mills = connection_timeout_in_mills
class TableParams:
"""Baidu VectorDB table params.
See the following documentation for details:
https://cloud.baidu.com/doc/VDB/s/mlrsob0p6
"""
def __init__(
self,
dimension: int,
replication: int = 3,
partition: int = 1,
index_type: str = "HNSW",
metric_type: str = "L2",
params: Optional[Dict] = None,
):
self.dimension = dimension
self.replication = replication
self.partition = partition
self.index_type = index_type
self.metric_type = metric_type
self.params = params
class BaiduVectorDB(VectorStore):
"""Baidu VectorDB as a vector store.
In order to use this you need to have a database instance.
See the following documentation for details:
https://cloud.baidu.com/doc/VDB/index.html
"""
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
index_vector: str = "vector_idx"
def __init__(
self,
embedding: Embeddings,
connection_params: ConnectionParams,
table_params: TableParams = TableParams(128),
database_name: str = "LangChainDatabase",
table_name: str = "LangChainTable",
drop_old: Optional[bool] = False,
):
pymochow = guard_import("pymochow")
configuration = guard_import("pymochow.configuration")
auth = guard_import("pymochow.auth.bce_credentials")
self.mochowtable = guard_import("pymochow.model.table")
self.mochowenum = guard_import("pymochow.model.enum")
self.embedding_func = embedding
self.table_params = table_params
config = configuration.Configuration(
credentials=auth.BceCredentials(
connection_params.account, connection_params.api_key
),
endpoint=connection_params.endpoint,
connection_timeout_in_mills=connection_params.connection_timeout_in_mills,
)
self.vdb_client = pymochow.MochowClient(config)
db_list = self.vdb_client.list_databases()
db_exist: bool = False
for db in db_list:
if database_name == db.database_name:
db_exist = True
break
if db_exist:
self.database = self.vdb_client.database(database_name)
else:
self.database = self.vdb_client.create_database(database_name)
try:
self.table = self.database.describe_table(table_name)
if drop_old:
self.database.drop_table(table_name)
self._create_table(table_name)
except pymochow.exception.ServerError:
self._create_table(table_name)
def _create_table(self, table_name: str) -> None:
schema = guard_import("pymochow.model.schema")
index_type = None
for k, v in self.mochowenum.IndexType.__members__.items():
if k == self.table_params.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in self.mochowenum.MetricType.__members__.items():
if k == self.table_params.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")
if self.table_params.params is None:
params = schema.HNSWParams(m=16, efconstruction=200)
else:
params = schema.HNSWParams(
m=self.table_params.params.get("M", 16),
efconstruction=self.table_params.params.get("efConstruction", 200),
)
fields = []
fields.append(
schema.Field(
self.field_id,
self.mochowenum.FieldType.STRING,
primary_key=True,
partition_key=True,
auto_increment=False,
not_null=True,
)
)
fields.append(
schema.Field(
self.field_vector,
self.mochowenum.FieldType.FLOAT_VECTOR,
dimension=self.table_params.dimension,
)
)
fields.append(schema.Field(self.field_text, self.mochowenum.FieldType.STRING))
fields.append(
schema.Field(self.field_metadata, self.mochowenum.FieldType.STRING)
)
indexes = []
indexes.append(
schema.VectorIndex(
index_name=self.index_vector,
index_type=index_type,
field=self.field_vector,
metric_type=metric_type,
params=params,
)
)
self.table = self.database.create_table(
table_name=table_name,
replication=self.table_params.replication,
partition=self.mochowtable.Partition(
partition_num=self.table_params.partition
),
schema=schema.Schema(fields=fields, indexes=indexes),
)
while True:
time.sleep(1)
table = self.database.describe_table(table_name)
if table.state == self.mochowenum.TableState.NORMAL:
break
@property
def embeddings(self) -> Embeddings:
return self.embedding_func
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
connection_params: Optional[ConnectionParams] = None,
table_params: Optional[TableParams] = None,
database_name: str = "LangChainDatabase",
table_name: str = "LangChainTable",
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> BaiduVectorDB:
"""Create a table, indexes it with HNSW, and insert data."""
if len(texts) == 0:
raise ValueError("texts is empty")
if connection_params is None:
raise ValueError("connection_params is empty")
try:
embeddings = embedding.embed_documents(texts[0:1])
except NotImplementedError:
embeddings = [embedding.embed_query(texts[0])]
dimension = len(embeddings[0])
if table_params is None:
table_params = TableParams(dimension=dimension)
else:
table_params.dimension = dimension
vector_db = cls(
embedding=embedding,
connection_params=connection_params,
table_params=table_params,
database_name=database_name,
table_name=table_name,
drop_old=drop_old,
)
vector_db.add_texts(texts=texts, metadatas=metadatas)
return vector_db
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""Insert text data into Baidu VectorDB."""
texts = list(texts)
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
pks: list[str] = []
total_count = len(embeddings)
for start in range(0, total_count, batch_size):
# Grab end index
rows = []
end = min(start + batch_size, total_count)
for id in range(start, end, 1):
metadata = "{}"
if metadatas is not None:
metadata = json.dumps(metadatas[id])
row = self.mochowtable.Row(
id="{}-{}-{}".format(time.time_ns(), hash(texts[id]), id),
vector=[float(num) for num in embeddings[id]],
text=texts[id],
metadata=metadata,
)
rows.append(row)
pks.append(str(id))
self.table.upsert(rows=rows)
# need rebuild vindex after upsert
self.table.rebuild_index(self.index_vector)
while True:
time.sleep(2)
index = self.table.describe_index(self.index_vector)
if index.state == self.mochowenum.IndexState.NORMAL:
break
return pks
def similarity_search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string."""
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score."""
# Embed the query text.
embedding = self.embedding_func.embed_query(query)
res = self._similarity_search_with_score(
embedding=embedding, k=k, param=param, expr=expr, **kwargs
)
return res
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string."""
res = self._similarity_search_with_score(
embedding=embedding, k=k, param=param, expr=expr, **kwargs
)
return [doc for doc, _ in res]
def _similarity_search_with_score(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score."""
ef = 10 if param is None else param.get("ef", 10)
anns = self.mochowtable.AnnSearch(
vector_field=self.field_vector,
vector_floats=[float(num) for num in embedding],
params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
filter=expr,
)
res = self.table.search(anns=anns)
rows = [[item] for item in res.rows]
# Organize results.
ret: List[Tuple[Document, float]] = []
if rows is None or len(rows) == 0:
return ret
for row in rows:
for result in row:
row_data = result.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
doc = Document(
page_content=row_data.get(self.field_text), metadata=meta
)
pair = (doc, result.get("distance", 0.0))
ret.append(pair)
return ret
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR."""
embedding = self.embedding_func.embed_query(query)
return self._max_marginal_relevance_search(
embedding=embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
param=param,
expr=expr,
**kwargs,
)
def _max_marginal_relevance_search(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR."""
ef = 10 if param is None else param.get("ef", 10)
anns = self.mochowtable.AnnSearch(
vector_field=self.field_vector,
vector_floats=[float(num) for num in embedding],
params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
filter=expr,
)
res = self.table.search(anns=anns, retrieve_vector=True)
# Organize results.
documents: List[Document] = []
ordered_result_embeddings = []
rows = [[item] for item in res.rows]
if rows is None or len(rows) == 0:
return documents
for row in rows:
for result in row:
row_data = result.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
doc = Document(
page_content=row_data.get(self.field_text), metadata=meta
)
documents.append(doc)
ordered_result_embeddings.append(row_data.get(self.field_vector))
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
else:
ret.append(documents[x])
return ret

View File

@ -11,6 +11,7 @@ _EXPECTED = [
"AwaDB",
"AzureSearch",
"Bagel",
"BaiduVectorDB",
"BESVectorStore",
"BigQueryVectorSearch",
"Cassandra",