community[minor]: add vikingdb vecstore (#15155)

---------

Co-authored-by: gaoyuan <gaoyuan.20001218@bytedance.com>
This commit is contained in:
高远 2024-01-16 04:34:01 +08:00 committed by GitHub
parent d196646811
commit 061e63eef2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 623 additions and 0 deletions

View File

@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "96ff9e912bfe9d8",
"metadata": {
"collapsed": false
},
"source": [
"# viking DB\n",
"\n",
">[viking DB](https://www.volcengine.com/docs/6459/1163946) is a database that stores, indexes, and manages massive embedding vectors generated by deep neural networks and other machine learning (ML) models.\n",
"\n",
"This notebook shows how to use functionality related to the VikingDB vector database.\n",
"\n",
"To run, you should have a [viking DB instance up and running](https://www.volcengine.com/docs/6459/1165058).\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd771e02d8a93a0",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"!pip install --upgrade volcengine"
]
},
{
"cell_type": "markdown",
"id": "12719205caed0d18",
"metadata": {
"collapsed": false
},
"source": [
"We want to use VikingDBEmbeddings so we have to get the VikingDB API Key."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fbfb32665b4a3640",
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-21T09:53:24.186916Z",
"start_time": "2023-12-21T09:53:24.179524Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8c983d329237fa4",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.vectorstores.vikingdb import VikingDB, VikingDBConfig\n",
"from langchain_openai import OpenAIEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a4aea2eaeb2261",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"loader = TextLoader(\"./test.txt\")\n",
"documents = loader.load()\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"embeddings = OpenAIEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfd593f3deabfaf8",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"db = VikingDB.from_documents(\n",
" docs,\n",
" embeddings,\n",
" connection_args=VikingDBConfig(\n",
" host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n",
" ),\n",
" drop_old=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "50e6ee12ca7eec39",
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-21T10:01:47.355894Z",
"start_time": "2023-12-21T10:01:47.334789Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = db.similarity_search(query)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b6b81f5995c79ef0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-21T10:01:47.771478Z",
"start_time": "2023-12-21T10:01:47.731485Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"docs[0].page_content"
]
},
{
"cell_type": "markdown",
"id": "a2d932c1290478ee",
"metadata": {
"collapsed": false
},
"source": [
"### Compartmentalize the data with viking DB Collections\n",
"\n",
"You can store different unrelated documents in different collections within same viking DB instance to maintain the context"
]
},
{
"cell_type": "markdown",
"id": "907de4eb10626d2a",
"metadata": {
"collapsed": false
},
"source": [
"Here's how you can create a new collection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f5a59ba40f7985f",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"db = VikingDB.from_documents(\n",
" docs,\n",
" embeddings,\n",
" connection_args=VikingDBConfig(\n",
" host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n",
" ),\n",
" collection_name=\"collection_1\",\n",
" drop_old=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7c8eada37b17d992",
"metadata": {
"collapsed": false
},
"source": [
"And here is how you retrieve that stored collection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "883ec678d47c9adc",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"db = VikingDB.from_documents(\n",
" embeddings,\n",
" connection_args=VikingDBConfig(\n",
" host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n",
" ),\n",
" collection_name=\"collection_1\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2f0be30cfe70083d",
"metadata": {
"collapsed": false
},
"source": [
"After retreival you can go on querying it as usual."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,375 @@
from __future__ import annotations
import logging
import uuid
from typing import Any, List, Optional, Tuple
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
class VikingDBConfig(object):
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"):
self.host = host
self.region = region
self.ak = ak
self.sk = sk
self.scheme = scheme
class VikingDB(VectorStore):
def __init__(
self,
embedding_function: Embeddings,
collection_name: str = "LangChainCollection",
connection_args: Optional[VikingDBConfig] = None,
index_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
**kwargs: Any,
):
try:
from volcengine.viking_db import Collection, VikingDBService
except ImportError:
raise ValueError(
"Could not import volcengine python package. "
"Please install it with `pip install --upgrade volcengine`."
)
self.embedding_func = embedding_function
self.collection_name = collection_name
self.index_name = "LangChainIndex"
self.connection_args = connection_args
self.index_params = index_params
self.drop_old = drop_old
self.service = VikingDBService(
connection_args.host,
connection_args.region,
connection_args.ak,
connection_args.sk,
connection_args.scheme,
)
try:
col = self.service.get_collection(collection_name)
except Exception:
col = None
self.collection = col
self.index = None
if self.collection is not None:
self.index = self.service.get_index(self.collection_name, self.index_name)
if drop_old and isinstance(self.collection, Collection):
indexes = self.service.list_indexes(collection_name)
for index in indexes:
self.service.drop_index(collection_name, index.index_name)
self.service.drop_collection(collection_name)
self.collection = None
self.index = None
@property
def embeddings(self) -> Embeddings:
return self.embedding_func
def _create_collection(
self, embeddings: List, metadatas: Optional[List[dict]] = None
) -> None:
try:
from volcengine.viking_db import Field, FieldType
except ImportError:
raise ValueError(
"Could not import volcengine python package. "
"Please install it with `pip install --upgrade volcengine`."
)
dim = len(embeddings[0])
fields = []
if metadatas:
for key, value in metadatas[0].items():
# print(key, value)
if isinstance(value, str):
fields.append(Field(key, FieldType.String))
if isinstance(value, int):
fields.append(Field(key, FieldType.Int64))
if isinstance(value, bool):
fields.append(Field(key, FieldType.Bool))
if isinstance(value, list) and all(
isinstance(item, str) for item in value
):
fields.append(Field(key, FieldType.List_String))
if isinstance(value, list) and all(
isinstance(item, int) for item in value
):
fields.append(Field(key, FieldType.List_Int64))
fields.append(Field("text", FieldType.String))
fields.append(Field("primary_key", FieldType.String, is_primary_key=True))
fields.append(Field("vector", FieldType.Vector, dim=dim))
self.collection = self.service.create_collection(self.collection_name, fields)
def _create_index(self) -> None:
try:
from volcengine.viking_db import VectorIndexParams
except ImportError:
raise ValueError(
"Could not import volcengine python package. "
"Please install it with `pip install --upgrade volcengine`."
)
cpu_quota = 2
vector_index = VectorIndexParams()
partition_by = ""
scalar_index = None
if self.index_params is not None:
if self.index_params.get("cpu_quota") is not None:
cpu_quota = self.index_params["cpu_quota"]
if self.index_params.get("vector_index") is not None:
vector_index = self.index_params["vector_index"]
if self.index_params.get("partition_by") is not None:
partition_by = self.index_params["partition_by"]
if self.index_params.get("scalar_index") is not None:
scalar_index = self.index_params["scalar_index"]
self.index = self.service.create_index(
self.collection_name,
self.index_name,
vector_index=vector_index,
cpu_quota=cpu_quota,
partition_by=partition_by,
scalar_index=scalar_index,
)
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
try:
from volcengine.viking_db import Data
except ImportError:
raise ValueError(
"Could not import volcengine python package. "
"Please install it with `pip install --upgrade volcengine`."
)
texts = list(texts)
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
if self.collection is None:
self._create_collection(embeddings, metadatas)
self._create_index()
# insert data
data = []
pks: List[str] = []
for index in range(len(embeddings)):
primary_key = str(uuid.uuid4())
pks.append(primary_key)
field = {
"text": texts[index],
"primary_key": primary_key,
"vector": embeddings[index],
}
if metadatas is not None and index < len(metadatas):
names = list(metadatas[index].keys())
for name in names:
field[name] = metadatas[index].get(name)
data.append(Data(field))
total_count = len(data)
for i in range(0, total_count, batch_size):
end = min(i + batch_size, total_count)
insert_data = data[i:end]
# print(insert_data)
self.collection.upsert_data(insert_data)
return pks
def similarity_search(
self,
query: str,
params: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
res = self.similarity_search_with_score(query=query, params=params, **kwargs)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
params: Optional[dict] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
embedding = self.embedding_func.embed_query(query)
res = self.similarity_search_with_score_by_vector(
embedding=embedding, params=params, **kwargs
)
return res
def similarity_search_by_vector(
self,
embedding: List[float],
params: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
res = self.similarity_search_with_score_by_vector(
embedding=embedding, params=params, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
params: Optional[dict] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
if self.collection is None:
logger.debug("No existing collection to search.")
return []
filter = None
limit = 10
output_fields = None
partition = "default"
if params is not None:
if params.get("filter") is not None:
filter = params["filter"]
if params.get("limit") is not None:
limit = params["limit"]
if params.get("output_fields") is not None:
output_fields = params["output_fields"]
if params.get("partition") is not None:
partition = params["partition"]
res = self.index.search_by_vector(
embedding,
filter=filter,
limit=limit,
output_fields=output_fields,
partition=partition,
)
ret = []
for item in res:
item.fields.pop("primary_key")
item.fields.pop("vector")
page_content = item.fields.pop("text")
doc = Document(page_content=page_content, metadata=item.fields)
pair = (doc, item.score)
ret.append(pair)
return ret
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
lambda_mult: float = 0.5,
params: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
embedding = self.embedding_func.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding=embedding,
k=k,
lambda_mult=lambda_mult,
params=params,
**kwargs,
)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
lambda_mult: float = 0.5,
params: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
if self.collection is None:
logger.debug("No existing collection to search.")
return []
filter = None
limit = 10
output_fields = None
partition = "default"
if params is not None:
if params.get("filter") is not None:
filter = params["filter"]
if params.get("limit") is not None:
limit = params["limit"]
if params.get("output_fields") is not None:
output_fields = params["output_fields"]
if params.get("partition") is not None:
partition = params["partition"]
res = self.index.search_by_vector(
embedding,
filter=filter,
limit=limit,
output_fields=output_fields,
partition=partition,
)
documents = []
ordered_result_embeddings = []
for item in res:
ordered_result_embeddings.append(item.fields.pop("vector"))
item.fields.pop("primary_key")
page_content = item.fields.pop("text")
doc = Document(page_content=page_content, metadata=item.fields)
documents.append(doc)
new_ordering = maximal_marginal_relevance(
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
else:
ret.append(documents[x])
return ret
def delete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
if self.collection is None:
logger.debug("No existing collection to search.")
self.collection.delete_data(ids)
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
connection_args: Optional[VikingDBConfig] = None,
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
index_params: Optional[dict] = None,
drop_old: bool = False,
**kwargs: Any,
):
if connection_args is None:
raise Exception("VikingDBConfig does not exists")
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args,
index_params=index_params,
drop_old=drop_old,
**kwargs,
)
vector_db.add_texts(texts=texts, metadatas=metadatas)
return vector_db