diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index da68ab332..0f9cef937 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -48,3 +48,5 @@ DB_SETTINGS = { VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") KNOWLEDGE_CHUNK_SPLIT_SIZE = 100 +VECTOR_STORE_TYPE = "milvus" +VECTOR_STORE_CONFIG = {"url": "127.0.0.1", "port": "19530"} diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 25940a437..bcf8f6385 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -19,7 +19,8 @@ from langchain import PromptTemplate ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K +from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, \ + VECTOR_STORE_CONFIG from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql import MySQLOperator from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding @@ -267,12 +268,16 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["custome"] and not db_selector: - persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb") - print("vector store path: ", persist_dir) + # persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"]) + print("vector store type: ", VECTOR_STORE_CONFIG) + print("vector store name: ", vector_store_name["vs_name"]) + vector_store_config = VECTOR_STORE_CONFIG + vector_store_config["vector_store_name"] = vector_store_name["vs_name"] + vector_store_config["text_field"] = "content" + vector_store_config["vector_store_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], local_persist=False, - vector_store_config={"vector_store_name": vector_store_name["vs_name"], - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) + vector_store_config=vector_store_config) query = state.messages[-2][1] docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) context = [d.page_content for d in docs] diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 08d962908..63d6c2121 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,7 +1,7 @@ import os from bs4 import BeautifulSoup -from langchain.document_loaders import PyPDFLoader, TextLoader, markdown +from langchain.document_loaders import TextLoader, markdown from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Chroma from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE @@ -12,6 +12,7 @@ from pilot.source_embedding.pdf_embedding import PDFEmbedding import markdown from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader +from pilot.vector_store.milvus_store import MilvusStore class KnowledgeEmbedding: @@ -20,7 +21,7 @@ class KnowledgeEmbedding: self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config - self.vector_store_type = "default" + self.file_type = "default" self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.local_persist = local_persist if not self.local_persist: @@ -42,7 +43,7 @@ class KnowledgeEmbedding: elif self.file_path.endswith(".csv"): embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config) - elif self.vector_store_type == "default": + elif self.file_type == "default": embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config) return embedding @@ -52,25 +53,33 @@ class KnowledgeEmbedding: def knowledge_persist_initialization(self, append_mode): vector_name = self.vector_store_config["vector_store_name"] - persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb") - print("vector db path: ", persist_dir) - if os.path.exists(persist_dir): - if append_mode: - print("append knowledge return vector store") - new_documents = self._load_knownlege(self.file_path) - vector_store = Chroma.from_documents(documents=new_documents, + documents = self._load_knownlege(self.file_path) + if self.vector_store_config["vector_store_type"] == "Chroma": + persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb") + print("vector db path: ", persist_dir) + if os.path.exists(persist_dir): + if append_mode: + print("append knowledge return vector store") + new_documents = self._load_knownlege(self.file_path) + vector_store = Chroma.from_documents(documents=new_documents, + embedding=self.embeddings, + persist_directory=persist_dir) + else: + print("directly return vector store") + vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) + else: + print(vector_name + " is new vector store, knowledge begin load...") + vector_store = Chroma.from_documents(documents=documents, embedding=self.embeddings, persist_directory=persist_dir) - else: - print("directly return vector store") - vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) - else: - print(vector_name + " is new vector store, knowledge begin load...") - documents = self._load_knownlege(self.file_path) - vector_store = Chroma.from_documents(documents=documents, - embedding=self.embeddings, - persist_directory=persist_dir) - vector_store.persist() + vector_store.persist() + + elif self.vector_store_config["vector_store_type"] == "milvus": + vector_store = MilvusStore({"url": self.vector_store_config["url"], + "port": self.vector_store_config["port"], + "embedding": self.embeddings}) + vector_store.init_schema_and_load(vector_name, documents) + return vector_store def _load_knownlege(self, path): diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index 66bc97b6d..a253e4d78 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -5,9 +5,14 @@ from abc import ABC, abstractmethod from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Chroma +from langchain.vectorstores import Milvus from typing import List, Optional, Dict + +from pilot.configs.model_config import VECTOR_STORE_TYPE, VECTOR_STORE_CONFIG +from pilot.vector_store.milvus_store import MilvusStore + registered_methods = [] @@ -29,9 +34,20 @@ class SourceEmbedding(ABC): self.vector_store_config = vector_store_config self.embedding_args = embedding_args self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) - persist_dir = os.path.join(self.vector_store_config["vector_store_path"], - self.vector_store_config["vector_store_name"] + ".vectordb") - self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) + + if VECTOR_STORE_TYPE == "milvus": + print(VECTOR_STORE_CONFIG) + if self.vector_store_config.get("text_field") is None: + self.vector_store_client = MilvusStore({"url": VECTOR_STORE_CONFIG["url"], + "port": VECTOR_STORE_CONFIG["port"], + "embedding": self.embeddings}) + else: + self.vector_store_client = Milvus(embedding_function=self.embeddings, collection_name=self.vector_store_config["vector_store_name"], text_field="content", + connection_args={"host": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"]}) + else: + persist_dir = os.path.join(self.vector_store_config["vector_store_path"], + self.vector_store_config["vector_store_name"] + ".vectordb") + self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) @abstractmethod @register @@ -54,10 +70,18 @@ class SourceEmbedding(ABC): @register def index_to_store(self, docs): """index to vector store""" - persist_dir = os.path.join(self.vector_store_config["vector_store_path"], - self.vector_store_config["vector_store_name"] + ".vectordb") - self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir) - self.vector_store.persist() + + if VECTOR_STORE_TYPE == "chroma": + persist_dir = os.path.join(self.vector_store_config["vector_store_path"], + self.vector_store_config["vector_store_name"] + ".vectordb") + self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir) + self.vector_store.persist() + + elif VECTOR_STORE_TYPE == "milvus": + self.vector_store = MilvusStore({"url": VECTOR_STORE_CONFIG["url"], + "port": VECTOR_STORE_CONFIG["port"], + "embedding": self.embeddings}) + self.vector_store.init_schema_and_load(self.vector_store_config["vector_store_name"], docs) @register def similar_search(self, doc, topk): diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index eda0b4e38..6b06dcf00 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -1,31 +1,35 @@ +from typing import List, Optional, Iterable + from langchain.embeddings import HuggingFaceEmbeddings from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection -from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.vector_store.vector_store_base import VectorStoreBase class MilvusStore(VectorStoreBase): - def __init__(self, cfg: {}) -> None: - """Construct a milvus memory storage connection. + def __init__(self, ctx: {}) -> None: + """init a milvus storage connection. Args: - cfg (Config): MilvusStore global config. + ctx ({}): MilvusStore global config. """ # self.configure(cfg) connect_kwargs = {} self.uri = None - self.uri = cfg["url"] - self.port = cfg["port"] - self.username = cfg.get("username", None) - self.password = cfg.get("password", None) - self.collection_name = cfg["table_name"] - self.password = cfg.get("secure", None) + self.uri = ctx["url"] + self.port = ctx["port"] + self.username = ctx.get("username", None) + self.password = ctx.get("password", None) + self.collection_name = ctx.get("table_name", None) + self.secure = ctx.get("secure", None) + self.model_config = ctx.get("model_config", None) + self.embedding = ctx.get("embedding", None) + self.fields = [] # use HNSW by default. self.index_params = { - "metric_type": "IP", + "metric_type": "L2", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}, } @@ -39,20 +43,144 @@ class MilvusStore(VectorStoreBase): connect_kwargs["password"] = self.password connections.connect( - **connect_kwargs, host=self.uri or "127.0.0.1", port=self.port or "19530", alias="default" # secure=self.secure, ) + if self.collection_name is not None: + self.col = Collection(self.collection_name) + schema = self.col.schema + for x in schema.fields: + self.fields.append(x.name) + if x.auto_id: + self.fields.remove(x.name) + if x.is_primary: + self.primary_field = x.name + if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: + self.vector_field = x.name - self.init_schema() + + # self.init_schema() + # self.init_collection_schema() + + def init_schema_and_load(self, vector_name, documents): + """Create a Milvus collection, indexes it with HNSW, load document. + Args: + documents (List[str]): Text to insert. + vector_name (Embeddings): your collection name. + Returns: + VectorStore: The MilvusStore vector store. + """ + try: + from pymilvus import ( + Collection, + CollectionSchema, + DataType, + FieldSchema, + connections, + ) + from pymilvus.orm.types import infer_dtype_bydata + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) + # Connect to Milvus instance + if not connections.has_connection("default"): + connections.connect( + host=self.uri or "127.0.0.1", + port=self.port or "19530", + alias="default" + # secure=self.secure, + ) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + embeddings = self.embedding.embed_query(texts[0]) + dim = len(embeddings) + # Generate unique names + primary_field = "pk_id" + vector_field = "vector" + text_field = "content" + self.text_field = text_field + collection_name = vector_name + fields = [] + # Determine metadata schema + # if metadatas: + # # Check if all metadata keys line up + # key = metadatas[0].keys() + # for x in metadatas: + # if key != x.keys(): + # raise ValueError( + # "Mismatched metadata. " + # "Make sure all metadata has the same keys and datatype." + # ) + # # Create FieldSchema for each entry in singular metadata. + # for key, value in metadatas[0].items(): + # # Infer the corresponding datatype of the metadata + # dtype = infer_dtype_bydata(value) + # if dtype == DataType.UNKNOWN: + # raise ValueError(f"Unrecognized datatype for {key}.") + # elif dtype == DataType.VARCHAR: + # # Find out max length text based metadata + # max_length = 0 + # for subvalues in metadatas: + # max_length = max(max_length, len(subvalues[key])) + # fields.append( + # FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1) + # ) + # else: + # fields.append(FieldSchema(key, dtype)) + + # Find out max length of texts + max_length = 0 + for y in texts: + max_length = max(max_length, len(y)) + # Create the text field + fields.append( + FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1) + ) + # Create the primary key field + fields.append( + FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True) + ) + # Create the vector field + fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) + # Create the schema for the collection + schema = CollectionSchema(fields) + # Create the collection + collection = Collection(collection_name, schema) + self.col = collection + # Index parameters for the collection + index = self.index_params + # Create the index + collection.create_index(vector_field, index) + # Create the VectorStore + # milvus = cls( + # embedding, + # kwargs.get("connection_args", {"port": 19530}), + # collection_name, + # text_field, + # ) + # Add the texts. + schema = collection.schema + for x in schema.fields: + self.fields.append(x.name) + if x.auto_id: + self.fields.remove(x.name) + if x.is_primary: + self.primary_field = x.name + if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: + self.vector_field = x.name + self._add_texts(texts, metadatas) + + return self.collection_name def init_schema(self) -> None: """Initialize collection in milvus database.""" fields = [ FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=384), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]), FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535), ] @@ -75,7 +203,7 @@ class MilvusStore(VectorStoreBase): info = self.collection.describe() self.collection.load() - def insert(self, text) -> str: + def insert(self, text, model_config) -> str: """Add an embedding of data into milvus. Args: text (str): The raw text to construct embedding index. @@ -83,10 +211,54 @@ class MilvusStore(VectorStoreBase): str: log. """ # embedding = get_ada_embedding(data) - embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) + embeddings = HuggingFaceEmbeddings(model_name=self.model_config["model_name"]) result = self.collection.insert([embeddings.embed_documents(text), text]) _text = ( "Inserting data into memory at primary key: " f"{result.primary_keys[0]}:\n data: {text}" ) - return _text \ No newline at end of file + return _text + + def _add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + partition_name: Optional[str] = None, + timeout: Optional[int] = None, + ) -> List[str]: + """Insert text data into Milvus. + Args: + texts (Iterable[str]): The text being embedded and inserted. + metadatas (Optional[List[dict]], optional): The metadata that + corresponds to each insert. Defaults to None. + partition_name (str, optional): The partition of the collection + to insert data into. Defaults to None. + timeout: specified timeout. + + Returns: + List[str]: The resulting keys for each inserted element. + """ + insert_dict: Any = {self.text_field: list(texts)} + try: + insert_dict[self.vector_field] = self.embedding.embed_documents( + list(texts) + ) + except NotImplementedError: + insert_dict[self.vector_field] = [ + self.embedding.embed_query(x) for x in texts + ] + # Collect the metadata into the insert dict. + if len(self.fields) > 2 and metadatas is not None: + for d in metadatas: + for key, value in d.items(): + if key in self.fields: + insert_dict.setdefault(key, []).append(value) + # Convert dict to list of lists for insertion + insert_list = [insert_dict[x] for x in self.fields] + # Insert into the collection. + res = self.col.insert( + insert_list, partition_name=partition_name, timeout=timeout + ) + # Flush to make sure newly inserted is immediately searchable. + self.col.flush() + return res.primary_keys diff --git a/requirements.txt b/requirements.txt index eac927c3d..ba31d0d04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -60,6 +60,7 @@ gTTS==2.3.1 langchain nltk python-dotenv==1.0.0 +pymilvus # Testing dependencies pytest diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index e9ecad49a..fdc754e05 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -2,8 +2,10 @@ # -*- coding: utf-8 -*- import argparse -from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, \ - KNOWLEDGE_UPLOAD_ROOT_PATH +from langchain.embeddings import HuggingFaceEmbeddings +from langchain.vectorstores import Milvus + +from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding @@ -12,15 +14,15 @@ class LocalKnowledgeInit: model_name = LLM_MODEL_CONFIG["text2vec"] top_k: int = VECTOR_SEARCH_TOP_K - def __init__(self) -> None: - pass + def __init__(self, vector_store_config) -> None: + self.vector_store_config = vector_store_config - def knowledge_persist(self, file_path, vector_name, append_mode): + def knowledge_persist(self, file_path, append_mode): """ knowledge persist """ kv = KnowledgeEmbedding( file_path=file_path, model_name=LLM_MODEL_CONFIG["text2vec"], - vector_store_config= {"vector_store_name":vector_name, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) + vector_store_config= self.vector_store_config) vector_store = kv.knowledge_persist_initialization(append_mode) return vector_store @@ -34,11 +36,15 @@ class LocalKnowledgeInit: if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--vector_name", type=str, default="default") + parser.add_argument("--vector_name", type=str, default="keting") parser.add_argument("--append", type=bool, default=False) + parser.add_argument("--store_type", type=str, default="Chroma") args = parser.parse_args() vector_name = args.vector_name append_mode = args.append - kv = LocalKnowledgeInit() - vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, vector_name=vector_name, append_mode=append_mode) + store_type = args.store_type + vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name, "vector_store_type":store_type} + print(vector_store_config) + kv = LocalKnowledgeInit(vector_store_config=vector_store_config) + vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) print("your knowledge embedding success...") \ No newline at end of file