mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-18 08:17:38 +00:00
Merge branch 'main' of https://github.com/csunny/DB-GPT into dbgpt_doc
This commit is contained in:
commit
03898ee7d3
@ -81,3 +81,14 @@ DENYLISTED_PLUGINS=
|
|||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
|
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
|
||||||
# CHAT_MESSAGES_ENABLED=False
|
# CHAT_MESSAGES_ENABLED=False
|
||||||
|
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** VECTOR STORE SETTINGS **#
|
||||||
|
#*******************************************************************#
|
||||||
|
VECTOR_STORE_TYPE=Chroma
|
||||||
|
#MILVUS_URL=127.0.0.1
|
||||||
|
#MILVUS_PORT=19530
|
||||||
|
#MILVUS_USERNAME
|
||||||
|
#MILVUS_PASSWORD
|
||||||
|
#MILVUS_SECURE=
|
||||||
|
@ -109,6 +109,14 @@ class Config(metaclass=Singleton):
|
|||||||
self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT))
|
self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT))
|
||||||
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
|
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
|
||||||
|
|
||||||
|
### Vector Store Configuration
|
||||||
|
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
|
||||||
|
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
|
||||||
|
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
|
||||||
|
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||||
|
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||||
|
|
||||||
|
|
||||||
def set_debug_mode(self, value: bool) -> None:
|
def set_debug_mode(self, value: bool) -> None:
|
||||||
"""Set the debug mode value"""
|
"""Set the debug mode value"""
|
||||||
self.debug_mode = value
|
self.debug_mode = value
|
||||||
|
@ -47,6 +47,4 @@ ISDEBUG = False
|
|||||||
VECTOR_SEARCH_TOP_K = 10
|
VECTOR_SEARCH_TOP_K = 10
|
||||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
|
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
|
||||||
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100
|
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100
|
||||||
VECTOR_STORE_TYPE = "milvus"
|
|
||||||
VECTOR_STORE_CONFIG = {"url": "127.0.0.1", "port": "19530"}
|
|
@ -19,8 +19,7 @@ from langchain import PromptTemplate
|
|||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, \
|
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
|
||||||
VECTOR_STORE_CONFIG
|
|
||||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||||
from pilot.connections.mysql import MySQLOperator
|
from pilot.connections.mysql import MySQLOperator
|
||||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||||
@ -268,13 +267,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
if mode == conversation_types["custome"] and not db_selector:
|
if mode == conversation_types["custome"] and not db_selector:
|
||||||
# persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"])
|
|
||||||
print("vector store type: ", VECTOR_STORE_CONFIG)
|
|
||||||
print("vector store name: ", vector_store_name["vs_name"])
|
print("vector store name: ", vector_store_name["vs_name"])
|
||||||
vector_store_config = VECTOR_STORE_CONFIG
|
vector_store_config = {"vector_store_name": vector_store_name["vs_name"], "text_field": "content",
|
||||||
vector_store_config["vector_store_name"] = vector_store_name["vs_name"]
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}
|
||||||
vector_store_config["text_field"] = "content"
|
|
||||||
vector_store_config["vector_store_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
|
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"],
|
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
local_persist=False,
|
local_persist=False,
|
||||||
vector_store_config=vector_store_config)
|
vector_store_config=vector_store_config)
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from langchain.document_loaders import TextLoader, markdown
|
from langchain.document_loaders import TextLoader, markdown, PyPDFLoader
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
from langchain.vectorstores import Chroma
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
|
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||||
@ -11,9 +12,9 @@ from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
|||||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||||
import markdown
|
import markdown
|
||||||
|
|
||||||
from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
from pilot.vector_store.milvus_store import MilvusStore
|
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
class KnowledgeEmbedding:
|
class KnowledgeEmbedding:
|
||||||
def __init__(self, file_path, model_name, vector_store_config, local_persist=True):
|
def __init__(self, file_path, model_name, vector_store_config, local_persist=True):
|
||||||
@ -23,6 +24,7 @@ class KnowledgeEmbedding:
|
|||||||
self.vector_store_config = vector_store_config
|
self.vector_store_config = vector_store_config
|
||||||
self.file_type = "default"
|
self.file_type = "default"
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||||
|
self.vector_store_config["embeddings"] = self.embeddings
|
||||||
self.local_persist = local_persist
|
self.local_persist = local_persist
|
||||||
if not self.local_persist:
|
if not self.local_persist:
|
||||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||||
@ -52,35 +54,10 @@ class KnowledgeEmbedding:
|
|||||||
return self.knowledge_embedding_client.similar_search(text, topk)
|
return self.knowledge_embedding_client.similar_search(text, topk)
|
||||||
|
|
||||||
def knowledge_persist_initialization(self, append_mode):
|
def knowledge_persist_initialization(self, append_mode):
|
||||||
vector_name = self.vector_store_config["vector_store_name"]
|
|
||||||
documents = self._load_knownlege(self.file_path)
|
documents = self._load_knownlege(self.file_path)
|
||||||
if self.vector_store_config["vector_store_type"] == "Chroma":
|
self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config)
|
||||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb")
|
self.vector_client.load_document(documents)
|
||||||
print("vector db path: ", persist_dir)
|
return self.vector_client
|
||||||
if os.path.exists(persist_dir):
|
|
||||||
if append_mode:
|
|
||||||
print("append knowledge return vector store")
|
|
||||||
new_documents = self._load_knownlege(self.file_path)
|
|
||||||
vector_store = Chroma.from_documents(documents=new_documents,
|
|
||||||
embedding=self.embeddings,
|
|
||||||
persist_directory=persist_dir)
|
|
||||||
else:
|
|
||||||
print("directly return vector store")
|
|
||||||
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
|
||||||
else:
|
|
||||||
print(vector_name + " is new vector store, knowledge begin load...")
|
|
||||||
vector_store = Chroma.from_documents(documents=documents,
|
|
||||||
embedding=self.embeddings,
|
|
||||||
persist_directory=persist_dir)
|
|
||||||
vector_store.persist()
|
|
||||||
|
|
||||||
elif self.vector_store_config["vector_store_type"] == "milvus":
|
|
||||||
vector_store = MilvusStore({"url": self.vector_store_config["url"],
|
|
||||||
"port": self.vector_store_config["port"],
|
|
||||||
"embedding": self.embeddings})
|
|
||||||
vector_store.init_schema_and_load(vector_name, documents)
|
|
||||||
|
|
||||||
return vector_store
|
|
||||||
|
|
||||||
def _load_knownlege(self, path):
|
def _load_knownlege(self, path):
|
||||||
docments = []
|
docments = []
|
||||||
@ -111,7 +88,7 @@ class KnowledgeEmbedding:
|
|||||||
docs[i].page_content = docs[i].page_content.replace("\n", " ")
|
docs[i].page_content = docs[i].page_content.replace("\n", " ")
|
||||||
i += 1
|
i += 1
|
||||||
elif filename.lower().endswith(".pdf"):
|
elif filename.lower().endswith(".pdf"):
|
||||||
loader = UnstructuredPaddlePDFLoader(filename)
|
loader = PyPDFLoader(filename)
|
||||||
textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
|
textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
|
||||||
docs = loader.load_and_split(textsplitter)
|
docs = loader.load_and_split(textsplitter)
|
||||||
i = 0
|
i = 0
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.document_loaders import PyPDFLoader
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||||
|
|
||||||
from pilot.source_embedding import SourceEmbedding, register
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||||
from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader
|
|
||||||
|
|
||||||
|
|
||||||
class PDFEmbedding(SourceEmbedding):
|
class PDFEmbedding(SourceEmbedding):
|
||||||
@ -23,7 +23,8 @@ class PDFEmbedding(SourceEmbedding):
|
|||||||
@register
|
@register
|
||||||
def read(self):
|
def read(self):
|
||||||
"""Load from pdf path."""
|
"""Load from pdf path."""
|
||||||
loader = UnstructuredPaddlePDFLoader(self.file_path)
|
# loader = UnstructuredPaddlePDFLoader(self.file_path)
|
||||||
|
loader = PyPDFLoader(self.file_path)
|
||||||
textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
|
textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
|
||||||
return loader.load_and_split(textsplitter)
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
|
@ -1,19 +1,15 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
from langchain.vectorstores import Chroma
|
|
||||||
from langchain.vectorstores import Milvus
|
|
||||||
|
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import VECTOR_STORE_TYPE, VECTOR_STORE_CONFIG
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
from pilot.vector_store.milvus_store import MilvusStore
|
|
||||||
|
|
||||||
registered_methods = []
|
registered_methods = []
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
def register(method):
|
def register(method):
|
||||||
@ -35,19 +31,8 @@ class SourceEmbedding(ABC):
|
|||||||
self.embedding_args = embedding_args
|
self.embedding_args = embedding_args
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||||
|
|
||||||
if VECTOR_STORE_TYPE == "milvus":
|
vector_store_config["embeddings"] = self.embeddings
|
||||||
print(VECTOR_STORE_CONFIG)
|
self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, vector_store_config)
|
||||||
if self.vector_store_config.get("text_field") is None:
|
|
||||||
self.vector_store_client = MilvusStore({"url": VECTOR_STORE_CONFIG["url"],
|
|
||||||
"port": VECTOR_STORE_CONFIG["port"],
|
|
||||||
"embedding": self.embeddings})
|
|
||||||
else:
|
|
||||||
self.vector_store_client = Milvus(embedding_function=self.embeddings, collection_name=self.vector_store_config["vector_store_name"], text_field="content",
|
|
||||||
connection_args={"host": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"]})
|
|
||||||
else:
|
|
||||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
|
||||||
self.vector_store_config["vector_store_name"] + ".vectordb")
|
|
||||||
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@register
|
@register
|
||||||
@ -70,24 +55,12 @@ class SourceEmbedding(ABC):
|
|||||||
@register
|
@register
|
||||||
def index_to_store(self, docs):
|
def index_to_store(self, docs):
|
||||||
"""index to vector store"""
|
"""index to vector store"""
|
||||||
|
self.vector_client.load_document(docs)
|
||||||
if VECTOR_STORE_TYPE == "chroma":
|
|
||||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
|
||||||
self.vector_store_config["vector_store_name"] + ".vectordb")
|
|
||||||
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
|
|
||||||
self.vector_store.persist()
|
|
||||||
|
|
||||||
elif VECTOR_STORE_TYPE == "milvus":
|
|
||||||
self.vector_store = MilvusStore({"url": VECTOR_STORE_CONFIG["url"],
|
|
||||||
"port": VECTOR_STORE_CONFIG["port"],
|
|
||||||
"embedding": self.embeddings})
|
|
||||||
self.vector_store.init_schema_and_load(self.vector_store_config["vector_store_name"], docs)
|
|
||||||
|
|
||||||
@register
|
@register
|
||||||
def similar_search(self, doc, topk):
|
def similar_search(self, doc, topk):
|
||||||
"""vector store similarity_search"""
|
"""vector store similarity_search"""
|
||||||
|
return self.vector_client.similar_search(doc, topk)
|
||||||
return self.vector_store_client.similarity_search(doc, topk)
|
|
||||||
|
|
||||||
def source_embedding(self):
|
def source_embedding(self):
|
||||||
if 'read' in registered_methods:
|
if 'read' in registered_methods:
|
||||||
|
30
pilot/vector_store/chroma_store.py
Normal file
30
pilot/vector_store/chroma_store.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
|
||||||
|
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
|
from pilot.logs import logger
|
||||||
|
from pilot.vector_store.vector_store_base import VectorStoreBase
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaStore(VectorStoreBase):
|
||||||
|
"""chroma database"""
|
||||||
|
|
||||||
|
def __init__(self, ctx: {}) -> None:
|
||||||
|
self.ctx = ctx
|
||||||
|
self.embeddings = ctx["embeddings"]
|
||||||
|
self.persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
ctx["vector_store_name"] + ".vectordb")
|
||||||
|
self.vector_store_client = Chroma(persist_directory=self.persist_dir, embedding_function=self.embeddings)
|
||||||
|
|
||||||
|
def similar_search(self, text, topk) -> None:
|
||||||
|
logger.info("ChromaStore similar search")
|
||||||
|
return self.vector_store_client.similarity_search(text, topk)
|
||||||
|
|
||||||
|
def load_document(self, documents):
|
||||||
|
logger.info("ChromaStore load document")
|
||||||
|
texts = [doc.page_content for doc in documents]
|
||||||
|
metadatas = [doc.metadata for doc in documents]
|
||||||
|
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
|
||||||
|
self.vector_store_client.persist()
|
||||||
|
|
22
pilot/vector_store/connector.py
Normal file
22
pilot/vector_store/connector.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from pilot.vector_store.chroma_store import ChromaStore
|
||||||
|
from pilot.vector_store.milvus_store import MilvusStore
|
||||||
|
|
||||||
|
connector = {
|
||||||
|
"Chroma": ChromaStore,
|
||||||
|
"Milvus": MilvusStore
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStoreConnector:
|
||||||
|
""" vector store connector, can connect different vector db provided load document api and similar search api
|
||||||
|
"""
|
||||||
|
def __init__(self, vector_store_type, ctx: {}) -> None:
|
||||||
|
self.ctx = ctx
|
||||||
|
self.connector_class = connector[vector_store_type]
|
||||||
|
self.client = self.connector_class(ctx)
|
||||||
|
|
||||||
|
def load_document(self, docs):
|
||||||
|
self.client.load_document(docs)
|
||||||
|
|
||||||
|
def similar_search(self, docs, topk):
|
||||||
|
return self.client.similar_search(docs, topk)
|
@ -1,12 +1,15 @@
|
|||||||
from typing import List, Optional, Iterable
|
from typing import List, Optional, Iterable, Tuple, Any
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from pymilvus import connections, Collection, DataType
|
||||||
from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection
|
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
from pilot.vector_store.vector_store_base import VectorStoreBase
|
from pilot.vector_store.vector_store_base import VectorStoreBase
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
class MilvusStore(VectorStoreBase):
|
class MilvusStore(VectorStoreBase):
|
||||||
|
"""Milvus database"""
|
||||||
def __init__(self, ctx: {}) -> None:
|
def __init__(self, ctx: {}) -> None:
|
||||||
"""init a milvus storage connection.
|
"""init a milvus storage connection.
|
||||||
|
|
||||||
@ -16,15 +19,13 @@ class MilvusStore(VectorStoreBase):
|
|||||||
# self.configure(cfg)
|
# self.configure(cfg)
|
||||||
|
|
||||||
connect_kwargs = {}
|
connect_kwargs = {}
|
||||||
self.uri = None
|
self.uri = CFG.MILVUS_URL
|
||||||
self.uri = ctx["url"]
|
self.port = CFG.MILVUS_PORT
|
||||||
self.port = ctx["port"]
|
self.username = CFG.MILVUS_USERNAME
|
||||||
self.username = ctx.get("username", None)
|
self.password = CFG.MILVUS_PASSWORD
|
||||||
self.password = ctx.get("password", None)
|
self.collection_name = ctx.get("vector_store_name", None)
|
||||||
self.collection_name = ctx.get("table_name", None)
|
|
||||||
self.secure = ctx.get("secure", None)
|
self.secure = ctx.get("secure", None)
|
||||||
self.model_config = ctx.get("model_config", None)
|
self.embedding = ctx.get("embeddings", None)
|
||||||
self.embedding = ctx.get("embedding", None)
|
|
||||||
self.fields = []
|
self.fields = []
|
||||||
|
|
||||||
# use HNSW by default.
|
# use HNSW by default.
|
||||||
@ -33,6 +34,20 @@ class MilvusStore(VectorStoreBase):
|
|||||||
"index_type": "HNSW",
|
"index_type": "HNSW",
|
||||||
"params": {"M": 8, "efConstruction": 64},
|
"params": {"M": 8, "efConstruction": 64},
|
||||||
}
|
}
|
||||||
|
# use HNSW by default.
|
||||||
|
self.index_params_map = {
|
||||||
|
"IVF_FLAT": {"params": {"nprobe": 10}},
|
||||||
|
"IVF_SQ8": {"params": {"nprobe": 10}},
|
||||||
|
"IVF_PQ": {"params": {"nprobe": 10}},
|
||||||
|
"HNSW": {"params": {"ef": 10}},
|
||||||
|
"RHNSW_FLAT": {"params": {"ef": 10}},
|
||||||
|
"RHNSW_SQ": {"params": {"ef": 10}},
|
||||||
|
"RHNSW_PQ": {"params": {"ef": 10}},
|
||||||
|
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
||||||
|
"ANNOY": {"params": {"search_k": 10}},
|
||||||
|
}
|
||||||
|
|
||||||
|
self.text_field = "content"
|
||||||
|
|
||||||
if (self.username is None) != (self.password is None):
|
if (self.username is None) != (self.password is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -48,21 +63,6 @@ class MilvusStore(VectorStoreBase):
|
|||||||
alias="default"
|
alias="default"
|
||||||
# secure=self.secure,
|
# secure=self.secure,
|
||||||
)
|
)
|
||||||
if self.collection_name is not None:
|
|
||||||
self.col = Collection(self.collection_name)
|
|
||||||
schema = self.col.schema
|
|
||||||
for x in schema.fields:
|
|
||||||
self.fields.append(x.name)
|
|
||||||
if x.auto_id:
|
|
||||||
self.fields.remove(x.name)
|
|
||||||
if x.is_primary:
|
|
||||||
self.primary_field = x.name
|
|
||||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
|
||||||
self.vector_field = x.name
|
|
||||||
|
|
||||||
|
|
||||||
# self.init_schema()
|
|
||||||
# self.init_collection_schema()
|
|
||||||
|
|
||||||
def init_schema_and_load(self, vector_name, documents):
|
def init_schema_and_load(self, vector_name, documents):
|
||||||
"""Create a Milvus collection, indexes it with HNSW, load document.
|
"""Create a Milvus collection, indexes it with HNSW, load document.
|
||||||
@ -86,7 +86,6 @@ class MilvusStore(VectorStoreBase):
|
|||||||
"Could not import pymilvus python package. "
|
"Could not import pymilvus python package. "
|
||||||
"Please install it with `pip install pymilvus`."
|
"Please install it with `pip install pymilvus`."
|
||||||
)
|
)
|
||||||
# Connect to Milvus instance
|
|
||||||
if not connections.has_connection("default"):
|
if not connections.has_connection("default"):
|
||||||
connections.connect(
|
connections.connect(
|
||||||
host=self.uri or "127.0.0.1",
|
host=self.uri or "127.0.0.1",
|
||||||
@ -140,11 +139,11 @@ class MilvusStore(VectorStoreBase):
|
|||||||
fields.append(
|
fields.append(
|
||||||
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
|
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
|
||||||
)
|
)
|
||||||
# Create the primary key field
|
# create the primary key field
|
||||||
fields.append(
|
fields.append(
|
||||||
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
||||||
)
|
)
|
||||||
# Create the vector field
|
# create the vector field
|
||||||
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||||
# Create the schema for the collection
|
# Create the schema for the collection
|
||||||
schema = CollectionSchema(fields)
|
schema = CollectionSchema(fields)
|
||||||
@ -176,32 +175,44 @@ class MilvusStore(VectorStoreBase):
|
|||||||
|
|
||||||
return self.collection_name
|
return self.collection_name
|
||||||
|
|
||||||
def init_schema(self) -> None:
|
# def init_schema(self) -> None:
|
||||||
"""Initialize collection in milvus database."""
|
# """Initialize collection in milvus database."""
|
||||||
fields = [
|
# fields = [
|
||||||
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
# FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
|
# FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
|
||||||
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
# FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
||||||
]
|
# ]
|
||||||
|
#
|
||||||
# create collection if not exist and load it.
|
# # create collection if not exist and load it.
|
||||||
self.schema = CollectionSchema(fields, "db-gpt memory storage")
|
# self.schema = CollectionSchema(fields, "db-gpt memory storage")
|
||||||
self.collection = Collection(self.collection_name, self.schema)
|
# self.collection = Collection(self.collection_name, self.schema)
|
||||||
self.index_params = {
|
# self.index_params_map = {
|
||||||
"metric_type": "IP",
|
# "IVF_FLAT": {"params": {"nprobe": 10}},
|
||||||
"index_type": "HNSW",
|
# "IVF_SQ8": {"params": {"nprobe": 10}},
|
||||||
"params": {"M": 8, "efConstruction": 64},
|
# "IVF_PQ": {"params": {"nprobe": 10}},
|
||||||
}
|
# "HNSW": {"params": {"ef": 10}},
|
||||||
# create index if not exist.
|
# "RHNSW_FLAT": {"params": {"ef": 10}},
|
||||||
if not self.collection.has_index():
|
# "RHNSW_SQ": {"params": {"ef": 10}},
|
||||||
self.collection.release()
|
# "RHNSW_PQ": {"params": {"ef": 10}},
|
||||||
self.collection.create_index(
|
# "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
||||||
"vector",
|
# "ANNOY": {"params": {"search_k": 10}},
|
||||||
self.index_params,
|
# }
|
||||||
index_name="vector",
|
#
|
||||||
)
|
# self.index_params = {
|
||||||
info = self.collection.describe()
|
# "metric_type": "IP",
|
||||||
self.collection.load()
|
# "index_type": "HNSW",
|
||||||
|
# "params": {"M": 8, "efConstruction": 64},
|
||||||
|
# }
|
||||||
|
# # create index if not exist.
|
||||||
|
# if not self.collection.has_index():
|
||||||
|
# self.collection.release()
|
||||||
|
# self.collection.create_index(
|
||||||
|
# "vector",
|
||||||
|
# self.index_params,
|
||||||
|
# index_name="vector",
|
||||||
|
# )
|
||||||
|
# info = self.collection.describe()
|
||||||
|
# self.collection.load()
|
||||||
|
|
||||||
# def insert(self, text, model_config) -> str:
|
# def insert(self, text, model_config) -> str:
|
||||||
# """Add an embedding of data into milvus.
|
# """Add an embedding of data into milvus.
|
||||||
@ -226,17 +237,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
partition_name: Optional[str] = None,
|
partition_name: Optional[str] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Insert text data into Milvus.
|
"""add text data into Milvus.
|
||||||
Args:
|
|
||||||
texts (Iterable[str]): The text being embedded and inserted.
|
|
||||||
metadatas (Optional[List[dict]], optional): The metadata that
|
|
||||||
corresponds to each insert. Defaults to None.
|
|
||||||
partition_name (str, optional): The partition of the collection
|
|
||||||
to insert data into. Defaults to None.
|
|
||||||
timeout: specified timeout.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: The resulting keys for each inserted element.
|
|
||||||
"""
|
"""
|
||||||
insert_dict: Any = {self.text_field: list(texts)}
|
insert_dict: Any = {self.text_field: list(texts)}
|
||||||
try:
|
try:
|
||||||
@ -259,6 +260,72 @@ class MilvusStore(VectorStoreBase):
|
|||||||
res = self.col.insert(
|
res = self.col.insert(
|
||||||
insert_list, partition_name=partition_name, timeout=timeout
|
insert_list, partition_name=partition_name, timeout=timeout
|
||||||
)
|
)
|
||||||
# Flush to make sure newly inserted is immediately searchable.
|
# make sure data is searchable.
|
||||||
self.col.flush()
|
self.col.flush()
|
||||||
return res.primary_keys
|
return res.primary_keys
|
||||||
|
|
||||||
|
def load_document(self, documents) -> None:
|
||||||
|
"""load document in vector database."""
|
||||||
|
self.init_schema_and_load(self.collection_name, documents)
|
||||||
|
|
||||||
|
def similar_search(self, text, topk) -> None:
|
||||||
|
"""similar_search in vector database."""
|
||||||
|
self.col = Collection(self.collection_name)
|
||||||
|
schema = self.col.schema
|
||||||
|
for x in schema.fields:
|
||||||
|
self.fields.append(x.name)
|
||||||
|
if x.auto_id:
|
||||||
|
self.fields.remove(x.name)
|
||||||
|
if x.is_primary:
|
||||||
|
self.primary_field = x.name
|
||||||
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||||
|
self.vector_field = x.name
|
||||||
|
_, docs_and_scores = self._search(text, topk)
|
||||||
|
return [doc for doc, _, _ in docs_and_scores]
|
||||||
|
|
||||||
|
def _search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
partition_names: Optional[List[str]] = None,
|
||||||
|
round_decimal: int = -1,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
|
||||||
|
self.col.load()
|
||||||
|
# use default index params.
|
||||||
|
if param is None:
|
||||||
|
index_type = self.col.indexes[0].params["index_type"]
|
||||||
|
param = self.index_params_map[index_type]
|
||||||
|
# query text embedding.
|
||||||
|
data = [self.embedding.embed_query(query)]
|
||||||
|
# Determine result metadata fields.
|
||||||
|
output_fields = self.fields[:]
|
||||||
|
output_fields.remove(self.vector_field)
|
||||||
|
# milvus search.
|
||||||
|
res = self.col.search(
|
||||||
|
data,
|
||||||
|
self.vector_field,
|
||||||
|
param,
|
||||||
|
k,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=output_fields,
|
||||||
|
partition_names=partition_names,
|
||||||
|
round_decimal=round_decimal,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
ret = []
|
||||||
|
for result in res[0]:
|
||||||
|
meta = {x: result.entity.get(x) for x in output_fields}
|
||||||
|
ret.append(
|
||||||
|
(
|
||||||
|
Document(page_content=meta.pop(self.text_field), metadata=meta),
|
||||||
|
result.distance,
|
||||||
|
result.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return data[0], ret
|
||||||
|
@ -2,8 +2,14 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
|
|
||||||
class VectorStoreBase(ABC):
|
class VectorStoreBase(ABC):
|
||||||
|
"""base class for vector store database"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def init_schema(self) -> None:
|
def load_document(self, documents) -> None:
|
||||||
|
"""load document in vector database."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def similar_search(self, text, topk) -> None:
|
||||||
"""Initialize schema in vector database."""
|
"""Initialize schema in vector database."""
|
||||||
pass
|
pass
|
@ -61,7 +61,7 @@ gTTS==2.3.1
|
|||||||
langchain
|
langchain
|
||||||
nltk
|
nltk
|
||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
pymilvus
|
pymilvus==2.2.1
|
||||||
vcrpy
|
vcrpy
|
||||||
chromadb
|
chromadb
|
||||||
markdown2
|
markdown2
|
||||||
|
@ -2,10 +2,8 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG, \
|
||||||
from langchain.vectorstores import Milvus
|
VECTOR_STORE_TYPE
|
||||||
|
|
||||||
from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG
|
|
||||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||||
|
|
||||||
|
|
||||||
@ -42,8 +40,8 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
vector_name = args.vector_name
|
vector_name = args.vector_name
|
||||||
append_mode = args.append
|
append_mode = args.append
|
||||||
store_type = args.store_type
|
store_type = VECTOR_STORE_TYPE
|
||||||
vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name, "vector_store_type":store_type}
|
vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name}
|
||||||
print(vector_store_config)
|
print(vector_store_config)
|
||||||
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||||
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||||
|
Loading…
Reference in New Issue
Block a user