Merge branch 'main' of https://github.com/csunny/DB-GPT into dbgpt_doc

This commit is contained in:
csunny 2023-05-24 10:36:18 +08:00
commit 03898ee7d3
13 changed files with 242 additions and 156 deletions

View File

@ -81,3 +81,14 @@ DENYLISTED_PLUGINS=
#*******************************************************************# #*******************************************************************#
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False) # CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
# CHAT_MESSAGES_ENABLED=False # CHAT_MESSAGES_ENABLED=False
#*******************************************************************#
#** VECTOR STORE SETTINGS **#
#*******************************************************************#
VECTOR_STORE_TYPE=Chroma
#MILVUS_URL=127.0.0.1
#MILVUS_PORT=19530
#MILVUS_USERNAME
#MILVUS_PASSWORD
#MILVUS_SECURE=

View File

@ -109,6 +109,14 @@ class Config(metaclass=Singleton):
self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)) self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT))
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True" self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
### Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
def set_debug_mode(self, value: bool) -> None: def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value""" """Set the debug mode value"""
self.debug_mode = value self.debug_mode = value

View File

@ -48,5 +48,3 @@ VECTOR_SEARCH_TOP_K = 10
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100 KNOWLEDGE_CHUNK_SPLIT_SIZE = 100
VECTOR_STORE_TYPE = "milvus"
VECTOR_STORE_CONFIG = {"url": "127.0.0.1", "port": "19530"}

View File

@ -19,8 +19,7 @@ from langchain import PromptTemplate
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, \ from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
VECTOR_STORE_CONFIG
from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql import MySQLOperator from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
@ -268,13 +267,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
skip_echo_len = len(prompt.replace("</s>", " ")) + 1 skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector: if mode == conversation_types["custome"] and not db_selector:
# persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"])
print("vector store type: ", VECTOR_STORE_CONFIG)
print("vector store name: ", vector_store_name["vs_name"]) print("vector store name: ", vector_store_name["vs_name"])
vector_store_config = VECTOR_STORE_CONFIG vector_store_config = {"vector_store_name": vector_store_name["vs_name"], "text_field": "content",
vector_store_config["vector_store_name"] = vector_store_name["vs_name"] "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}
vector_store_config["text_field"] = "content"
vector_store_config["vector_store_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False, local_persist=False,
vector_store_config=vector_store_config) vector_store_config=vector_store_config)

View File

@ -1,9 +1,10 @@
import os import os
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from langchain.document_loaders import TextLoader, markdown from langchain.document_loaders import TextLoader, markdown, PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from pilot.configs.config import Config
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.csv_embedding import CSVEmbedding
@ -11,9 +12,9 @@ from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding
import markdown import markdown
from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader from pilot.vector_store.connector import VectorStoreConnector
from pilot.vector_store.milvus_store import MilvusStore
CFG = Config()
class KnowledgeEmbedding: class KnowledgeEmbedding:
def __init__(self, file_path, model_name, vector_store_config, local_persist=True): def __init__(self, file_path, model_name, vector_store_config, local_persist=True):
@ -23,6 +24,7 @@ class KnowledgeEmbedding:
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.file_type = "default" self.file_type = "default"
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
self.local_persist = local_persist self.local_persist = local_persist
if not self.local_persist: if not self.local_persist:
self.knowledge_embedding_client = self.init_knowledge_embedding() self.knowledge_embedding_client = self.init_knowledge_embedding()
@ -52,35 +54,10 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.similar_search(text, topk) return self.knowledge_embedding_client.similar_search(text, topk)
def knowledge_persist_initialization(self, append_mode): def knowledge_persist_initialization(self, append_mode):
vector_name = self.vector_store_config["vector_store_name"]
documents = self._load_knownlege(self.file_path) documents = self._load_knownlege(self.file_path)
if self.vector_store_config["vector_store_type"] == "Chroma": self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config)
persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb") self.vector_client.load_document(documents)
print("vector db path: ", persist_dir) return self.vector_client
if os.path.exists(persist_dir):
if append_mode:
print("append knowledge return vector store")
new_documents = self._load_knownlege(self.file_path)
vector_store = Chroma.from_documents(documents=new_documents,
embedding=self.embeddings,
persist_directory=persist_dir)
else:
print("directly return vector store")
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
else:
print(vector_name + " is new vector store, knowledge begin load...")
vector_store = Chroma.from_documents(documents=documents,
embedding=self.embeddings,
persist_directory=persist_dir)
vector_store.persist()
elif self.vector_store_config["vector_store_type"] == "milvus":
vector_store = MilvusStore({"url": self.vector_store_config["url"],
"port": self.vector_store_config["port"],
"embedding": self.embeddings})
vector_store.init_schema_and_load(vector_name, documents)
return vector_store
def _load_knownlege(self, path): def _load_knownlege(self, path):
docments = [] docments = []
@ -111,7 +88,7 @@ class KnowledgeEmbedding:
docs[i].page_content = docs[i].page_content.replace("\n", " ") docs[i].page_content = docs[i].page_content.replace("\n", " ")
i += 1 i += 1
elif filename.lower().endswith(".pdf"): elif filename.lower().endswith(".pdf"):
loader = UnstructuredPaddlePDFLoader(filename) loader = PyPDFLoader(filename)
textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
docs = loader.load_and_split(textsplitter) docs = loader.load_and_split(textsplitter)
i = 0 i = 0

View File

@ -2,12 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import List from typing import List
from langchain.document_loaders import PyPDFLoader
from langchain.schema import Document from langchain.schema import Document
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader
class PDFEmbedding(SourceEmbedding): class PDFEmbedding(SourceEmbedding):
@ -23,7 +23,8 @@ class PDFEmbedding(SourceEmbedding):
@register @register
def read(self): def read(self):
"""Load from pdf path.""" """Load from pdf path."""
loader = UnstructuredPaddlePDFLoader(self.file_path) # loader = UnstructuredPaddlePDFLoader(self.file_path)
loader = PyPDFLoader(self.file_path)
textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)

View File

@ -1,19 +1,15 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.vectorstores import Milvus
from typing import List, Optional, Dict from typing import List, Optional, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import VECTOR_STORE_TYPE, VECTOR_STORE_CONFIG from pilot.vector_store.connector import VectorStoreConnector
from pilot.vector_store.milvus_store import MilvusStore
registered_methods = [] registered_methods = []
CFG = Config()
def register(method): def register(method):
@ -35,19 +31,8 @@ class SourceEmbedding(ABC):
self.embedding_args = embedding_args self.embedding_args = embedding_args
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
if VECTOR_STORE_TYPE == "milvus": vector_store_config["embeddings"] = self.embeddings
print(VECTOR_STORE_CONFIG) self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, vector_store_config)
if self.vector_store_config.get("text_field") is None:
self.vector_store_client = MilvusStore({"url": VECTOR_STORE_CONFIG["url"],
"port": VECTOR_STORE_CONFIG["port"],
"embedding": self.embeddings})
else:
self.vector_store_client = Milvus(embedding_function=self.embeddings, collection_name=self.vector_store_config["vector_store_name"], text_field="content",
connection_args={"host": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"]})
else:
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb")
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
@abstractmethod @abstractmethod
@register @register
@ -70,24 +55,12 @@ class SourceEmbedding(ABC):
@register @register
def index_to_store(self, docs): def index_to_store(self, docs):
"""index to vector store""" """index to vector store"""
self.vector_client.load_document(docs)
if VECTOR_STORE_TYPE == "chroma":
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb")
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
self.vector_store.persist()
elif VECTOR_STORE_TYPE == "milvus":
self.vector_store = MilvusStore({"url": VECTOR_STORE_CONFIG["url"],
"port": VECTOR_STORE_CONFIG["port"],
"embedding": self.embeddings})
self.vector_store.init_schema_and_load(self.vector_store_config["vector_store_name"], docs)
@register @register
def similar_search(self, doc, topk): def similar_search(self, doc, topk):
"""vector store similarity_search""" """vector store similarity_search"""
return self.vector_client.similar_search(doc, topk)
return self.vector_store_client.similarity_search(doc, topk)
def source_embedding(self): def source_embedding(self):
if 'read' in registered_methods: if 'read' in registered_methods:

View File

@ -0,0 +1,30 @@
import os
from langchain.vectorstores import Chroma
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger
from pilot.vector_store.vector_store_base import VectorStoreBase
class ChromaStore(VectorStoreBase):
"""chroma database"""
def __init__(self, ctx: {}) -> None:
self.ctx = ctx
self.embeddings = ctx["embeddings"]
self.persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH,
ctx["vector_store_name"] + ".vectordb")
self.vector_store_client = Chroma(persist_directory=self.persist_dir, embedding_function=self.embeddings)
def similar_search(self, text, topk) -> None:
logger.info("ChromaStore similar search")
return self.vector_store_client.similarity_search(text, topk)
def load_document(self, documents):
logger.info("ChromaStore load document")
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
self.vector_store_client.persist()

View File

@ -0,0 +1,22 @@
from pilot.vector_store.chroma_store import ChromaStore
from pilot.vector_store.milvus_store import MilvusStore
connector = {
"Chroma": ChromaStore,
"Milvus": MilvusStore
}
class VectorStoreConnector:
""" vector store connector, can connect different vector db provided load document api and similar search api
"""
def __init__(self, vector_store_type, ctx: {}) -> None:
self.ctx = ctx
self.connector_class = connector[vector_store_type]
self.client = self.connector_class(ctx)
def load_document(self, docs):
self.client.load_document(docs)
def similar_search(self, docs, topk):
return self.client.similar_search(docs, topk)

View File

@ -1,12 +1,15 @@
from typing import List, Optional, Iterable from typing import List, Optional, Iterable, Tuple, Any
from langchain.embeddings import HuggingFaceEmbeddings from pymilvus import connections, Collection, DataType
from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection
from langchain.docstore.document import Document
from pilot.configs.config import Config
from pilot.vector_store.vector_store_base import VectorStoreBase from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config()
class MilvusStore(VectorStoreBase): class MilvusStore(VectorStoreBase):
"""Milvus database"""
def __init__(self, ctx: {}) -> None: def __init__(self, ctx: {}) -> None:
"""init a milvus storage connection. """init a milvus storage connection.
@ -16,15 +19,13 @@ class MilvusStore(VectorStoreBase):
# self.configure(cfg) # self.configure(cfg)
connect_kwargs = {} connect_kwargs = {}
self.uri = None self.uri = CFG.MILVUS_URL
self.uri = ctx["url"] self.port = CFG.MILVUS_PORT
self.port = ctx["port"] self.username = CFG.MILVUS_USERNAME
self.username = ctx.get("username", None) self.password = CFG.MILVUS_PASSWORD
self.password = ctx.get("password", None) self.collection_name = ctx.get("vector_store_name", None)
self.collection_name = ctx.get("table_name", None)
self.secure = ctx.get("secure", None) self.secure = ctx.get("secure", None)
self.model_config = ctx.get("model_config", None) self.embedding = ctx.get("embeddings", None)
self.embedding = ctx.get("embedding", None)
self.fields = [] self.fields = []
# use HNSW by default. # use HNSW by default.
@ -33,6 +34,20 @@ class MilvusStore(VectorStoreBase):
"index_type": "HNSW", "index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64}, "params": {"M": 8, "efConstruction": 64},
} }
# use HNSW by default.
self.index_params_map = {
"IVF_FLAT": {"params": {"nprobe": 10}},
"IVF_SQ8": {"params": {"nprobe": 10}},
"IVF_PQ": {"params": {"nprobe": 10}},
"HNSW": {"params": {"ef": 10}},
"RHNSW_FLAT": {"params": {"ef": 10}},
"RHNSW_SQ": {"params": {"ef": 10}},
"RHNSW_PQ": {"params": {"ef": 10}},
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"params": {"search_k": 10}},
}
self.text_field = "content"
if (self.username is None) != (self.password is None): if (self.username is None) != (self.password is None):
raise ValueError( raise ValueError(
@ -48,21 +63,6 @@ class MilvusStore(VectorStoreBase):
alias="default" alias="default"
# secure=self.secure, # secure=self.secure,
) )
if self.collection_name is not None:
self.col = Collection(self.collection_name)
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
# self.init_schema()
# self.init_collection_schema()
def init_schema_and_load(self, vector_name, documents): def init_schema_and_load(self, vector_name, documents):
"""Create a Milvus collection, indexes it with HNSW, load document. """Create a Milvus collection, indexes it with HNSW, load document.
@ -86,7 +86,6 @@ class MilvusStore(VectorStoreBase):
"Could not import pymilvus python package. " "Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`." "Please install it with `pip install pymilvus`."
) )
# Connect to Milvus instance
if not connections.has_connection("default"): if not connections.has_connection("default"):
connections.connect( connections.connect(
host=self.uri or "127.0.0.1", host=self.uri or "127.0.0.1",
@ -140,11 +139,11 @@ class MilvusStore(VectorStoreBase):
fields.append( fields.append(
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1) FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
) )
# Create the primary key field # create the primary key field
fields.append( fields.append(
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True) FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
) )
# Create the vector field # create the vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
# Create the schema for the collection # Create the schema for the collection
schema = CollectionSchema(fields) schema = CollectionSchema(fields)
@ -176,32 +175,44 @@ class MilvusStore(VectorStoreBase):
return self.collection_name return self.collection_name
def init_schema(self) -> None: # def init_schema(self) -> None:
"""Initialize collection in milvus database.""" # """Initialize collection in milvus database."""
fields = [ # fields = [
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), # FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]), # FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535), # FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
] # ]
#
# create collection if not exist and load it. # # create collection if not exist and load it.
self.schema = CollectionSchema(fields, "db-gpt memory storage") # self.schema = CollectionSchema(fields, "db-gpt memory storage")
self.collection = Collection(self.collection_name, self.schema) # self.collection = Collection(self.collection_name, self.schema)
self.index_params = { # self.index_params_map = {
"metric_type": "IP", # "IVF_FLAT": {"params": {"nprobe": 10}},
"index_type": "HNSW", # "IVF_SQ8": {"params": {"nprobe": 10}},
"params": {"M": 8, "efConstruction": 64}, # "IVF_PQ": {"params": {"nprobe": 10}},
} # "HNSW": {"params": {"ef": 10}},
# create index if not exist. # "RHNSW_FLAT": {"params": {"ef": 10}},
if not self.collection.has_index(): # "RHNSW_SQ": {"params": {"ef": 10}},
self.collection.release() # "RHNSW_PQ": {"params": {"ef": 10}},
self.collection.create_index( # "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"vector", # "ANNOY": {"params": {"search_k": 10}},
self.index_params, # }
index_name="vector", #
) # self.index_params = {
info = self.collection.describe() # "metric_type": "IP",
self.collection.load() # "index_type": "HNSW",
# "params": {"M": 8, "efConstruction": 64},
# }
# # create index if not exist.
# if not self.collection.has_index():
# self.collection.release()
# self.collection.create_index(
# "vector",
# self.index_params,
# index_name="vector",
# )
# info = self.collection.describe()
# self.collection.load()
# def insert(self, text, model_config) -> str: # def insert(self, text, model_config) -> str:
# """Add an embedding of data into milvus. # """Add an embedding of data into milvus.
@ -226,17 +237,7 @@ class MilvusStore(VectorStoreBase):
partition_name: Optional[str] = None, partition_name: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
) -> List[str]: ) -> List[str]:
"""Insert text data into Milvus. """add text data into Milvus.
Args:
texts (Iterable[str]): The text being embedded and inserted.
metadatas (Optional[List[dict]], optional): The metadata that
corresponds to each insert. Defaults to None.
partition_name (str, optional): The partition of the collection
to insert data into. Defaults to None.
timeout: specified timeout.
Returns:
List[str]: The resulting keys for each inserted element.
""" """
insert_dict: Any = {self.text_field: list(texts)} insert_dict: Any = {self.text_field: list(texts)}
try: try:
@ -259,6 +260,72 @@ class MilvusStore(VectorStoreBase):
res = self.col.insert( res = self.col.insert(
insert_list, partition_name=partition_name, timeout=timeout insert_list, partition_name=partition_name, timeout=timeout
) )
# Flush to make sure newly inserted is immediately searchable. # make sure data is searchable.
self.col.flush() self.col.flush()
return res.primary_keys return res.primary_keys
def load_document(self, documents) -> None:
"""load document in vector database."""
self.init_schema_and_load(self.collection_name, documents)
def similar_search(self, text, topk) -> None:
"""similar_search in vector database."""
self.col = Collection(self.collection_name)
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
return [doc for doc, _, _ in docs_and_scores]
def _search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
partition_names: Optional[List[str]] = None,
round_decimal: int = -1,
timeout: Optional[int] = None,
**kwargs: Any,
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
self.col.load()
# use default index params.
if param is None:
index_type = self.col.indexes[0].params["index_type"]
param = self.index_params_map[index_type]
# query text embedding.
data = [self.embedding.embed_query(query)]
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self.vector_field)
# milvus search.
res = self.col.search(
data,
self.vector_field,
param,
k,
expr=expr,
output_fields=output_fields,
partition_names=partition_names,
round_decimal=round_decimal,
timeout=timeout,
**kwargs,
)
ret = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
ret.append(
(
Document(page_content=meta.pop(self.text_field), metadata=meta),
result.distance,
result.id,
)
)
return data[0], ret

View File

@ -2,8 +2,14 @@ from abc import ABC, abstractmethod
class VectorStoreBase(ABC): class VectorStoreBase(ABC):
"""base class for vector store database"""
@abstractmethod @abstractmethod
def init_schema(self) -> None: def load_document(self, documents) -> None:
"""load document in vector database."""
pass
@abstractmethod
def similar_search(self, text, topk) -> None:
"""Initialize schema in vector database.""" """Initialize schema in vector database."""
pass pass

View File

@ -61,7 +61,7 @@ gTTS==2.3.1
langchain langchain
nltk nltk
python-dotenv==1.0.0 python-dotenv==1.0.0
pymilvus pymilvus==2.2.1
vcrpy vcrpy
chromadb chromadb
markdown2 markdown2

View File

@ -2,10 +2,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import argparse import argparse
from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG, \
from langchain.vectorstores import Milvus VECTOR_STORE_TYPE
from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
@ -42,8 +40,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
vector_name = args.vector_name vector_name = args.vector_name
append_mode = args.append append_mode = args.append
store_type = args.store_type store_type = VECTOR_STORE_TYPE
vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name, "vector_store_type":store_type} vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name}
print(vector_store_config) print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)