diff --git a/pilot/configs/config.py b/pilot/configs/config.py index b914390f7..e9ec2bd48 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -109,6 +109,14 @@ class Config(metaclass=Singleton): self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)) self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True" + ### Vector Store Configuration + self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma") + self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1") + self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530") + self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) + self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + + def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" self.debug_mode = value diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 6e32daefc..ebd8513e4 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -47,8 +47,4 @@ ISDEBUG = False VECTOR_SEARCH_TOP_K = 10 VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") -KNOWLEDGE_CHUNK_SPLIT_SIZE = 100 -#vector db type, now provided Chroma and Milvus -VECTOR_STORE_TYPE = "Milvus" -#vector db config -VECTOR_STORE_CONFIG = {"url": "127.0.0.1", "port": "19530"} +KNOWLEDGE_CHUNK_SPLIT_SIZE = 100 \ No newline at end of file diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 85db5ab02..cb1fcb504 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -3,6 +3,8 @@ import os from bs4 import BeautifulSoup from langchain.document_loaders import TextLoader, markdown from langchain.embeddings import HuggingFaceEmbeddings + +from pilot.configs.config import Config from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE, VECTOR_STORE_TYPE from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.csv_embedding import CSVEmbedding @@ -13,6 +15,7 @@ import markdown from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader from pilot.vector_store.connector import VectorStoreConnector +CFG = Config() class KnowledgeEmbedding: def __init__(self, file_path, model_name, vector_store_config, local_persist=True): @@ -53,7 +56,7 @@ class KnowledgeEmbedding: def knowledge_persist_initialization(self, append_mode): documents = self._load_knownlege(self.file_path) - self.vector_client = VectorStoreConnector(VECTOR_STORE_TYPE, self.vector_store_config) + self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config) self.vector_client.load_document(documents) return self.vector_client diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index ddefd4f1e..a84282009 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -4,10 +4,12 @@ from abc import ABC, abstractmethod from langchain.embeddings import HuggingFaceEmbeddings from typing import List, Optional, Dict -from pilot.configs.model_config import VECTOR_STORE_TYPE + +from pilot.configs.config import Config from pilot.vector_store.connector import VectorStoreConnector registered_methods = [] +CFG = Config() def register(method): @@ -30,7 +32,7 @@ class SourceEmbedding(ABC): self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) vector_store_config["embeddings"] = self.embeddings - self.vector_client = VectorStoreConnector(VECTOR_STORE_TYPE, vector_store_config) + self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, vector_store_config) @abstractmethod @register diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 1c6d6bdbc..a61027850 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -2,11 +2,12 @@ from typing import List, Optional, Iterable, Tuple, Any from pymilvus import connections, Collection, DataType -from pilot.configs.model_config import VECTOR_STORE_CONFIG from langchain.docstore.document import Document + +from pilot.configs.config import Config from pilot.vector_store.vector_store_base import VectorStoreBase - +CFG = Config() class MilvusStore(VectorStoreBase): """Milvus database""" def __init__(self, ctx: {}) -> None: @@ -18,11 +19,10 @@ class MilvusStore(VectorStoreBase): # self.configure(cfg) connect_kwargs = {} - self.uri = None - self.uri = ctx.get("url", VECTOR_STORE_CONFIG["url"]) - self.port = ctx.get("port", VECTOR_STORE_CONFIG["port"]) - self.username = ctx.get("username", None) - self.password = ctx.get("password", None) + self.uri = CFG.MILVUS_URL + self.port = CFG.MILVUS_PORT + self.username = CFG.MILVUS_USERNAME + self.password = CFG.MILVUS_PASSWORD self.collection_name = ctx.get("vector_store_name", None) self.secure = ctx.get("secure", None) self.embedding = ctx.get("embeddings", None) @@ -238,16 +238,6 @@ class MilvusStore(VectorStoreBase): timeout: Optional[int] = None, ) -> List[str]: """add text data into Milvus. - Args: - texts (Iterable[str]): The text being embedded and inserted. - metadatas (Optional[List[dict]], optional): The metadata that - corresponds to each insert. Defaults to None. - partition_name (str, optional): The partition of the collection - to insert data into. Defaults to None. - timeout: specified timeout. - - Returns: - List[str]: The resulting keys for each inserted element. """ insert_dict: Any = {self.text_field: list(texts)} try: @@ -279,6 +269,7 @@ class MilvusStore(VectorStoreBase): self.init_schema_and_load(self.collection_name, documents) def similar_search(self, text, topk) -> None: + """similar_search in vector database.""" self.col = Collection(self.collection_name) schema = self.col.schema for x in schema.fields: @@ -326,7 +317,6 @@ class MilvusStore(VectorStoreBase): timeout=timeout, **kwargs, ) - # Organize results. ret = [] for result in res[0]: meta = {x: result.entity.get(x) for x in output_fields}