feature:vector store connector

This commit is contained in:
aries-ckt
2023-05-23 10:50:43 +08:00
parent b70cb8076d
commit 983a00f53a
8 changed files with 209 additions and 128 deletions

View File

@@ -48,5 +48,7 @@ VECTOR_SEARCH_TOP_K = 10
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100
VECTOR_STORE_TYPE = "milvus"
#vector db type, now provided Chroma and Milvus
VECTOR_STORE_TYPE = "Milvus"
#vector db config
VECTOR_STORE_CONFIG = {"url": "127.0.0.1", "port": "19530"}

View File

@@ -3,8 +3,7 @@ import os
from bs4 import BeautifulSoup
from langchain.document_loaders import TextLoader, markdown
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE, VECTOR_STORE_TYPE
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
@@ -12,7 +11,7 @@ from pilot.source_embedding.pdf_embedding import PDFEmbedding
import markdown
from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader
from pilot.vector_store.milvus_store import MilvusStore
from pilot.vector_store.connector import VectorStoreConnector
class KnowledgeEmbedding:
@@ -23,6 +22,7 @@ class KnowledgeEmbedding:
self.vector_store_config = vector_store_config
self.file_type = "default"
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
self.local_persist = local_persist
if not self.local_persist:
self.knowledge_embedding_client = self.init_knowledge_embedding()
@@ -52,35 +52,10 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.similar_search(text, topk)
def knowledge_persist_initialization(self, append_mode):
vector_name = self.vector_store_config["vector_store_name"]
documents = self._load_knownlege(self.file_path)
if self.vector_store_config["vector_store_type"] == "Chroma":
persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb")
print("vector db path: ", persist_dir)
if os.path.exists(persist_dir):
if append_mode:
print("append knowledge return vector store")
new_documents = self._load_knownlege(self.file_path)
vector_store = Chroma.from_documents(documents=new_documents,
embedding=self.embeddings,
persist_directory=persist_dir)
else:
print("directly return vector store")
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
else:
print(vector_name + " is new vector store, knowledge begin load...")
vector_store = Chroma.from_documents(documents=documents,
embedding=self.embeddings,
persist_directory=persist_dir)
vector_store.persist()
elif self.vector_store_config["vector_store_type"] == "milvus":
vector_store = MilvusStore({"url": self.vector_store_config["url"],
"port": self.vector_store_config["port"],
"embedding": self.embeddings})
vector_store.init_schema_and_load(vector_name, documents)
return vector_store
self.vector_client = VectorStoreConnector(VECTOR_STORE_TYPE, self.vector_store_config)
self.vector_client.load_document(documents)
return self.vector_client
def _load_knownlege(self, path):
docments = []

View File

