diff --git a/.env.template b/.env.template index d809a362b..3fe762e73 100644 --- a/.env.template +++ b/.env.template @@ -81,3 +81,14 @@ DENYLISTED_PLUGINS= #*******************************************************************# # CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False) # CHAT_MESSAGES_ENABLED=False + + +#*******************************************************************# +#** VECTOR STORE SETTINGS **# +#*******************************************************************# +VECTOR_STORE_TYPE=Chroma +#MILVUS_URL=127.0.0.1 +#MILVUS_PORT=19530 +#MILVUS_USERNAME +#MILVUS_PASSWORD +#MILVUS_SECURE= diff --git a/pilot/configs/config.py b/pilot/configs/config.py index b914390f7..e9ec2bd48 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -109,6 +109,14 @@ class Config(metaclass=Singleton): self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)) self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True" + ### Vector Store Configuration + self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma") + self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1") + self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530") + self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) + self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + + def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" self.debug_mode = value diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 7c4928304..ebd8513e4 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -47,6 +47,4 @@ ISDEBUG = False VECTOR_SEARCH_TOP_K = 10 VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") -KNOWLEDGE_CHUNK_SPLIT_SIZE = 100 -VECTOR_STORE_TYPE = "milvus" -VECTOR_STORE_CONFIG = {"url": "127.0.0.1", "port": "19530"} +KNOWLEDGE_CHUNK_SPLIT_SIZE = 100 \ No newline at end of file diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 270eff67f..1ac32ab26 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -19,8 +19,7 @@ from langchain import PromptTemplate ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, \ - VECTOR_STORE_CONFIG +from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql import MySQLOperator from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding @@ -268,13 +267,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["custome"] and not db_selector: - # persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"]) - print("vector store type: ", VECTOR_STORE_CONFIG) print("vector store name: ", vector_store_name["vs_name"]) - vector_store_config = VECTOR_STORE_CONFIG - vector_store_config["vector_store_name"] = vector_store_name["vs_name"] - vector_store_config["text_field"] = "content" - vector_store_config["vector_store_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH + vector_store_config = {"vector_store_name": vector_store_name["vs_name"], "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH} knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], local_persist=False, vector_store_config=vector_store_config) diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 63d6c2121..2f313a35a 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,9 +1,10 @@ import os from bs4 import BeautifulSoup -from langchain.document_loaders import TextLoader, markdown +from langchain.document_loaders import TextLoader, markdown, PyPDFLoader from langchain.embeddings import HuggingFaceEmbeddings -from langchain.vectorstores import Chroma + +from pilot.configs.config import Config from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.csv_embedding import CSVEmbedding @@ -11,9 +12,9 @@ from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding import markdown -from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader -from pilot.vector_store.milvus_store import MilvusStore +from pilot.vector_store.connector import VectorStoreConnector +CFG = Config() class KnowledgeEmbedding: def __init__(self, file_path, model_name, vector_store_config, local_persist=True): @@ -23,6 +24,7 @@ class KnowledgeEmbedding: self.vector_store_config = vector_store_config self.file_type = "default" self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) + self.vector_store_config["embeddings"] = self.embeddings self.local_persist = local_persist if not self.local_persist: self.knowledge_embedding_client = self.init_knowledge_embedding() @@ -52,35 +54,10 @@ class KnowledgeEmbedding: return self.knowledge_embedding_client.similar_search(text, topk) def knowledge_persist_initialization(self, append_mode): - vector_name = self.vector_store_config["vector_store_name"] documents = self._load_knownlege(self.file_path) - if self.vector_store_config["vector_store_type"] == "Chroma": - persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb") - print("vector db path: ", persist_dir) - if os.path.exists(persist_dir): - if append_mode: - print("append knowledge return vector store") - new_documents = self._load_knownlege(self.file_path) - vector_store = Chroma.from_documents(documents=new_documents, - embedding=self.embeddings, - persist_directory=persist_dir) - else: - print("directly return vector store") - vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) - else: - print(vector_name + " is new vector store, knowledge begin load...") - vector_store = Chroma.from_documents(documents=documents, - embedding=self.embeddings, - persist_directory=persist_dir) - vector_store.persist() - - elif self.vector_store_config["vector_store_type"] == "milvus": - vector_store = MilvusStore({"url": self.vector_store_config["url"], - "port": self.vector_store_config["port"], - "embedding": self.embeddings}) - vector_store.init_schema_and_load(vector_name, documents) - - return vector_store + self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config) + self.vector_client.load_document(documents) + return self.vector_client def _load_knownlege(self, path): docments = [] @@ -111,7 +88,7 @@ class KnowledgeEmbedding: docs[i].page_content = docs[i].page_content.replace("\n", " ") i += 1 elif filename.lower().endswith(".pdf"): - loader = UnstructuredPaddlePDFLoader(filename) + loader = PyPDFLoader(filename) textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) docs = loader.load_and_split(textsplitter) i = 0 diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index a8749695b..75d17c4c6 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -2,12 +2,12 @@ # -*- coding: utf-8 -*- from typing import List +from langchain.document_loaders import PyPDFLoader from langchain.schema import Document from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter -from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader class PDFEmbedding(SourceEmbedding): @@ -23,7 +23,8 @@ class PDFEmbedding(SourceEmbedding): @register def read(self): """Load from pdf path.""" - loader = UnstructuredPaddlePDFLoader(self.file_path) + # loader = UnstructuredPaddlePDFLoader(self.file_path) + loader = PyPDFLoader(self.file_path) textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) return loader.load_and_split(textsplitter) diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index a253e4d78..a84282009 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -1,19 +1,15 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os from abc import ABC, abstractmethod from langchain.embeddings import HuggingFaceEmbeddings -from langchain.vectorstores import Chroma -from langchain.vectorstores import Milvus - from typing import List, Optional, Dict - -from pilot.configs.model_config import VECTOR_STORE_TYPE, VECTOR_STORE_CONFIG -from pilot.vector_store.milvus_store import MilvusStore +from pilot.configs.config import Config +from pilot.vector_store.connector import VectorStoreConnector registered_methods = [] +CFG = Config() def register(method): @@ -35,19 +31,8 @@ class SourceEmbedding(ABC): self.embedding_args = embedding_args self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) - if VECTOR_STORE_TYPE == "milvus": - print(VECTOR_STORE_CONFIG) - if self.vector_store_config.get("text_field") is None: - self.vector_store_client = MilvusStore({"url": VECTOR_STORE_CONFIG["url"], - "port": VECTOR_STORE_CONFIG["port"], - "embedding": self.embeddings}) - else: - self.vector_store_client = Milvus(embedding_function=self.embeddings, collection_name=self.vector_store_config["vector_store_name"], text_field="content", - connection_args={"host": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"]}) - else: - persist_dir = os.path.join(self.vector_store_config["vector_store_path"], - self.vector_store_config["vector_store_name"] + ".vectordb") - self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) + vector_store_config["embeddings"] = self.embeddings + self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, vector_store_config) @abstractmethod @register @@ -70,24 +55,12 @@ class SourceEmbedding(ABC): @register def index_to_store(self, docs): """index to vector store""" - - if VECTOR_STORE_TYPE == "chroma": - persist_dir = os.path.join(self.vector_store_config["vector_store_path"], - self.vector_store_config["vector_store_name"] + ".vectordb") - self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir) - self.vector_store.persist() - - elif VECTOR_STORE_TYPE == "milvus": - self.vector_store = MilvusStore({"url": VECTOR_STORE_CONFIG["url"], - "port": VECTOR_STORE_CONFIG["port"], - "embedding": self.embeddings}) - self.vector_store.init_schema_and_load(self.vector_store_config["vector_store_name"], docs) + self.vector_client.load_document(docs) @register def similar_search(self, doc, topk): """vector store similarity_search""" - - return self.vector_store_client.similarity_search(doc, topk) + return self.vector_client.similar_search(doc, topk) def source_embedding(self): if 'read' in registered_methods: diff --git a/pilot/vector_store/chroma_store.py b/pilot/vector_store/chroma_store.py new file mode 100644 index 000000000..9a91659f1 --- /dev/null +++ b/pilot/vector_store/chroma_store.py @@ -0,0 +1,30 @@ +import os + +from langchain.vectorstores import Chroma + +from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH +from pilot.logs import logger +from pilot.vector_store.vector_store_base import VectorStoreBase + + +class ChromaStore(VectorStoreBase): + """chroma database""" + + def __init__(self, ctx: {}) -> None: + self.ctx = ctx + self.embeddings = ctx["embeddings"] + self.persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, + ctx["vector_store_name"] + ".vectordb") + self.vector_store_client = Chroma(persist_directory=self.persist_dir, embedding_function=self.embeddings) + + def similar_search(self, text, topk) -> None: + logger.info("ChromaStore similar search") + return self.vector_store_client.similarity_search(text, topk) + + def load_document(self, documents): + logger.info("ChromaStore load document") + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + self.vector_store_client.add_texts(texts=texts, metadatas=metadatas) + self.vector_store_client.persist() + diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py new file mode 100644 index 000000000..003415712 --- /dev/null +++ b/pilot/vector_store/connector.py @@ -0,0 +1,22 @@ +from pilot.vector_store.chroma_store import ChromaStore +from pilot.vector_store.milvus_store import MilvusStore + +connector = { + "Chroma": ChromaStore, + "Milvus": MilvusStore + } + + +class VectorStoreConnector: + """ vector store connector, can connect different vector db provided load document api and similar search api + """ + def __init__(self, vector_store_type, ctx: {}) -> None: + self.ctx = ctx + self.connector_class = connector[vector_store_type] + self.client = self.connector_class(ctx) + + def load_document(self, docs): + self.client.load_document(docs) + + def similar_search(self, docs, topk): + return self.client.similar_search(docs, topk) diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 5204e6b11..a61027850 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -1,12 +1,15 @@ -from typing import List, Optional, Iterable +from typing import List, Optional, Iterable, Tuple, Any -from langchain.embeddings import HuggingFaceEmbeddings -from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection +from pymilvus import connections, Collection, DataType +from langchain.docstore.document import Document + +from pilot.configs.config import Config from pilot.vector_store.vector_store_base import VectorStoreBase - +CFG = Config() class MilvusStore(VectorStoreBase): + """Milvus database""" def __init__(self, ctx: {}) -> None: """init a milvus storage connection. @@ -16,15 +19,13 @@ class MilvusStore(VectorStoreBase): # self.configure(cfg) connect_kwargs = {} - self.uri = None - self.uri = ctx["url"] - self.port = ctx["port"] - self.username = ctx.get("username", None) - self.password = ctx.get("password", None) - self.collection_name = ctx.get("table_name", None) + self.uri = CFG.MILVUS_URL + self.port = CFG.MILVUS_PORT + self.username = CFG.MILVUS_USERNAME + self.password = CFG.MILVUS_PASSWORD + self.collection_name = ctx.get("vector_store_name", None) self.secure = ctx.get("secure", None) - self.model_config = ctx.get("model_config", None) - self.embedding = ctx.get("embedding", None) + self.embedding = ctx.get("embeddings", None) self.fields = [] # use HNSW by default. @@ -33,6 +34,20 @@ class MilvusStore(VectorStoreBase): "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}, } + # use HNSW by default. + self.index_params_map = { + "IVF_FLAT": {"params": {"nprobe": 10}}, + "IVF_SQ8": {"params": {"nprobe": 10}}, + "IVF_PQ": {"params": {"nprobe": 10}}, + "HNSW": {"params": {"ef": 10}}, + "RHNSW_FLAT": {"params": {"ef": 10}}, + "RHNSW_SQ": {"params": {"ef": 10}}, + "RHNSW_PQ": {"params": {"ef": 10}}, + "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, + "ANNOY": {"params": {"search_k": 10}}, + } + + self.text_field = "content" if (self.username is None) != (self.password is None): raise ValueError( @@ -48,21 +63,6 @@ class MilvusStore(VectorStoreBase): alias="default" # secure=self.secure, ) - if self.collection_name is not None: - self.col = Collection(self.collection_name) - schema = self.col.schema - for x in schema.fields: - self.fields.append(x.name) - if x.auto_id: - self.fields.remove(x.name) - if x.is_primary: - self.primary_field = x.name - if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: - self.vector_field = x.name - - - # self.init_schema() - # self.init_collection_schema() def init_schema_and_load(self, vector_name, documents): """Create a Milvus collection, indexes it with HNSW, load document. @@ -86,7 +86,6 @@ class MilvusStore(VectorStoreBase): "Could not import pymilvus python package. " "Please install it with `pip install pymilvus`." ) - # Connect to Milvus instance if not connections.has_connection("default"): connections.connect( host=self.uri or "127.0.0.1", @@ -140,11 +139,11 @@ class MilvusStore(VectorStoreBase): fields.append( FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1) ) - # Create the primary key field + # create the primary key field fields.append( FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True) ) - # Create the vector field + # create the vector field fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) # Create the schema for the collection schema = CollectionSchema(fields) @@ -176,32 +175,44 @@ class MilvusStore(VectorStoreBase): return self.collection_name - def init_schema(self) -> None: - """Initialize collection in milvus database.""" - fields = [ - FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]), - FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535), - ] - - # create collection if not exist and load it. - self.schema = CollectionSchema(fields, "db-gpt memory storage") - self.collection = Collection(self.collection_name, self.schema) - self.index_params = { - "metric_type": "IP", - "index_type": "HNSW", - "params": {"M": 8, "efConstruction": 64}, - } - # create index if not exist. - if not self.collection.has_index(): - self.collection.release() - self.collection.create_index( - "vector", - self.index_params, - index_name="vector", - ) - info = self.collection.describe() - self.collection.load() + # def init_schema(self) -> None: + # """Initialize collection in milvus database.""" + # fields = [ + # FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), + # FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]), + # FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535), + # ] + # + # # create collection if not exist and load it. + # self.schema = CollectionSchema(fields, "db-gpt memory storage") + # self.collection = Collection(self.collection_name, self.schema) + # self.index_params_map = { + # "IVF_FLAT": {"params": {"nprobe": 10}}, + # "IVF_SQ8": {"params": {"nprobe": 10}}, + # "IVF_PQ": {"params": {"nprobe": 10}}, + # "HNSW": {"params": {"ef": 10}}, + # "RHNSW_FLAT": {"params": {"ef": 10}}, + # "RHNSW_SQ": {"params": {"ef": 10}}, + # "RHNSW_PQ": {"params": {"ef": 10}}, + # "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, + # "ANNOY": {"params": {"search_k": 10}}, + # } + # + # self.index_params = { + # "metric_type": "IP", + # "index_type": "HNSW", + # "params": {"M": 8, "efConstruction": 64}, + # } + # # create index if not exist. + # if not self.collection.has_index(): + # self.collection.release() + # self.collection.create_index( + # "vector", + # self.index_params, + # index_name="vector", + # ) + # info = self.collection.describe() + # self.collection.load() # def insert(self, text, model_config) -> str: # """Add an embedding of data into milvus. @@ -226,17 +237,7 @@ class MilvusStore(VectorStoreBase): partition_name: Optional[str] = None, timeout: Optional[int] = None, ) -> List[str]: - """Insert text data into Milvus. - Args: - texts (Iterable[str]): The text being embedded and inserted. - metadatas (Optional[List[dict]], optional): The metadata that - corresponds to each insert. Defaults to None. - partition_name (str, optional): The partition of the collection - to insert data into. Defaults to None. - timeout: specified timeout. - - Returns: - List[str]: The resulting keys for each inserted element. + """add text data into Milvus. """ insert_dict: Any = {self.text_field: list(texts)} try: @@ -259,6 +260,72 @@ class MilvusStore(VectorStoreBase): res = self.col.insert( insert_list, partition_name=partition_name, timeout=timeout ) - # Flush to make sure newly inserted is immediately searchable. + # make sure data is searchable. self.col.flush() return res.primary_keys + + def load_document(self, documents) -> None: + """load document in vector database.""" + self.init_schema_and_load(self.collection_name, documents) + + def similar_search(self, text, topk) -> None: + """similar_search in vector database.""" + self.col = Collection(self.collection_name) + schema = self.col.schema + for x in schema.fields: + self.fields.append(x.name) + if x.auto_id: + self.fields.remove(x.name) + if x.is_primary: + self.primary_field = x.name + if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: + self.vector_field = x.name + _, docs_and_scores = self._search(text, topk) + return [doc for doc, _, _ in docs_and_scores] + + def _search( + self, + query: str, + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + partition_names: Optional[List[str]] = None, + round_decimal: int = -1, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]: + self.col.load() + # use default index params. + if param is None: + index_type = self.col.indexes[0].params["index_type"] + param = self.index_params_map[index_type] + # query text embedding. + data = [self.embedding.embed_query(query)] + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self.vector_field) + # milvus search. + res = self.col.search( + data, + self.vector_field, + param, + k, + expr=expr, + output_fields=output_fields, + partition_names=partition_names, + round_decimal=round_decimal, + timeout=timeout, + **kwargs, + ) + ret = [] + for result in res[0]: + meta = {x: result.entity.get(x) for x in output_fields} + ret.append( + ( + Document(page_content=meta.pop(self.text_field), metadata=meta), + result.distance, + result.id, + ) + ) + + return data[0], ret diff --git a/pilot/vector_store/vector_store_base.py b/pilot/vector_store/vector_store_base.py index 818730f0f..b483b3116 100644 --- a/pilot/vector_store/vector_store_base.py +++ b/pilot/vector_store/vector_store_base.py @@ -2,8 +2,14 @@ from abc import ABC, abstractmethod class VectorStoreBase(ABC): + """base class for vector store database""" @abstractmethod - def init_schema(self) -> None: + def load_document(self, documents) -> None: + """load document in vector database.""" + pass + + @abstractmethod + def similar_search(self, text, topk) -> None: """Initialize schema in vector database.""" pass \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 29e792451..aea4f00e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,7 +61,7 @@ gTTS==2.3.1 langchain nltk python-dotenv==1.0.0 -pymilvus +pymilvus==2.2.1 vcrpy chromadb markdown2 diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index 60010e4de..23ca33a80 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -2,10 +2,8 @@ # -*- coding: utf-8 -*- import argparse -from langchain.embeddings import HuggingFaceEmbeddings -from langchain.vectorstores import Milvus - -from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG +from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG, \ + VECTOR_STORE_TYPE from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding @@ -42,8 +40,8 @@ if __name__ == "__main__": args = parser.parse_args() vector_name = args.vector_name append_mode = args.append - store_type = args.store_type - vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name, "vector_store_type":store_type} + store_type = VECTOR_STORE_TYPE + vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name} print(vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config) vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)