From e6aa46fc875159e4aeda10bf9d28ca1e0f995c75 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 11 Jul 2023 16:33:48 +0800 Subject: [PATCH] refactor:refactor knowledge api 1.delete CFG in embedding_engine api 2.add a text_splitter param in embedding_engine api --- docs/modules/knowledge.rst | 8 ++-- docs/use_cases/knownledge_based_qa.md | 8 +++- pilot/embedding_engine/csv_embedding.py | 11 ++---- pilot/embedding_engine/embedding_engine.py | 9 ++--- pilot/embedding_engine/knowledge_type.py | 35 +++++++++++++++++- pilot/embedding_engine/markdown_embedding.py | 30 ++++++--------- pilot/embedding_engine/pdf_embedding.py | 37 ++++++------------- pilot/embedding_engine/ppt_embedding.py | 34 ++++++----------- pilot/embedding_engine/source_embedding.py | 18 +++++---- pilot/embedding_engine/string_embedding.py | 7 ++-- pilot/embedding_engine/url_embedding.py | 28 ++++++-------- pilot/embedding_engine/word_embedding.py | 27 +++++--------- pilot/scene/chat_knowledge/custom/chat.py | 4 +- pilot/scene/chat_knowledge/default/chat.py | 3 +- pilot/scene/chat_knowledge/url/chat.py | 3 +- pilot/scene/chat_knowledge/v1/chat.py | 3 +- pilot/server/knowledge/service.py | 4 +- pilot/server/webserver.py | 1 + pilot/summary/db_summary_client.py | 17 ++++++++- pilot/vector_store/chroma_store.py | 3 +- pilot/vector_store/connector.py | 4 +- pilot/vector_store/milvus_store.py | 10 ++--- .../embedding_engine/test_url_embedding.py | 4 ++ tools/knowledge_init.py | 4 +- 24 files changed, 161 insertions(+), 151 deletions(-) diff --git a/docs/modules/knowledge.rst b/docs/modules/knowledge.rst index 313df1512..756486f89 100644 --- a/docs/modules/knowledge.rst +++ b/docs/modules/knowledge.rst @@ -26,7 +26,7 @@ before execution: :: url = "https://db-gpt.readthedocs.io/en/latest/getting_started/getting_started.html" - embedding_model = "text2vec" + embedding_model = "your_model_path/all-MiniLM-L6-v2" vector_store_config = { "vector_store_name": your_name, } @@ -43,9 +43,11 @@ Document type can be .txt, .pdf, .md, .doc, .ppt. :: document_path = "your_path/test.md" - embedding_model = "text2vec" + embedding_model = "your_model_path/all-MiniLM-L6-v2" vector_store_config = { "vector_store_name": your_name, + "vector_store_type": "Chroma", + "chroma_persist_path": "your_persist_dir", } embedding_engine = EmbeddingEngine( knowledge_source=document_path, @@ -59,7 +61,7 @@ Document type can be .txt, .pdf, .md, .doc, .ppt. :: raw_text = "a long passage" - embedding_model = "text2vec" + embedding_model = "your_model_path/all-MiniLM-L6-v2" vector_store_config = { "vector_store_name": your_name, } diff --git a/docs/use_cases/knownledge_based_qa.md b/docs/use_cases/knownledge_based_qa.md index 2cafed421..6c4f1d56e 100644 --- a/docs/use_cases/knownledge_based_qa.md +++ b/docs/use_cases/knownledge_based_qa.md @@ -32,11 +32,17 @@ Below is an example of using the knowledge base API to query knowledge: ``` vector_store_config = { - "vector_store_name": name + "vector_store_name": your_name, + "vector_store_type": "Chroma", + "chroma_persist_path": "your_persist_dir", } +integrate + query = "your query" +embedding_model = "your_model_path/all-MiniLM-L6-v2" + embedding_engine = EmbeddingEngine(knowledge_source=url, knowledge_type=KnowledgeType.URL.value, model_name=embedding_model, vector_store_config=vector_store_config) embedding_engine.similar_search(query, 10) diff --git a/pilot/embedding_engine/csv_embedding.py b/pilot/embedding_engine/csv_embedding.py index 0e0aa54ec..ad2ca4333 100644 --- a/pilot/embedding_engine/csv_embedding.py +++ b/pilot/embedding_engine/csv_embedding.py @@ -9,17 +9,12 @@ from pilot.embedding_engine import SourceEmbedding, register class CSVEmbedding(SourceEmbedding): """csv embedding for read csv document.""" - def __init__( - self, - file_path, - vector_store_config, - embedding_args: Optional[Dict] = None, - ): + def __init__(self, file_path, vector_store_config, text_splitter=None): """Initialize with csv path.""" - super().__init__(file_path, vector_store_config) + super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path self.vector_store_config = vector_store_config - self.embedding_args = embedding_args + self.text_splitter = text_splitter or None @register def read(self): diff --git a/pilot/embedding_engine/embedding_engine.py b/pilot/embedding_engine/embedding_engine.py index 066611732..4d42fc3d0 100644 --- a/pilot/embedding_engine/embedding_engine.py +++ b/pilot/embedding_engine/embedding_engine.py @@ -3,12 +3,9 @@ from typing import Optional from chromadb.errors import NotEnoughElementsException from langchain.embeddings import HuggingFaceEmbeddings -from pilot.configs.config import Config from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType from pilot.vector_store.connector import VectorStoreConnector -CFG = Config() - class EmbeddingEngine: def __init__( @@ -45,7 +42,7 @@ class EmbeddingEngine: def similar_search(self, text, topk): vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, self.vector_store_config + self.vector_store_config["vector_store_type"], self.vector_store_config ) try: ans = vector_client.similar_search(text, topk) @@ -55,12 +52,12 @@ class EmbeddingEngine: def vector_exist(self): vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, self.vector_store_config + self.vector_store_config["vector_store_type"], self.vector_store_config ) return vector_client.vector_name_exists() def delete_by_ids(self, ids): vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, self.vector_store_config + self.vector_store_config["vector_store_type"], self.vector_store_config ) vector_client.delete_by_ids(ids=ids) diff --git a/pilot/embedding_engine/knowledge_type.py b/pilot/embedding_engine/knowledge_type.py index a2b3d563c..5becaa30c 100644 --- a/pilot/embedding_engine/knowledge_type.py +++ b/pilot/embedding_engine/knowledge_type.py @@ -11,6 +11,7 @@ from pilot.embedding_engine.word_embedding import WordEmbedding DocumentEmbeddingType = { ".txt": (MarkdownEmbedding, {}), ".md": (MarkdownEmbedding, {}), + ".html": (MarkdownEmbedding, {}), ".pdf": (PDFEmbedding, {}), ".doc": (WordEmbedding, {}), ".docx": (WordEmbedding, {}), @@ -25,7 +26,18 @@ class KnowledgeType(Enum): URL = "URL" TEXT = "TEXT" OSS = "OSS" + S3 = "S3" NOTION = "NOTION" + MYSQL = "MYSQL" + TIDB = "TIDB" + CLICKHOUSE = "CLICKHOUSE" + OCEANBASE = "OCEANBASE" + ELASTICSEARCH = "ELASTICSEARCH" + HIVE = "HIVE" + PRESTO = "PRESTO" + KAFKA = "KAFKA" + SPARK = "SPARK" + YOUTUBE = "YOUTUBE" def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config): @@ -55,8 +67,29 @@ def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_confi return embedding case KnowledgeType.OSS.value: raise Exception("OSS have not integrate") + case KnowledgeType.S3.value: + raise Exception("S3 have not integrate") case KnowledgeType.NOTION.value: raise Exception("NOTION have not integrate") - + case KnowledgeType.MYSQL.value: + raise Exception("MYSQL have not integrate") + case KnowledgeType.TIDB.value: + raise Exception("TIDB have not integrate") + case KnowledgeType.CLICKHOUSE.value: + raise Exception("CLICKHOUSE have not integrate") + case KnowledgeType.OCEANBASE.value: + raise Exception("OCEANBASE have not integrate") + case KnowledgeType.ELASTICSEARCH.value: + raise Exception("ELASTICSEARCH have not integrate") + case KnowledgeType.HIVE.value: + raise Exception("HIVE have not integrate") + case KnowledgeType.PRESTO.value: + raise Exception("PRESTO have not integrate") + case KnowledgeType.KAFKA.value: + raise Exception("KAFKA have not integrate") + case KnowledgeType.SPARK.value: + raise Exception("SPARK have not integrate") + case KnowledgeType.YOUTUBE.value: + raise Exception("YOUTUBE have not integrate") case _: raise Exception("unknown knowledge type") diff --git a/pilot/embedding_engine/markdown_embedding.py b/pilot/embedding_engine/markdown_embedding.py index e3afe8b13..03969a925 100644 --- a/pilot/embedding_engine/markdown_embedding.py +++ b/pilot/embedding_engine/markdown_embedding.py @@ -12,46 +12,38 @@ from langchain.text_splitter import ( RecursiveCharacterTextSplitter, ) -from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader -CFG = Config() - class MarkdownEmbedding(SourceEmbedding): """markdown embedding for read markdown document.""" - def __init__(self, file_path, vector_store_config): - """Initialize with markdown path.""" - super().__init__(file_path, vector_store_config) + def __init__(self, file_path, vector_store_config, text_splitter=None): + """Initialize raw text word path.""" + super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path self.vector_store_config = vector_store_config + self.text_splitter = text_splitter or None # self.encoding = encoding @register def read(self): """Load from markdown path.""" loader = EncodeTextLoader(self.file_path) - - if CFG.LANGUAGE == "en": - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=20, - length_function=len, - ) - else: + if self.text_splitter is None: try: - text_splitter = SpacyTextSplitter( + self.text_splitter = SpacyTextSplitter( pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_size=100, chunk_overlap=100, ) except Exception: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=100, chunk_overlap=50 ) - return loader.load_and_split(text_splitter) + + return loader.load_and_split(self.text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/pdf_embedding.py b/pilot/embedding_engine/pdf_embedding.py index 21b48f41d..5928ade43 100644 --- a/pilot/embedding_engine/pdf_embedding.py +++ b/pilot/embedding_engine/pdf_embedding.py @@ -6,51 +6,36 @@ from langchain.document_loaders import PyPDFLoader from langchain.schema import Document from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter -from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register -CFG = Config() - class PDFEmbedding(SourceEmbedding): """pdf embedding for read pdf document.""" - def __init__(self, file_path, vector_store_config): - """Initialize with pdf path.""" - super().__init__(file_path, vector_store_config) + def __init__(self, file_path, vector_store_config, text_splitter=None): + """Initialize pdf word path.""" + super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path self.vector_store_config = vector_store_config + self.text_splitter = text_splitter or None @register def read(self): """Load from pdf path.""" loader = PyPDFLoader(self.file_path) - # textsplitter = CHNDocumentSplitter( - # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE - # ) - # textsplitter = SpacyTextSplitter( - # pipeline="zh_core_web_sm", - # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - # chunk_overlap=100, - # ) - if CFG.LANGUAGE == "en": - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=20, - length_function=len, - ) - else: + if self.text_splitter is None: try: - text_splitter = SpacyTextSplitter( + self.text_splitter = SpacyTextSplitter( pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_size=100, chunk_overlap=100, ) except Exception: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=100, chunk_overlap=50 ) - return loader.load_and_split(text_splitter) + + return loader.load_and_split(self.text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/ppt_embedding.py b/pilot/embedding_engine/ppt_embedding.py index a181f8d37..7dd6f057f 100644 --- a/pilot/embedding_engine/ppt_embedding.py +++ b/pilot/embedding_engine/ppt_embedding.py @@ -6,48 +6,36 @@ from langchain.document_loaders import UnstructuredPowerPointLoader from langchain.schema import Document from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter -from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register -CFG = Config() - class PPTEmbedding(SourceEmbedding): """ppt embedding for read ppt document.""" - def __init__(self, file_path, vector_store_config): - """Initialize with pdf path.""" - super().__init__(file_path, vector_store_config) + def __init__(self, file_path, vector_store_config, text_splitter=None): + """Initialize ppt word path.""" + super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path self.vector_store_config = vector_store_config + self.text_splitter = text_splitter or None @register def read(self): """Load from ppt path.""" loader = UnstructuredPowerPointLoader(self.file_path) - # textsplitter = SpacyTextSplitter( - # pipeline="zh_core_web_sm", - # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - # chunk_overlap=200, - # ) - if CFG.LANGUAGE == "en": - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=20, - length_function=len, - ) - else: + if self.text_splitter is None: try: - text_splitter = SpacyTextSplitter( + self.text_splitter = SpacyTextSplitter( pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_size=100, chunk_overlap=100, ) except Exception: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=100, chunk_overlap=50 ) - return loader.load_and_split(text_splitter) + + return loader.load_and_split(self.text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/source_embedding.py b/pilot/embedding_engine/source_embedding.py index 6d7500007..6eb5b2265 100644 --- a/pilot/embedding_engine/source_embedding.py +++ b/pilot/embedding_engine/source_embedding.py @@ -4,11 +4,11 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional from chromadb.errors import NotEnoughElementsException -from pilot.configs.config import Config +from langchain.text_splitter import TextSplitter + from pilot.vector_store.connector import VectorStoreConnector registered_methods = [] -CFG = Config() def register(method): @@ -25,12 +25,14 @@ class SourceEmbedding(ABC): def __init__( self, file_path, - vector_store_config, + vector_store_config: {}, + text_splitter: TextSplitter = None, embedding_args: Optional[Dict] = None, ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path self.vector_store_config = vector_store_config + self.text_splitter = text_splitter self.embedding_args = embedding_args self.embeddings = vector_store_config["embeddings"] @@ -44,8 +46,8 @@ class SourceEmbedding(ABC): """pre process data.""" @register - def text_split(self, text): - """text split chunk""" + def text_splitter(self, text_splitter: TextSplitter): + """add text split chunk""" pass @register @@ -57,7 +59,7 @@ class SourceEmbedding(ABC): def index_to_store(self, docs): """index to vector store""" self.vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, self.vector_store_config + self.vector_store_config["vector_store_type"], self.vector_store_config ) return self.vector_client.load_document(docs) @@ -65,7 +67,7 @@ class SourceEmbedding(ABC): def similar_search(self, doc, topk): """vector store similarity_search""" self.vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, self.vector_store_config + self.vector_store_config["vector_store_type"], self.vector_store_config ) try: ans = self.vector_client.similar_search(doc, topk) @@ -75,7 +77,7 @@ class SourceEmbedding(ABC): def vector_name_exist(self): self.vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, self.vector_store_config + self.vector_store_config["vector_store_type"], self.vector_store_config ) return self.vector_client.vector_name_exists() diff --git a/pilot/embedding_engine/string_embedding.py b/pilot/embedding_engine/string_embedding.py index 5839290fe..f6506c153 100644 --- a/pilot/embedding_engine/string_embedding.py +++ b/pilot/embedding_engine/string_embedding.py @@ -8,11 +8,12 @@ from pilot.embedding_engine import SourceEmbedding, register class StringEmbedding(SourceEmbedding): """string embedding for read string document.""" - def __init__(self, file_path, vector_store_config): - """Initialize with pdf path.""" - super().__init__(file_path, vector_store_config) + def __init__(self, file_path, vector_store_config, text_splitter=None): + """Initialize raw text word path.""" + super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path self.vector_store_config = vector_store_config + self.text_splitter = text_splitter or None @register def read(self): diff --git a/pilot/embedding_engine/url_embedding.py b/pilot/embedding_engine/url_embedding.py index 8b8976d03..8c79b05c5 100644 --- a/pilot/embedding_engine/url_embedding.py +++ b/pilot/embedding_engine/url_embedding.py @@ -5,43 +5,37 @@ from langchain.document_loaders import WebBaseLoader from langchain.schema import Document from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter -from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register -CFG = Config() class URLEmbedding(SourceEmbedding): """url embedding for read url document.""" - def __init__(self, file_path, vector_store_config): - """Initialize with url path.""" - super().__init__(file_path, vector_store_config) + def __init__(self, file_path, vector_store_config, text_splitter=None): + """Initialize url word path.""" + super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path self.vector_store_config = vector_store_config + self.text_splitter = text_splitter or None @register def read(self): """Load from url path.""" loader = WebBaseLoader(web_path=self.file_path) - if CFG.LANGUAGE == "en": - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=20, - length_function=len, - ) - else: + if self.text_splitter is None: try: - text_splitter = SpacyTextSplitter( + self.text_splitter = SpacyTextSplitter( pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_size=100, chunk_overlap=100, ) except Exception: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=100, chunk_overlap=50 ) - return loader.load_and_split(text_splitter) + + return loader.load_and_split(self.text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/word_embedding.py b/pilot/embedding_engine/word_embedding.py index 232c9c6f7..d7995bbf3 100644 --- a/pilot/embedding_engine/word_embedding.py +++ b/pilot/embedding_engine/word_embedding.py @@ -6,43 +6,36 @@ from langchain.document_loaders import UnstructuredWordDocumentLoader from langchain.schema import Document from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter -from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register -CFG = Config() - class WordEmbedding(SourceEmbedding): """word embedding for read word document.""" - def __init__(self, file_path, vector_store_config): + def __init__(self, file_path, vector_store_config, text_splitter=None): """Initialize with word path.""" - super().__init__(file_path, vector_store_config) + super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path self.vector_store_config = vector_store_config + self.text_splitter = text_splitter or None @register def read(self): """Load from word path.""" loader = UnstructuredWordDocumentLoader(self.file_path) - if CFG.LANGUAGE == "en": - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=20, - length_function=len, - ) - else: + if self.text_splitter is None: try: - text_splitter = SpacyTextSplitter( + self.text_splitter = SpacyTextSplitter( pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_size=100, chunk_overlap=100, ) except Exception: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=100, chunk_overlap=50 ) - return loader.load_and_split(text_splitter) + + return loader.load_and_split(self.text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 4a887164c..bc121a7c1 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -37,8 +37,8 @@ class ChatNewKnowledge(BaseChat): self.knowledge_name = knowledge_name vector_store_config = { "vector_store_name": knowledge_name, - "text_field": "content", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG["text2vec"], diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 3b0c2fa1e..1e45eec95 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -38,7 +38,8 @@ class ChatDefaultKnowledge(BaseChat): ) vector_store_config = { "vector_store_name": "default", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG["text2vec"], diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index 8903400a2..21698b8b6 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -38,7 +38,8 @@ class ChatUrlKnowledge(BaseChat): self.url = url vector_store_config = { "vector_store_name": url.replace(":", ""), - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 0bb80d97b..e075634de 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -38,7 +38,8 @@ class ChatKnowledge(BaseChat): ) vector_store_config = { "vector_store_name": knowledge_space, - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 0c13ae1cf..2647f5571 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -2,7 +2,7 @@ import threading from datetime import datetime from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.logs import logger from pilot.server.knowledge.chunk_db import ( @@ -128,6 +128,8 @@ class KnowledgeService: model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config={ "vector_store_name": space_name, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, }, ) chunk_docs = client.read() diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 6161984bb..ee2723405 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -665,6 +665,7 @@ def knowledge_embedding_store(vs_id, files): model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"], + "vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, }, ) diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index b1346a097..710c2101b 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -4,7 +4,7 @@ import uuid from langchain.embeddings import HuggingFaceEmbeddings, logger from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.scene.base import ChatScene from pilot.scene.base_chat import BaseChat from pilot.embedding_engine.embedding_engine import EmbeddingEngine @@ -33,6 +33,8 @@ class DBSummaryClient: ) vector_store_config = { "vector_store_name": dbname + "_summary", + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "embeddings": embeddings, } embedding = StringEmbedding( @@ -60,6 +62,8 @@ class DBSummaryClient: ) in db_summary_client.get_table_summary().items(): table_vector_store_config = { "vector_store_name": dbname + "_" + table_name + "_ts", + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "embeddings": embeddings, } embedding = StringEmbedding( @@ -73,6 +77,9 @@ class DBSummaryClient: def get_db_summary(self, dbname, query, topk): vector_store_config = { "vector_store_name": dbname + "_profile", + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], @@ -86,6 +93,9 @@ class DBSummaryClient: """get user query related tables info""" vector_store_config = { "vector_store_name": dbname + "_summary", + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], @@ -109,6 +119,9 @@ class DBSummaryClient: for table in related_tables: vector_store_config = { "vector_store_name": dbname + "_" + table + "_ts", + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } knowledge_embedding_client = EmbeddingEngine( file_path="", @@ -128,6 +141,8 @@ class DBSummaryClient: def init_db_profile(self, db_summary_client, dbname, embeddings): profile_store_config = { "vector_store_name": dbname + "_profile", + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + "vector_store_type": CFG.VECTOR_STORE_TYPE, "embeddings": embeddings, } embedding = StringEmbedding( diff --git a/pilot/vector_store/chroma_store.py b/pilot/vector_store/chroma_store.py index 35016aa09..6dc8eebad 100644 --- a/pilot/vector_store/chroma_store.py +++ b/pilot/vector_store/chroma_store.py @@ -1,7 +1,6 @@ import os from langchain.vectorstores import Chroma -from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.logs import logger from pilot.vector_store.vector_store_base import VectorStoreBase @@ -13,7 +12,7 @@ class ChromaStore(VectorStoreBase): self.ctx = ctx self.embeddings = ctx["embeddings"] self.persist_dir = os.path.join( - KNOWLEDGE_UPLOAD_ROOT_PATH, ctx["vector_store_name"] + ".vectordb" + ctx["chroma_persist_path"], ctx["vector_store_name"] + ".vectordb" ) self.vector_store_client = Chroma( persist_directory=self.persist_dir, embedding_function=self.embeddings diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 8eecb74e0..c8d7d3f6f 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,8 +1,8 @@ from pilot.vector_store.chroma_store import ChromaStore -# from pilot.vector_store.milvus_store import MilvusStore +from pilot.vector_store.milvus_store import MilvusStore -connector = {"Chroma": ChromaStore, "Milvus": None} +connector = {"Chroma": ChromaStore, "Milvus": MilvusStore} class VectorStoreConnector: diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 4535ea30a..1230bede9 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -3,11 +3,9 @@ from typing import Any, Iterable, List, Optional, Tuple from langchain.docstore.document import Document from pymilvus import Collection, DataType, connections, utility -from pilot.configs.config import Config from pilot.vector_store.vector_store_base import VectorStoreBase -CFG = Config() class MilvusStore(VectorStoreBase): @@ -22,10 +20,10 @@ class MilvusStore(VectorStoreBase): # self.configure(cfg) connect_kwargs = {} - self.uri = CFG.MILVUS_URL - self.port = CFG.MILVUS_PORT - self.username = CFG.MILVUS_USERNAME - self.password = CFG.MILVUS_PASSWORD + self.uri = ctx.get("milvus_url", None) + self.port = ctx.get("milvus_port", None) + self.username = ctx.get("milvus_username", None) + self.password = ctx.get("milvus_password", None) self.collection_name = ctx.get("vector_store_name", None) self.secure = ctx.get("secure", None) self.embedding = ctx.get("embeddings", None) diff --git a/tests/unit/embedding_engine/test_url_embedding.py b/tests/unit/embedding_engine/test_url_embedding.py index b281e1004..4cd7dcbd8 100644 --- a/tests/unit/embedding_engine/test_url_embedding.py +++ b/tests/unit/embedding_engine/test_url_embedding.py @@ -2,8 +2,12 @@ from pilot import EmbeddingEngine, KnowledgeType url = "https://db-gpt.readthedocs.io/en/latest/getting_started/getting_started.html" embedding_model = "text2vec" +vector_store_type = "Chroma" +chroma_persist_path = "your_persist_path" vector_store_config = { "vector_store_name": url.replace(":", ""), + "vector_store_type": vector_store_type, + "chroma_persist_path": chroma_persist_path } embedding_engine = EmbeddingEngine(knowledge_source=url, knowledge_type=KnowledgeType.URL.value, model_name=embedding_model, vector_store_config=vector_store_config) diff --git a/tools/knowledge_init.py b/tools/knowledge_init.py index 2f18fcf93..8ccd567c4 100644 --- a/tools/knowledge_init.py +++ b/tools/knowledge_init.py @@ -14,7 +14,7 @@ from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from pilot.configs.config import Config from pilot.configs.model_config import ( DATASETS_DIR, - LLM_MODEL_CONFIG, + LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH, ) from pilot.embedding_engine.embedding_engine import EmbeddingEngine @@ -68,7 +68,7 @@ if __name__ == "__main__": args = parser.parse_args() vector_name = args.vector_name store_type = CFG.VECTOR_STORE_TYPE - vector_store_config = {"vector_store_name": vector_name} + vector_store_config = {"vector_store_name": vector_name, "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH} print(vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config) kv.knowledge_persist(file_path=DATASETS_DIR)