mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
community[minor]: add vikingdb vecstore (#15155)
--------- Co-authored-by: gaoyuan <gaoyuan.20001218@bytedance.com>
This commit is contained in:
parent
d196646811
commit
061e63eef2
248
docs/docs/integrations/vectorstores/vikingdb.ipynb
Normal file
248
docs/docs/integrations/vectorstores/vikingdb.ipynb
Normal file
@ -0,0 +1,248 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "96ff9e912bfe9d8",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"# viking DB\n",
|
||||
"\n",
|
||||
">[viking DB](https://www.volcengine.com/docs/6459/1163946) is a database that stores, indexes, and manages massive embedding vectors generated by deep neural networks and other machine learning (ML) models.\n",
|
||||
"\n",
|
||||
"This notebook shows how to use functionality related to the VikingDB vector database.\n",
|
||||
"\n",
|
||||
"To run, you should have a [viking DB instance up and running](https://www.volcengine.com/docs/6459/1165058).\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd771e02d8a93a0",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install --upgrade volcengine"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12719205caed0d18",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"We want to use VikingDBEmbeddings so we have to get the VikingDB API Key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "fbfb32665b4a3640",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-12-21T09:53:24.186916Z",
|
||||
"start_time": "2023-12-21T09:53:24.179524Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d8c983d329237fa4",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"from langchain.vectorstores.vikingdb import VikingDB, VikingDBConfig\n",
|
||||
"from langchain_openai import OpenAIEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a4aea2eaeb2261",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = TextLoader(\"./test.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bfd593f3deabfaf8",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = VikingDB.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" connection_args=VikingDBConfig(\n",
|
||||
" host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n",
|
||||
" ),\n",
|
||||
" drop_old=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "50e6ee12ca7eec39",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-12-21T10:01:47.355894Z",
|
||||
"start_time": "2023-12-21T10:01:47.334789Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = db.similarity_search(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "b6b81f5995c79ef0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-12-21T10:01:47.771478Z",
|
||||
"start_time": "2023-12-21T10:01:47.731485Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs[0].page_content"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a2d932c1290478ee",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"### Compartmentalize the data with viking DB Collections\n",
|
||||
"\n",
|
||||
"You can store different unrelated documents in different collections within same viking DB instance to maintain the context"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "907de4eb10626d2a",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"Here's how you can create a new collection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4f5a59ba40f7985f",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = VikingDB.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" connection_args=VikingDBConfig(\n",
|
||||
" host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n",
|
||||
" ),\n",
|
||||
" collection_name=\"collection_1\",\n",
|
||||
" drop_old=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7c8eada37b17d992",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"And here is how you retrieve that stored collection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "883ec678d47c9adc",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = VikingDB.from_documents(\n",
|
||||
" embeddings,\n",
|
||||
" connection_args=VikingDBConfig(\n",
|
||||
" host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n",
|
||||
" ),\n",
|
||||
" collection_name=\"collection_1\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f0be30cfe70083d",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"After retreival you can go on querying it as usual."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
375
libs/community/langchain_community/vectorstores/vikngdb.py
Normal file
375
libs/community/langchain_community/vectorstores/vikngdb.py
Normal file
@ -0,0 +1,375 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VikingDBConfig(object):
|
||||
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"):
|
||||
self.host = host
|
||||
self.region = region
|
||||
self.ak = ak
|
||||
self.sk = sk
|
||||
self.scheme = scheme
|
||||
|
||||
|
||||
class VikingDB(VectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[VikingDBConfig] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
drop_old: Optional[bool] = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from volcengine.viking_db import Collection, VikingDBService
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import volcengine python package. "
|
||||
"Please install it with `pip install --upgrade volcengine`."
|
||||
)
|
||||
self.embedding_func = embedding_function
|
||||
self.collection_name = collection_name
|
||||
self.index_name = "LangChainIndex"
|
||||
self.connection_args = connection_args
|
||||
self.index_params = index_params
|
||||
self.drop_old = drop_old
|
||||
self.service = VikingDBService(
|
||||
connection_args.host,
|
||||
connection_args.region,
|
||||
connection_args.ak,
|
||||
connection_args.sk,
|
||||
connection_args.scheme,
|
||||
)
|
||||
|
||||
try:
|
||||
col = self.service.get_collection(collection_name)
|
||||
except Exception:
|
||||
col = None
|
||||
self.collection = col
|
||||
self.index = None
|
||||
if self.collection is not None:
|
||||
self.index = self.service.get_index(self.collection_name, self.index_name)
|
||||
|
||||
if drop_old and isinstance(self.collection, Collection):
|
||||
indexes = self.service.list_indexes(collection_name)
|
||||
for index in indexes:
|
||||
self.service.drop_index(collection_name, index.index_name)
|
||||
self.service.drop_collection(collection_name)
|
||||
self.collection = None
|
||||
self.index = None
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_func
|
||||
|
||||
def _create_collection(
|
||||
self, embeddings: List, metadatas: Optional[List[dict]] = None
|
||||
) -> None:
|
||||
try:
|
||||
from volcengine.viking_db import Field, FieldType
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import volcengine python package. "
|
||||
"Please install it with `pip install --upgrade volcengine`."
|
||||
)
|
||||
dim = len(embeddings[0])
|
||||
fields = []
|
||||
if metadatas:
|
||||
for key, value in metadatas[0].items():
|
||||
# print(key, value)
|
||||
if isinstance(value, str):
|
||||
fields.append(Field(key, FieldType.String))
|
||||
if isinstance(value, int):
|
||||
fields.append(Field(key, FieldType.Int64))
|
||||
if isinstance(value, bool):
|
||||
fields.append(Field(key, FieldType.Bool))
|
||||
if isinstance(value, list) and all(
|
||||
isinstance(item, str) for item in value
|
||||
):
|
||||
fields.append(Field(key, FieldType.List_String))
|
||||
if isinstance(value, list) and all(
|
||||
isinstance(item, int) for item in value
|
||||
):
|
||||
fields.append(Field(key, FieldType.List_Int64))
|
||||
fields.append(Field("text", FieldType.String))
|
||||
|
||||
fields.append(Field("primary_key", FieldType.String, is_primary_key=True))
|
||||
|
||||
fields.append(Field("vector", FieldType.Vector, dim=dim))
|
||||
|
||||
self.collection = self.service.create_collection(self.collection_name, fields)
|
||||
|
||||
def _create_index(self) -> None:
|
||||
try:
|
||||
from volcengine.viking_db import VectorIndexParams
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import volcengine python package. "
|
||||
"Please install it with `pip install --upgrade volcengine`."
|
||||
)
|
||||
cpu_quota = 2
|
||||
vector_index = VectorIndexParams()
|
||||
partition_by = ""
|
||||
scalar_index = None
|
||||
if self.index_params is not None:
|
||||
if self.index_params.get("cpu_quota") is not None:
|
||||
cpu_quota = self.index_params["cpu_quota"]
|
||||
if self.index_params.get("vector_index") is not None:
|
||||
vector_index = self.index_params["vector_index"]
|
||||
if self.index_params.get("partition_by") is not None:
|
||||
partition_by = self.index_params["partition_by"]
|
||||
if self.index_params.get("scalar_index") is not None:
|
||||
scalar_index = self.index_params["scalar_index"]
|
||||
|
||||
self.index = self.service.create_index(
|
||||
self.collection_name,
|
||||
self.index_name,
|
||||
vector_index=vector_index,
|
||||
cpu_quota=cpu_quota,
|
||||
partition_by=partition_by,
|
||||
scalar_index=scalar_index,
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
batch_size: int = 1000,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
try:
|
||||
from volcengine.viking_db import Data
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import volcengine python package. "
|
||||
"Please install it with `pip install --upgrade volcengine`."
|
||||
)
|
||||
texts = list(texts)
|
||||
try:
|
||||
embeddings = self.embedding_func.embed_documents(texts)
|
||||
except NotImplementedError:
|
||||
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||||
if len(embeddings) == 0:
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
if self.collection is None:
|
||||
self._create_collection(embeddings, metadatas)
|
||||
self._create_index()
|
||||
|
||||
# insert data
|
||||
data = []
|
||||
pks: List[str] = []
|
||||
for index in range(len(embeddings)):
|
||||
primary_key = str(uuid.uuid4())
|
||||
pks.append(primary_key)
|
||||
field = {
|
||||
"text": texts[index],
|
||||
"primary_key": primary_key,
|
||||
"vector": embeddings[index],
|
||||
}
|
||||
if metadatas is not None and index < len(metadatas):
|
||||
names = list(metadatas[index].keys())
|
||||
for name in names:
|
||||
field[name] = metadatas[index].get(name)
|
||||
data.append(Data(field))
|
||||
|
||||
total_count = len(data)
|
||||
for i in range(0, total_count, batch_size):
|
||||
end = min(i + batch_size, total_count)
|
||||
insert_data = data[i:end]
|
||||
# print(insert_data)
|
||||
self.collection.upsert_data(insert_data)
|
||||
return pks
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
res = self.similarity_search_with_score(query=query, params=params, **kwargs)
|
||||
return [doc for doc, _ in res]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embedding = self.embedding_func.embed_query(query)
|
||||
|
||||
res = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, params=params, **kwargs
|
||||
)
|
||||
return res
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
params: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
res = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, params=params, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in res]
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
params: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
if self.collection is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
filter = None
|
||||
limit = 10
|
||||
output_fields = None
|
||||
partition = "default"
|
||||
if params is not None:
|
||||
if params.get("filter") is not None:
|
||||
filter = params["filter"]
|
||||
if params.get("limit") is not None:
|
||||
limit = params["limit"]
|
||||
if params.get("output_fields") is not None:
|
||||
output_fields = params["output_fields"]
|
||||
if params.get("partition") is not None:
|
||||
partition = params["partition"]
|
||||
|
||||
res = self.index.search_by_vector(
|
||||
embedding,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
output_fields=output_fields,
|
||||
partition=partition,
|
||||
)
|
||||
|
||||
ret = []
|
||||
for item in res:
|
||||
item.fields.pop("primary_key")
|
||||
item.fields.pop("vector")
|
||||
page_content = item.fields.pop("text")
|
||||
doc = Document(page_content=page_content, metadata=item.fields)
|
||||
pair = (doc, item.score)
|
||||
ret.append(pair)
|
||||
return ret
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
lambda_mult: float = 0.5,
|
||||
params: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
embedding = self.embedding_func.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
lambda_mult: float = 0.5,
|
||||
params: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.collection is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
filter = None
|
||||
limit = 10
|
||||
output_fields = None
|
||||
partition = "default"
|
||||
if params is not None:
|
||||
if params.get("filter") is not None:
|
||||
filter = params["filter"]
|
||||
if params.get("limit") is not None:
|
||||
limit = params["limit"]
|
||||
if params.get("output_fields") is not None:
|
||||
output_fields = params["output_fields"]
|
||||
if params.get("partition") is not None:
|
||||
partition = params["partition"]
|
||||
|
||||
res = self.index.search_by_vector(
|
||||
embedding,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
output_fields=output_fields,
|
||||
partition=partition,
|
||||
)
|
||||
documents = []
|
||||
ordered_result_embeddings = []
|
||||
for item in res:
|
||||
ordered_result_embeddings.append(item.fields.pop("vector"))
|
||||
item.fields.pop("primary_key")
|
||||
page_content = item.fields.pop("text")
|
||||
doc = Document(page_content=page_content, metadata=item.fields)
|
||||
documents.append(doc)
|
||||
|
||||
new_ordering = maximal_marginal_relevance(
|
||||
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
# Reorder the values and return.
|
||||
ret = []
|
||||
for x in new_ordering:
|
||||
# Function can return -1 index
|
||||
if x == -1:
|
||||
break
|
||||
else:
|
||||
ret.append(documents[x])
|
||||
return ret
|
||||
|
||||
def delete(
|
||||
self,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.collection is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
self.collection.delete_data(ids)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
connection_args: Optional[VikingDBConfig] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = "LangChainCollection",
|
||||
index_params: Optional[dict] = None,
|
||||
drop_old: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if connection_args is None:
|
||||
raise Exception("VikingDBConfig does not exists")
|
||||
vector_db = cls(
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
connection_args=connection_args,
|
||||
index_params=index_params,
|
||||
drop_old=drop_old,
|
||||
**kwargs,
|
||||
)
|
||||
vector_db.add_texts(texts=texts, metadatas=metadatas)
|
||||
return vector_db
|
Loading…
Reference in New Issue
Block a user