@@ -1,17 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from abc import ABC, abstractmethod
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.vectorstores import Milvus
from typing import List, Optional, Dict
from pilot.configs.model_config import VECTOR_STORE_TYPE, VECTOR_STORE_CONFIG
from pilot.vector_store.milvus_store import MilvusStore
from pilot.configs.model_config import VECTOR_STORE_TYPE
from pilot.vector_store.connector import VectorStoreConnector
registered_methods = []
@@ -35,19 +29,8 @@ class SourceEmbedding(ABC):
self.embedding_args = embedding_args
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
if VECTOR_STORE_TYPE == "milvus":
print(VECTOR_STORE_CONFIG)
if self.vector_store_config.get("text_field") is None:
self.vector_store_client = MilvusStore({"url": VECTOR_STORE_CONFIG["url"],
"port": VECTOR_STORE_CONFIG["port"],
"embedding": self.embeddings})
else:
self.vector_store_client = Milvus(embedding_function=self.embeddings, collection_name=self.vector_store_config["vector_store_name"], text_field="content",
connection_args={"host": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"]})
else:
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb")
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
vector_store_config["embeddings"] = self.embeddings
self.vector_client = VectorStoreConnector(VECTOR_STORE_TYPE, vector_store_config)
@abstractmethod
@register
@@ -70,24 +53,12 @@ class SourceEmbedding(ABC):
@register
def index_to_store(self, docs):
"""index to vector store"""
if VECTOR_STORE_TYPE == "chroma":
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb")
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
self.vector_store.persist()
elif VECTOR_STORE_TYPE == "milvus":
self.vector_store = MilvusStore({"url": VECTOR_STORE_CONFIG["url"],
"port": VECTOR_STORE_CONFIG["port"],
"embedding": self.embeddings})
self.vector_store.init_schema_and_load(self.vector_store_config["vector_store_name"], docs)
self.vector_client.load_document(docs)
@register
def similar_search(self, doc, topk):
"""vector store similarity_search"""
return self.vector_store_client.similarity_search(doc, topk)
return self.vector_client.similar_search(doc, topk)
def source_embedding(self):
if 'read' in registered_methods:

View File

@@ -0,0 +1,30 @@
import os
from langchain.vectorstores import Chroma
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger
from pilot.vector_store.vector_store_base import VectorStoreBase
class ChromaStore(VectorStoreBase):
"""chroma database"""
def __init__(self, ctx: {}) -> None:
self.ctx = ctx
self.embeddings = ctx["embeddings"]
self.persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH,
ctx["vector_store_name"] + ".vectordb")
self.vector_store_client = Chroma(persist_directory=self.persist_dir, embedding_function=self.embeddings)
def similar_search(self, text, topk) -> None:
logger.info("ChromaStore similar search")
return self.vector_store_client.similarity_search(text, topk)
def load_document(self, documents):
logger.info("ChromaStore load document")
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
self.vector_store_client.persist()

View File

@@ -0,0 +1,22 @@
from pilot.vector_store.chroma_store import ChromaStore
from pilot.vector_store.milvus_store import MilvusStore
connector = {
"Chroma": ChromaStore,
"Milvus": MilvusStore
}
class VectorStoreConnector:
""" vector store connector, can connect different vector db provided load document api and similar search api
"""
def __init__(self, vector_store_type, ctx: {}) -> None:
self.ctx = ctx
self.connector_class = connector[vector_store_type]
self.client = self.connector_class(ctx)
def load_document(self, docs):
self.client.load_document(docs)
def similar_search(self, docs, topk):
return self.client.similar_search(docs, topk)

View File

@@ -1,12 +1,14 @@
from typing import List, Optional, Iterable
from typing import List, Optional, Iterable, Tuple, Any
from langchain.embeddings import HuggingFaceEmbeddings
from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection
from pymilvus import connections, Collection, DataType
from pilot.configs.model_config import VECTOR_STORE_CONFIG
from langchain.docstore.document import Document
from pilot.vector_store.vector_store_base import VectorStoreBase
class MilvusStore(VectorStoreBase):
"""Milvus database"""
def __init__(self, ctx: {}) -> None:
"""init a milvus storage connection.
@@ -17,14 +19,13 @@ class MilvusStore(VectorStoreBase):
connect_kwargs = {}
self.uri = None
self.uri = ctx["url"]
self.port = ctx["port"]
self.uri = ctx.get("url", VECTOR_STORE_CONFIG["url"])
self.port = ctx.get("port", VECTOR_STORE_CONFIG["port"])
self.username = ctx.get("username", None)
self.password = ctx.get("password", None)
self.collection_name = ctx.get("table_name", None)
self.collection_name = ctx.get("vector_store_name", None)
self.secure = ctx.get("secure", None)
self.model_config = ctx.get("model_config", None)
self.embedding = ctx.get("embedding", None)
self.embedding = ctx.get("embeddings", None)
self.fields = []
# use HNSW by default.
@@ -33,6 +34,20 @@ class MilvusStore(VectorStoreBase):
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
# use HNSW by default.
self.index_params_map = {
"IVF_FLAT": {"params": {"nprobe": 10}},
"IVF_SQ8": {"params": {"nprobe": 10}},
"IVF_PQ": {"params": {"nprobe": 10}},
"HNSW": {"params": {"ef": 10}},
"RHNSW_FLAT": {"params": {"ef": 10}},
"RHNSW_SQ": {"params": {"ef": 10}},
"RHNSW_PQ": {"params": {"ef": 10}},
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"params": {"search_k": 10}},
}
self.text_field = "content"
if (self.username is None) != (self.password is None):
raise ValueError(
@@ -48,21 +63,6 @@ class MilvusStore(VectorStoreBase):
alias="default"
# secure=self.secure,
)
if self.collection_name is not None:
self.col = Collection(self.collection_name)
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
# self.init_schema()
# self.init_collection_schema()
def init_schema_and_load(self, vector_name, documents):
"""Create a Milvus collection, indexes it with HNSW, load document.
@@ -86,7 +86,6 @@ class MilvusStore(VectorStoreBase):
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
# Connect to Milvus instance
if not connections.has_connection("default"):
connections.connect(
host=self.uri or "127.0.0.1",
@@ -140,11 +139,11 @@ class MilvusStore(VectorStoreBase):
fields.append(
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
)
# Create the primary key field
# create the primary key field
fields.append(
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
)
# Create the vector field
# create the vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
# Create the schema for the collection
schema = CollectionSchema(fields)
@@ -176,32 +175,44 @@ class MilvusStore(VectorStoreBase):
return self.collection_name
def init_schema(self) -> None:
"""Initialize collection in milvus database."""
fields = [
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
]
# create collection if not exist and load it.
self.schema = CollectionSchema(fields, "db-gpt memory storage")
self.collection = Collection(self.collection_name, self.schema)
self.index_params = {
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
# create index if not exist.
if not self.collection.has_index():
self.collection.release()
self.collection.create_index(
"vector",
self.index_params,
index_name="vector",
)
info = self.collection.describe()
self.collection.load()
# def init_schema(self) -> None:
# """Initialize collection in milvus database."""
# fields = [
# FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
# FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
# FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
# ]
#
# # create collection if not exist and load it.
# self.schema = CollectionSchema(fields, "db-gpt memory storage")
# self.collection = Collection(self.collection_name, self.schema)
# self.index_params_map = {
# "IVF_FLAT": {"params": {"nprobe": 10}},
# "IVF_SQ8": {"params": {"nprobe": 10}},
# "IVF_PQ": {"params": {"nprobe": 10}},
# "HNSW": {"params": {"ef": 10}},
# "RHNSW_FLAT": {"params": {"ef": 10}},
# "RHNSW_SQ": {"params": {"ef": 10}},
# "RHNSW_PQ": {"params": {"ef": 10}},
# "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
# "ANNOY": {"params": {"search_k": 10}},
# }
#
# self.index_params = {
# "metric_type": "IP",
# "index_type": "HNSW",
# "params": {"M": 8, "efConstruction": 64},
# }
# # create index if not exist.
# if not self.collection.has_index():
# self.collection.release()
# self.collection.create_index(
# "vector",
# self.index_params,
# index_name="vector",
# )
# info = self.collection.describe()
# self.collection.load()
# def insert(self, text, model_config) -> str:
# """Add an embedding of data into milvus.
@@ -226,7 +237,7 @@ class MilvusStore(VectorStoreBase):
partition_name: Optional[str] = None,
timeout: Optional[int] = None,
) -> List[str]:
"""Insert text data into Milvus.
"""add text data into Milvus.
Args:
texts (Iterable[str]): The text being embedded and inserted.
metadatas (Optional[List[dict]], optional): The metadata that
@@ -259,6 +270,72 @@ class MilvusStore(VectorStoreBase):
res = self.col.insert(
insert_list, partition_name=partition_name, timeout=timeout
)
# Flush to make sure newly inserted is immediately searchable.
# make sure data is searchable.
self.col.flush()
return res.primary_keys
def load_document(self, documents) -> None:
"""load document in vector database."""
self.init_schema_and_load(self.collection_name, documents)
def similar_search(self, text, topk) -> None:
self.col = Collection(self.collection_name)
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
return [doc for doc, _, _ in docs_and_scores]
def _search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
partition_names: Optional[List[str]] = None,
round_decimal: int = -1,
timeout: Optional[int] = None,
**kwargs: Any,
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
self.col.load()
# use default index params.
if param is None:
index_type = self.col.indexes[0].params["index_type"]
param = self.index_params_map[index_type]
# query text embedding.
data = [self.embedding.embed_query(query)]
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self.vector_field)
# milvus search.
res = self.col.search(
data,
self.vector_field,
param,
k,
expr=expr,
output_fields=output_fields,
partition_names=partition_names,
round_decimal=round_decimal,
timeout=timeout,
**kwargs,
)
# Organize results.
ret = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
ret.append(
(
Document(page_content=meta.pop(self.text_field), metadata=meta),
result.distance,
result.id,
)
)
return data[0], ret

View File

@@ -2,8 +2,14 @@ from abc import ABC, abstractmethod
class VectorStoreBase(ABC):
"""base class for vector store database"""
@abstractmethod
def init_schema(self) -> None:
def load_document(self, documents) -> None:
"""load document in vector database."""
pass
@abstractmethod
def similar_search(self, text, topk) -> None:
"""Initialize schema in vector database."""
pass

View File

@@ -2,10 +2,8 @@
# -*- coding: utf-8 -*-
import argparse
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Milvus
from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG
from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG, \
VECTOR_STORE_TYPE
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
@@ -42,8 +40,8 @@ if __name__ == "__main__":
args = parser.parse_args()
vector_name = args.vector_name
append_mode = args.append
store_type = args.store_type
vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name, "vector_store_type":store_type}
store_type = VECTOR_STORE_TYPE
vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name}
print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)