From 061e63eef201dc11be1003bc113309769325d847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E8=BF=9C?= <90301759+19374242@users.noreply.github.com> Date: Tue, 16 Jan 2024 04:34:01 +0800 Subject: [PATCH] community[minor]: add vikingdb vecstore (#15155) --------- Co-authored-by: gaoyuan --- .../integrations/vectorstores/vikingdb.ipynb | 248 ++++++++++++ .../vectorstores/vikngdb.py | 375 ++++++++++++++++++ 2 files changed, 623 insertions(+) create mode 100644 docs/docs/integrations/vectorstores/vikingdb.ipynb create mode 100644 libs/community/langchain_community/vectorstores/vikngdb.py diff --git a/docs/docs/integrations/vectorstores/vikingdb.ipynb b/docs/docs/integrations/vectorstores/vikingdb.ipynb new file mode 100644 index 00000000000..66ab177efd5 --- /dev/null +++ b/docs/docs/integrations/vectorstores/vikingdb.ipynb @@ -0,0 +1,248 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "96ff9e912bfe9d8", + "metadata": { + "collapsed": false + }, + "source": [ + "# viking DB\n", + "\n", + ">[viking DB](https://www.volcengine.com/docs/6459/1163946) is a database that stores, indexes, and manages massive embedding vectors generated by deep neural networks and other machine learning (ML) models.\n", + "\n", + "This notebook shows how to use functionality related to the VikingDB vector database.\n", + "\n", + "To run, you should have a [viking DB instance up and running](https://www.volcengine.com/docs/6459/1165058).\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd771e02d8a93a0", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "!pip install --upgrade volcengine" + ] + }, + { + "cell_type": "markdown", + "id": "12719205caed0d18", + "metadata": { + "collapsed": false + }, + "source": [ + "We want to use VikingDBEmbeddings so we have to get the VikingDB API Key." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fbfb32665b4a3640", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-21T09:53:24.186916Z", + "start_time": "2023-12-21T09:53:24.179524Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8c983d329237fa4", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain.document_loaders import TextLoader\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "from langchain.vectorstores.vikingdb import VikingDB, VikingDBConfig\n", + "from langchain_openai import OpenAIEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a4aea2eaeb2261", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "loader = TextLoader(\"./test.txt\")\n", + "documents = loader.load()\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)\n", + "\n", + "embeddings = OpenAIEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfd593f3deabfaf8", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "db = VikingDB.from_documents(\n", + " docs,\n", + " embeddings,\n", + " connection_args=VikingDBConfig(\n", + " host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n", + " ),\n", + " drop_old=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "50e6ee12ca7eec39", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-21T10:01:47.355894Z", + "start_time": "2023-12-21T10:01:47.334789Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = db.similarity_search(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b6b81f5995c79ef0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-21T10:01:47.771478Z", + "start_time": "2023-12-21T10:01:47.731485Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "docs[0].page_content" + ] + }, + { + "cell_type": "markdown", + "id": "a2d932c1290478ee", + "metadata": { + "collapsed": false + }, + "source": [ + "### Compartmentalize the data with viking DB Collections\n", + "\n", + "You can store different unrelated documents in different collections within same viking DB instance to maintain the context" + ] + }, + { + "cell_type": "markdown", + "id": "907de4eb10626d2a", + "metadata": { + "collapsed": false + }, + "source": [ + "Here's how you can create a new collection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f5a59ba40f7985f", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "db = VikingDB.from_documents(\n", + " docs,\n", + " embeddings,\n", + " connection_args=VikingDBConfig(\n", + " host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n", + " ),\n", + " collection_name=\"collection_1\",\n", + " drop_old=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7c8eada37b17d992", + "metadata": { + "collapsed": false + }, + "source": [ + "And here is how you retrieve that stored collection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "883ec678d47c9adc", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "db = VikingDB.from_documents(\n", + " embeddings,\n", + " connection_args=VikingDBConfig(\n", + " host=\"host\", region=\"region\", ak=\"ak\", sk=\"sk\", scheme=\"http\"\n", + " ),\n", + " collection_name=\"collection_1\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2f0be30cfe70083d", + "metadata": { + "collapsed": false + }, + "source": [ + "After retreival you can go on querying it as usual." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/vectorstores/vikngdb.py b/libs/community/langchain_community/vectorstores/vikngdb.py new file mode 100644 index 00000000000..2f235f0bf4c --- /dev/null +++ b/libs/community/langchain_community/vectorstores/vikngdb.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +import logging +import uuid +from typing import Any, List, Optional, Tuple + +import numpy as np +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore + +from langchain_community.vectorstores.utils import maximal_marginal_relevance + +logger = logging.getLogger(__name__) + + +class VikingDBConfig(object): + def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): + self.host = host + self.region = region + self.ak = ak + self.sk = sk + self.scheme = scheme + + +class VikingDB(VectorStore): + def __init__( + self, + embedding_function: Embeddings, + collection_name: str = "LangChainCollection", + connection_args: Optional[VikingDBConfig] = None, + index_params: Optional[dict] = None, + drop_old: Optional[bool] = False, + **kwargs: Any, + ): + try: + from volcengine.viking_db import Collection, VikingDBService + except ImportError: + raise ValueError( + "Could not import volcengine python package. " + "Please install it with `pip install --upgrade volcengine`." + ) + self.embedding_func = embedding_function + self.collection_name = collection_name + self.index_name = "LangChainIndex" + self.connection_args = connection_args + self.index_params = index_params + self.drop_old = drop_old + self.service = VikingDBService( + connection_args.host, + connection_args.region, + connection_args.ak, + connection_args.sk, + connection_args.scheme, + ) + + try: + col = self.service.get_collection(collection_name) + except Exception: + col = None + self.collection = col + self.index = None + if self.collection is not None: + self.index = self.service.get_index(self.collection_name, self.index_name) + + if drop_old and isinstance(self.collection, Collection): + indexes = self.service.list_indexes(collection_name) + for index in indexes: + self.service.drop_index(collection_name, index.index_name) + self.service.drop_collection(collection_name) + self.collection = None + self.index = None + + @property + def embeddings(self) -> Embeddings: + return self.embedding_func + + def _create_collection( + self, embeddings: List, metadatas: Optional[List[dict]] = None + ) -> None: + try: + from volcengine.viking_db import Field, FieldType + except ImportError: + raise ValueError( + "Could not import volcengine python package. " + "Please install it with `pip install --upgrade volcengine`." + ) + dim = len(embeddings[0]) + fields = [] + if metadatas: + for key, value in metadatas[0].items(): + # print(key, value) + if isinstance(value, str): + fields.append(Field(key, FieldType.String)) + if isinstance(value, int): + fields.append(Field(key, FieldType.Int64)) + if isinstance(value, bool): + fields.append(Field(key, FieldType.Bool)) + if isinstance(value, list) and all( + isinstance(item, str) for item in value + ): + fields.append(Field(key, FieldType.List_String)) + if isinstance(value, list) and all( + isinstance(item, int) for item in value + ): + fields.append(Field(key, FieldType.List_Int64)) + fields.append(Field("text", FieldType.String)) + + fields.append(Field("primary_key", FieldType.String, is_primary_key=True)) + + fields.append(Field("vector", FieldType.Vector, dim=dim)) + + self.collection = self.service.create_collection(self.collection_name, fields) + + def _create_index(self) -> None: + try: + from volcengine.viking_db import VectorIndexParams + except ImportError: + raise ValueError( + "Could not import volcengine python package. " + "Please install it with `pip install --upgrade volcengine`." + ) + cpu_quota = 2 + vector_index = VectorIndexParams() + partition_by = "" + scalar_index = None + if self.index_params is not None: + if self.index_params.get("cpu_quota") is not None: + cpu_quota = self.index_params["cpu_quota"] + if self.index_params.get("vector_index") is not None: + vector_index = self.index_params["vector_index"] + if self.index_params.get("partition_by") is not None: + partition_by = self.index_params["partition_by"] + if self.index_params.get("scalar_index") is not None: + scalar_index = self.index_params["scalar_index"] + + self.index = self.service.create_index( + self.collection_name, + self.index_name, + vector_index=vector_index, + cpu_quota=cpu_quota, + partition_by=partition_by, + scalar_index=scalar_index, + ) + + def add_texts( + self, + texts: List[str], + metadatas: Optional[List[dict]] = None, + batch_size: int = 1000, + **kwargs: Any, + ) -> List[str]: + try: + from volcengine.viking_db import Data + except ImportError: + raise ValueError( + "Could not import volcengine python package. " + "Please install it with `pip install --upgrade volcengine`." + ) + texts = list(texts) + try: + embeddings = self.embedding_func.embed_documents(texts) + except NotImplementedError: + embeddings = [self.embedding_func.embed_query(x) for x in texts] + if len(embeddings) == 0: + logger.debug("Nothing to insert, skipping.") + return [] + if self.collection is None: + self._create_collection(embeddings, metadatas) + self._create_index() + + # insert data + data = [] + pks: List[str] = [] + for index in range(len(embeddings)): + primary_key = str(uuid.uuid4()) + pks.append(primary_key) + field = { + "text": texts[index], + "primary_key": primary_key, + "vector": embeddings[index], + } + if metadatas is not None and index < len(metadatas): + names = list(metadatas[index].keys()) + for name in names: + field[name] = metadatas[index].get(name) + data.append(Data(field)) + + total_count = len(data) + for i in range(0, total_count, batch_size): + end = min(i + batch_size, total_count) + insert_data = data[i:end] + # print(insert_data) + self.collection.upsert_data(insert_data) + return pks + + def similarity_search( + self, + query: str, + params: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + res = self.similarity_search_with_score(query=query, params=params, **kwargs) + return [doc for doc, _ in res] + + def similarity_search_with_score( + self, + query: str, + params: Optional[dict] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + embedding = self.embedding_func.embed_query(query) + + res = self.similarity_search_with_score_by_vector( + embedding=embedding, params=params, **kwargs + ) + return res + + def similarity_search_by_vector( + self, + embedding: List[float], + params: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + res = self.similarity_search_with_score_by_vector( + embedding=embedding, params=params, **kwargs + ) + return [doc for doc, _ in res] + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + params: Optional[dict] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + if self.collection is None: + logger.debug("No existing collection to search.") + return [] + + filter = None + limit = 10 + output_fields = None + partition = "default" + if params is not None: + if params.get("filter") is not None: + filter = params["filter"] + if params.get("limit") is not None: + limit = params["limit"] + if params.get("output_fields") is not None: + output_fields = params["output_fields"] + if params.get("partition") is not None: + partition = params["partition"] + + res = self.index.search_by_vector( + embedding, + filter=filter, + limit=limit, + output_fields=output_fields, + partition=partition, + ) + + ret = [] + for item in res: + item.fields.pop("primary_key") + item.fields.pop("vector") + page_content = item.fields.pop("text") + doc = Document(page_content=page_content, metadata=item.fields) + pair = (doc, item.score) + ret.append(pair) + return ret + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + lambda_mult: float = 0.5, + params: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + embedding = self.embedding_func.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding=embedding, + k=k, + lambda_mult=lambda_mult, + params=params, + **kwargs, + ) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + lambda_mult: float = 0.5, + params: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + if self.collection is None: + logger.debug("No existing collection to search.") + return [] + filter = None + limit = 10 + output_fields = None + partition = "default" + if params is not None: + if params.get("filter") is not None: + filter = params["filter"] + if params.get("limit") is not None: + limit = params["limit"] + if params.get("output_fields") is not None: + output_fields = params["output_fields"] + if params.get("partition") is not None: + partition = params["partition"] + + res = self.index.search_by_vector( + embedding, + filter=filter, + limit=limit, + output_fields=output_fields, + partition=partition, + ) + documents = [] + ordered_result_embeddings = [] + for item in res: + ordered_result_embeddings.append(item.fields.pop("vector")) + item.fields.pop("primary_key") + page_content = item.fields.pop("text") + doc = Document(page_content=page_content, metadata=item.fields) + documents.append(doc) + + new_ordering = maximal_marginal_relevance( + np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult + ) + # Reorder the values and return. + ret = [] + for x in new_ordering: + # Function can return -1 index + if x == -1: + break + else: + ret.append(documents[x]) + return ret + + def delete( + self, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + if self.collection is None: + logger.debug("No existing collection to search.") + self.collection.delete_data(ids) + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + connection_args: Optional[VikingDBConfig] = None, + metadatas: Optional[List[dict]] = None, + collection_name: str = "LangChainCollection", + index_params: Optional[dict] = None, + drop_old: bool = False, + **kwargs: Any, + ): + if connection_args is None: + raise Exception("VikingDBConfig does not exists") + vector_db = cls( + embedding_function=embedding, + collection_name=collection_name, + connection_args=connection_args, + index_params=index_params, + drop_old=drop_old, + **kwargs, + ) + vector_db.add_texts(texts=texts, metadatas=metadatas) + return vector_db