diff --git a/.env.template b/.env.template
index 3e8ae536b..234b12738 100644
--- a/.env.template
+++ b/.env.template
@@ -28,8 +28,12 @@ MAX_POSITION_EMBEDDINGS=4096
# FAST_LLM_MODEL=chatglm-6b
-### EMBEDDINGS
-## EMBEDDING_MODEL - Model to use for creating embeddings
+#*******************************************************************#
+#** EMBEDDING SETTINGS **#
+#*******************************************************************#
+EMBEDDING_MODEL=text2vec
+KNOWLEDGE_CHUNK_SIZE=500
+KNOWLEDGE_SEARCH_TOP_SIZE=5
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
# EMBEDDING_MODEL=all-MiniLM-L6-v2
diff --git a/.gitignore b/.gitignore
index 2f91f3757..82fa7fe62 100644
--- a/.gitignore
+++ b/.gitignore
@@ -141,6 +141,7 @@ logs
nltk_data
.vectordb
pilot/data/
+pilot/nltk_data
logswebserver.log.*
-.history/*
\ No newline at end of file
+.history/*
diff --git a/README.md b/README.md
index c826619ad..129814fae 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,5 @@
-# DB-GPT: A LLM Tool for Multi Databases
+# DB-GPT: Revolutionizing Database Interactions with Private LLM Technology
+
-[](https://star-history.com/#csunny/DB-GPT)
-
## What is DB-GPT?
As large models are released and iterated upon, they are becoming increasingly intelligent. However, in the process of using large models, we face significant challenges in data security and privacy. We need to ensure that our sensitive data and environments remain completely controlled and avoid any data privacy leaks or security risks. Based on this, we have launched the DB-GPT project to build a complete private large model solution for all database-based scenarios. This solution supports local deployment, allowing it to be applied not only in independent private environments but also to be independently deployed and isolated according to business modules, ensuring that the ability of large models is absolutely private, secure, and controllable.
@@ -53,7 +52,18 @@ Currently, we have released multiple key features, which are listed below to dem
## Demo
-Run on an RTX 4090 GPU. [YouTube](https://www.youtube.com/watch?v=1PWI6F89LPo)
+Run on an RTX 4090 GPU.
+
+
+
+
+
+
+
+
+
+
+
## Introduction
DB-GPT creates a vast model operating system using [FastChat](https://github.com/lm-sys/FastChat) and offers a large language model powered by [Vicuna](https://huggingface.co/Tribbiani/vicuna-7b). In addition, we provide private domain knowledge base question-answering capability through LangChain. Furthermore, we also provide support for additional plugins, and our design natively supports the Auto-GPT plugin.
@@ -61,7 +71,7 @@ DB-GPT creates a vast model operating system using [FastChat](https://github.com
Is the architecture of the entire DB-GPT shown in the following figure:
-
+
The core capabilities mainly consist of the following parts:
@@ -216,3 +226,5 @@ The MIT License (MIT)
## Contact Information
We are working on building a community, if you have any ideas about building the community, feel free to contact us. [Discord](https://discord.gg/kMFf77FH)
+
+[](https://star-history.com/#csunny/DB-GPT)
diff --git a/README.zh.md b/README.zh.md
index 59f41646f..0c25c1d9d 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -1,4 +1,4 @@
-# DB-GPT: 数据库的 LLM 工具
+# DB-GPT: 用私有化LLM技术定义数据库下一代交互方式
-[](https://star-history.com/#csunny/DB-GPT)
-
## DB-GPT 是什么?
随着大模型的发布迭代,大模型变得越来越智能,在使用大模型的过程当中,遇到极大的数据安全与隐私挑战。在利用大模型能力的过程中我们的私密数据跟环境需要掌握自己的手里,完全可控,避免任何的数据隐私泄露以及安全风险。基于此,我们发起了DB-GPT项目,为所有以数据库为基础的场景,构建一套完整的私有大模型解决方案。 此方案因为支持本地部署,所以不仅仅可以应用于独立私有环境,而且还可以根据业务模块独立部署隔离,让大模型的能力绝对私有、安全、可控。
@@ -51,7 +49,22 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
## 效果演示
-示例通过 RTX 4090 GPU 演示,[YouTube 地址](https://www.youtube.com/watch?v=1PWI6F89LPo)
+示例通过 RTX 4090 GPU 演示
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
## 架构方案
DB-GPT基于 [FastChat](https://github.com/lm-sys/FastChat) 构建大模型运行环境,并提供 vicuna 作为基础的大语言模型。此外,我们通过LangChain提供私域知识库问答能力。同时我们支持插件模式, 在设计上原生支持Auto-GPT插件。
@@ -220,3 +233,6 @@ Run the Python interpreter and type the commands:
## Licence
The MIT License (MIT)
+
+[](https://star-history.com/#csunny/DB-GPT)
+
diff --git a/pilot/configs/config.py b/pilot/configs/config.py
index 3762b43c1..c4458eaf7 100644
--- a/pilot/configs/config.py
+++ b/pilot/configs/config.py
@@ -148,6 +148,8 @@ class Config(metaclass=Singleton):
### EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
+ self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))
+ self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10))
### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR")
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index 4b8b85a62..759245864 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -34,7 +34,6 @@ LLM_MODEL_CONFIG = {
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
- "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"proxyllm": "proxyllm",
}
diff --git a/pilot/conversation.py b/pilot/conversation.py
index 40759ffc8..2fab65dba 100644
--- a/pilot/conversation.py
+++ b/pilot/conversation.py
@@ -295,7 +295,7 @@ default_conversation = conv_default
chat_mode_title = {
- "sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"),
+ "sql_generate_diagnostics": get_lang_text("sql_generate_diagnostics"),
"chat_use_plugin": get_lang_text("chat_use_plugin"),
"knowledge_qa": get_lang_text("knowledge_qa"),
}
diff --git a/pilot/data/__init__.py b/pilot/data/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/pilot/datasets/mysql/url.md b/pilot/datasets/mysql/url.md
index e69de29bb..20592cb72 100644
--- a/pilot/datasets/mysql/url.md
+++ b/pilot/datasets/mysql/url.md
@@ -0,0 +1 @@
+LlamaIndex是一个数据框架,旨在帮助您构建LLM应用程序。它包括一个向量存储索引和一个简单的目录阅读器,可以帮助您处理和操作数据。此外,LlamaIndex还提供了一个GPT Index,可以用于数据增强和生成更好的LM模型。
\ No newline at end of file
diff --git a/pilot/language/lang_content_mapping.py b/pilot/language/lang_content_mapping.py
index 86aa3fa3c..afcfaeaba 100644
--- a/pilot/language/lang_content_mapping.py
+++ b/pilot/language/lang_content_mapping.py
@@ -44,13 +44,13 @@ lang_dicts = {
"learn_more_markdown": "The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B",
"model_control_param": "Model Parameters",
"sql_generate_mode_direct": "Execute directly",
- "sql_generate_mode_none": "Execute without model",
+ "sql_generate_mode_none": "Execute without mode",
"max_input_token_size": "Maximum output token size",
"please_choose_database": "Please choose database",
"sql_generate_diagnostics": "SQL Generation & Diagnostics",
"knowledge_qa_type_llm_native_dialogue": "LLM native dialogue",
"knowledge_qa_type_default_knowledge_base_dialogue": "Default documents",
- "knowledge_qa_type_add_knowledge_base_dialogue": "Added documents",
+ "knowledge_qa_type_add_knowledge_base_dialogue": "New documents",
"knowledge_qa_type_url_knowledge_dialogue": "Chat with url",
"dialogue_use_plugin": "Dialogue Extension",
"create_knowledge_base": "Create Knowledge Base",
@@ -60,7 +60,7 @@ lang_dicts = {
"sql_vs_setting": "In the automatic execution mode, DB-GPT can have the ability to execute SQL, read data from the network, automatically store and learn",
"chat_use_plugin": "Plugin Mode",
"select_plugin": "Select Plugin",
- "knowledge_qa": "Documents QA",
+ "knowledge_qa": "Documents Chat",
"configure_knowledge_base": "Configure Documents",
"url_input_label": "Please input url",
"new_klg_name": "New document name",
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index 64d3617bf..05c55fa74 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -97,6 +97,20 @@ class GuanacoAdapter(BaseLLMAdaper):
return model, tokenizer
+class GuanacoAdapter(BaseLLMAdaper):
+ """TODO Support guanaco"""
+
+ def match(self, model_path: str):
+ return "guanaco" in model_path
+
+ def loader(self, model_path: str, from_pretrained_kwargs: dict):
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
+ )
+ return model, tokenizer
+
+
class CodeGenAdapter(BaseLLMAdaper):
pass
diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py
index 1f4597789..73c732713 100644
--- a/pilot/scene/chat_db/auto_execute/chat.py
+++ b/pilot/scene/chat_db/auto_execute/chat.py
@@ -47,12 +47,13 @@ class ChatWithDbAutoExecute(BaseChat):
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
+ client = DBSummaryClient()
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect)
- # "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
+ # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values
diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py
index 66b751533..cb2425ea9 100644
--- a/pilot/scene/chat_db/professional_qa/chat.py
+++ b/pilot/scene/chat_db/professional_qa/chat.py
@@ -35,7 +35,13 @@ class ChatWithDbQA(BaseChat):
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
- self.top_k: int = 5
+ self.tables = self.database.get_table_names()
+
+ self.top_k = (
+ CFG.KNOWLEDGE_SEARCH_TOP_SIZE
+ if len(self.tables) > CFG.KNOWLEDGE_SEARCH_TOP_SIZE
+ else len(self.tables)
+ )
def generate_input_values(self):
table_info = ""
@@ -45,7 +51,8 @@ class ChatWithDbQA(BaseChat):
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
if self.db_name:
- table_info = DBSummaryClient.get_similar_tables(
+ client = DBSummaryClient()
+ table_info = client.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
# table_info = self.database.table_simple_info(self.db_connect)
diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py
index 7600bab79..a56b2a098 100644
--- a/pilot/scene/chat_knowledge/custom/chat.py
+++ b/pilot/scene/chat_knowledge/custom/chat.py
@@ -14,7 +14,6 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
- VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_knowledge.custom.prompt import prompt
@@ -46,15 +45,13 @@ class ChatNewKnowledge(BaseChat):
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
- file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
- local_persist=False,
vector_store_config=vector_store_config,
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(
- self.current_user_input, VECTOR_SEARCH_TOP_K
+ self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
)
context = [d.page_content for d in docs]
context = context[:2000]
diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py
index 110250221..4892e28cd 100644
--- a/pilot/scene/chat_knowledge/custom/prompt.py
+++ b/pilot/scene/chat_knowledge/custom/prompt.py
@@ -14,13 +14,23 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
-_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
+_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
已知内容:
{context}
问题:
{question}
"""
+_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
+ known information:
+ {context}
+ question:
+ {question}
+"""
+
+_DEFAULT_TEMPLATE = (
+ _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
+)
PROMPT_SEP = SeparatorStyle.SINGLE.value
diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py
index 1a482b154..325b03783 100644
--- a/pilot/scene/chat_knowledge/default/chat.py
+++ b/pilot/scene/chat_knowledge/default/chat.py
@@ -14,7 +14,6 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
- VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_knowledge.default.prompt import prompt
@@ -42,15 +41,13 @@ class ChatDefaultKnowledge(BaseChat):
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
- file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
- local_persist=False,
vector_store_config=vector_store_config,
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(
- self.current_user_input, VECTOR_SEARCH_TOP_K
+ self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
)
context = [d.page_content for d in docs]
context = context[:2000]
diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py
index 0526be69b..0fd9f9ff3 100644
--- a/pilot/scene/chat_knowledge/default/prompt.py
+++ b/pilot/scene/chat_knowledge/default/prompt.py
@@ -15,13 +15,23 @@ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelli
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
-_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
+_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
已知内容:
{context}
问题:
{question}
"""
+_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
+ known information:
+ {context}
+ question:
+ {question}
+"""
+
+_DEFAULT_TEMPLATE = (
+ _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
+)
PROMPT_SEP = SeparatorStyle.SINGLE.value
diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py
index cc8d89d4a..88dc7ad0b 100644
--- a/pilot/scene/chat_knowledge/url/chat.py
+++ b/pilot/scene/chat_knowledge/url/chat.py
@@ -14,7 +14,6 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
- VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_knowledge.url.prompt import prompt
@@ -40,15 +39,13 @@ class ChatUrlKnowledge(BaseChat):
self.url = url
vector_store_config = {
"vector_store_name": url,
- "text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
- file_path=url,
- file_type="url",
model_name=LLM_MODEL_CONFIG["text2vec"],
- local_persist=False,
vector_store_config=vector_store_config,
+ file_type="url",
+ file_path=url,
)
# url soruce in vector
@@ -58,7 +55,7 @@ class ChatUrlKnowledge(BaseChat):
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(
- self.current_user_input, VECTOR_SEARCH_TOP_K
+ self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
)
context = [d.page_content for d in docs]
context = context[:2000]
diff --git a/pilot/scene/chat_knowledge/url/prompt.py b/pilot/scene/chat_knowledge/url/prompt.py
index 38d5dfe35..3e9659130 100644
--- a/pilot/scene/chat_knowledge/url/prompt.py
+++ b/pilot/scene/chat_knowledge/url/prompt.py
@@ -14,20 +14,23 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge.
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
-
-# _DEFAULT_TEMPLATE = """ Based on the known information, provide professional and concise answers to the user's questions. If the answer cannot be obtained from the provided content, please say: 'The information provided in the knowledge base is not sufficient to answer this question.' Fabrication is prohibited.。
-# known information:
-# {context}
-# question:
-# {question}
-# """
-_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
+_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
已知内容:
{context}
问题:
{question}
"""
+_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
+ known information:
+ {context}
+ question:
+ {question}
+"""
+
+_DEFAULT_TEMPLATE = (
+ _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
+)
PROMPT_SEP = SeparatorStyle.SINGLE.value
diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py
index 4dec22655..d4ab8ae09 100644
--- a/pilot/server/chat_adapter.py
+++ b/pilot/server/chat_adapter.py
@@ -110,6 +110,7 @@ register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
+
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)
diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py
index 9faae5eb8..2a09e6a98 100644
--- a/pilot/server/vectordb_qa.py
+++ b/pilot/server/vectordb_qa.py
@@ -3,12 +3,14 @@
from langchain.prompts import PromptTemplate
-from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
+from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
from pilot.logs import logger
from pilot.model.llm_out.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector
+CFG = Config()
+
class KnownLedgeBaseQA:
def __init__(self) -> None:
@@ -22,7 +24,7 @@ class KnownLedgeBaseQA:
)
retriever = self.vector_store.as_retriever(
- search_kwargs={"k": VECTOR_SEARCH_TOP_K}
+ search_kwargs={"k": CFG.KNOWLEDGE_SEARCH_TOP_SIZE}
)
docs = retriever.get_relevant_documents(query=query)
diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py
index f7655fd7d..270f3b681 100644
--- a/pilot/server/webserver.py
+++ b/pilot/server/webserver.py
@@ -1,5 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
+import signal
+import threading
import traceback
import argparse
import datetime
@@ -205,7 +207,13 @@ def post_process_code(code):
def get_chat_mode(selected, param=None) -> ChatScene:
if chat_mode_title["chat_use_plugin"] == selected:
return ChatScene.ChatExecution
- elif chat_mode_title["knowledge_qa"] == selected:
+ elif chat_mode_title["sql_generate_diagnostics"] == selected:
+ sql_mode = param
+ if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
+ return ChatScene.ChatWithDbExecute
+ else:
+ return ChatScene.ChatWithDbQA
+ else:
mode = param
if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge
@@ -215,12 +223,6 @@ def get_chat_mode(selected, param=None) -> ChatScene:
return ChatScene.ChatUrlKnowledge
else:
return ChatScene.ChatNormal
- else:
- sql_mode = param
- if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
- return ChatScene.ChatWithDbExecute
- else:
- return ChatScene.ChatWithDbQA
def chatbot_callback(state, message):
@@ -244,12 +246,13 @@ def http_bot(
logger.info(
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}"
)
- if chat_mode_title["knowledge_qa"] == selected:
- scene: ChatScene = get_chat_mode(selected, mode)
+ if chat_mode_title["sql_generate_diagnostics"] == selected:
+ scene: ChatScene = get_chat_mode(selected, sql_mode)
elif chat_mode_title["chat_use_plugin"] == selected:
scene: ChatScene = get_chat_mode(selected)
else:
- scene: ChatScene = get_chat_mode(selected, sql_mode)
+ scene: ChatScene = get_chat_mode(selected, mode)
+
print(f"chat scene:{scene.value}")
if ChatScene.ChatWithDbExecute == scene:
@@ -402,58 +405,6 @@ def build_single_model_ui():
tabs.select(on_select, None, selected)
with tabs:
- tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
- with tab_sql:
- # TODO A selector to choose database
- with gr.Row(elem_id="db_selector"):
- db_selector = gr.Dropdown(
- label=get_lang_text("please_choose_database"),
- choices=dbs,
- value=dbs[0] if len(models) > 0 else "",
- interactive=True,
- show_label=True,
- ).style(container=False)
-
- db_selector.change(fn=db_selector_changed, inputs=db_selector)
-
- sql_mode = gr.Radio(
- [
- get_lang_text("sql_generate_mode_direct"),
- get_lang_text("sql_generate_mode_none"),
- ],
- show_label=False,
- value=get_lang_text("sql_generate_mode_none"),
- )
- sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
- sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
-
- tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN")
- # tab_plugin.select(change_func)
- with tab_plugin:
- print("tab_plugin in...")
- with gr.Row(elem_id="plugin_selector"):
- # TODO
- plugin_selector = gr.Dropdown(
- label=get_lang_text("select_plugin"),
- choices=list(plugins_select_info().keys()),
- value="",
- interactive=True,
- show_label=True,
- type="value",
- ).style(container=False)
-
- def plugin_change(
- evt: gr.SelectData,
- ): # SelectData is a subclass of EventData
- print(f"You selected {evt.value} at {evt.index} from {evt.target}")
- print(f"user plugin:{plugins_select_info().get(evt.value)}")
- return plugins_select_info().get(evt.value)
-
- plugin_selected = gr.Textbox(
- show_label=False, visible=False, placeholder="Selected"
- )
- plugin_selector.select(plugin_change, None, plugin_selected)
-
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
with tab_qa:
mode = gr.Radio(
@@ -516,6 +467,58 @@ def build_single_model_ui():
get_lang_text("upload_and_load_to_klg")
)
+ tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
+ with tab_sql:
+ # TODO A selector to choose database
+ with gr.Row(elem_id="db_selector"):
+ db_selector = gr.Dropdown(
+ label=get_lang_text("please_choose_database"),
+ choices=dbs,
+ value=dbs[0] if len(models) > 0 else "",
+ interactive=True,
+ show_label=True,
+ ).style(container=False)
+
+ # db_selector.change(fn=db_selector_changed, inputs=db_selector)
+
+ sql_mode = gr.Radio(
+ [
+ get_lang_text("sql_generate_mode_direct"),
+ get_lang_text("sql_generate_mode_none"),
+ ],
+ show_label=False,
+ value=get_lang_text("sql_generate_mode_none"),
+ )
+ sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
+ sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
+
+ tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN")
+ # tab_plugin.select(change_func)
+ with tab_plugin:
+ print("tab_plugin in...")
+ with gr.Row(elem_id="plugin_selector"):
+ # TODO
+ plugin_selector = gr.Dropdown(
+ label=get_lang_text("select_plugin"),
+ choices=list(plugins_select_info().keys()),
+ value="",
+ interactive=True,
+ show_label=True,
+ type="value",
+ ).style(container=False)
+
+ def plugin_change(
+ evt: gr.SelectData,
+ ): # SelectData is a subclass of EventData
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
+ print(f"user plugin:{plugins_select_info().get(evt.value)}")
+ return plugins_select_info().get(evt.value)
+
+ plugin_selected = gr.Textbox(
+ show_label=False, visible=False, placeholder="Selected"
+ )
+ plugin_selector.select(plugin_change, None, plugin_selected)
+
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
@@ -618,10 +621,6 @@ def save_vs_name(vs_name):
return vs_name
-def db_selector_changed(dbname):
- DBSummaryClient.db_summary_embedding(dbname)
-
-
def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
@@ -634,7 +633,6 @@ def knowledge_embedding_store(vs_id, files):
knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG["text2vec"],
- local_persist=False,
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
@@ -646,6 +644,17 @@ def knowledge_embedding_store(vs_id, files):
return vs_id
+def async_db_summery():
+ client = DBSummaryClient()
+ thread = threading.Thread(target=client.init_db_summary)
+ thread.start()
+
+
+def signal_handler(sig, frame):
+ print("in order to avoid chroma db atexit problem")
+ os._exit(0)
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
@@ -662,7 +671,8 @@ if __name__ == "__main__":
cfg = Config()
dbs = cfg.local_db.get_database_list()
-
+ signal.signal(signal.SIGINT, signal_handler)
+ async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令
diff --git a/pilot/source_embedding/EncodeTextLoader.py b/pilot/source_embedding/EncodeTextLoader.py
new file mode 100644
index 000000000..2b7344f18
--- /dev/null
+++ b/pilot/source_embedding/EncodeTextLoader.py
@@ -0,0 +1,26 @@
+from typing import List, Optional
+import chardet
+
+from langchain.docstore.document import Document
+from langchain.document_loaders.base import BaseLoader
+
+
+class EncodeTextLoader(BaseLoader):
+ """Load text files."""
+
+ def __init__(self, file_path: str, encoding: Optional[str] = None):
+ """Initialize with file path."""
+ self.file_path = file_path
+ self.encoding = encoding
+
+ def load(self) -> List[Document]:
+ """Load from file path."""
+ with open(self.file_path, "rb") as f:
+ raw_text = f.read()
+ result = chardet.detect(raw_text)
+ if result["encoding"] is None:
+ text = raw_text.decode("utf-8")
+ else:
+ text = raw_text.decode(result["encoding"])
+ metadata = {"source": self.file_path}
+ return [Document(page_content=text, metadata=metadata)]
diff --git a/pilot/source_embedding/csv_embedding.py b/pilot/source_embedding/csv_embedding.py
index 8b2e25ff3..0e69574b4 100644
--- a/pilot/source_embedding/csv_embedding.py
+++ b/pilot/source_embedding/csv_embedding.py
@@ -12,14 +12,12 @@ class CSVEmbedding(SourceEmbedding):
def __init__(
self,
file_path,
- model_name,
vector_store_config,
embedding_args: Optional[Dict] = None,
):
"""Initialize with csv path."""
- super().__init__(file_path, model_name, vector_store_config)
+ super().__init__(file_path, vector_store_config)
self.file_path = file_path
- self.model_name = model_name
self.vector_store_config = vector_store_config
self.embedding_args = embedding_args
diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py
index f58742ee9..7ec0de76c 100644
--- a/pilot/source_embedding/knowledge_embedding.py
+++ b/pilot/source_embedding/knowledge_embedding.py
@@ -1,30 +1,34 @@
-import os
+from typing import Optional
-import markdown
-from bs4 import BeautifulSoup
-from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
-from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
-from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding
+from pilot.source_embedding.word_embedding import WordEmbedding
from pilot.vector_store.connector import VectorStoreConnector
CFG = Config()
+KnowledgeEmbeddingType = {
+ ".txt": (MarkdownEmbedding, {}),
+ ".md": (MarkdownEmbedding, {}),
+ ".pdf": (PDFEmbedding, {}),
+ ".doc": (WordEmbedding, {}),
+ ".docx": (WordEmbedding, {}),
+ ".csv": (CSVEmbedding, {}),
+}
+
class KnowledgeEmbedding:
def __init__(
self,
- file_path,
model_name,
vector_store_config,
- local_persist=True,
- file_type="default",
+ file_type: Optional[str] = "default",
+ file_path: Optional[str] = None,
):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
@@ -33,11 +37,9 @@ class KnowledgeEmbedding:
self.file_type = file_type
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
- self.local_persist = local_persist
- if not self.local_persist:
- self.knowledge_embedding_client = self.init_knowledge_embedding()
def knowledge_embedding(self):
+ self.knowledge_embedding_client = self.init_knowledge_embedding()
self.knowledge_embedding_client.source_embedding()
def knowledge_embedding_batch(self):
@@ -47,98 +49,29 @@ class KnowledgeEmbedding:
if self.file_type == "url":
embedding = URLEmbedding(
file_path=self.file_path,
- model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
- elif self.file_path.endswith(".pdf"):
- embedding = PDFEmbedding(
- file_path=self.file_path,
- model_name=self.model_name,
+ return embedding
+ extension = "." + self.file_path.rsplit(".", 1)[-1]
+ if extension in KnowledgeEmbeddingType:
+ knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
+ embedding = knowledge_class(
+ self.file_path,
vector_store_config=self.vector_store_config,
+ **knowledge_args,
)
- elif self.file_path.endswith(".md"):
- embedding = MarkdownEmbedding(
- file_path=self.file_path,
- model_name=self.model_name,
- vector_store_config=self.vector_store_config,
- )
-
- elif self.file_path.endswith(".csv"):
- embedding = CSVEmbedding(
- file_path=self.file_path,
- model_name=self.model_name,
- vector_store_config=self.vector_store_config,
- )
-
- elif self.file_type == "default":
- embedding = MarkdownEmbedding(
- file_path=self.file_path,
- model_name=self.model_name,
- vector_store_config=self.vector_store_config,
- )
-
+ return embedding
+ raise ValueError(f"Unsupported knowledge file type '{extension}'")
return embedding
def similar_search(self, text, topk):
- return self.knowledge_embedding_client.similar_search(text, topk)
-
- def vector_exist(self):
- return self.knowledge_embedding_client.vector_name_exist()
-
- def knowledge_persist_initialization(self, append_mode):
- documents = self._load_knownlege(self.file_path)
- self.vector_client = VectorStoreConnector(
+ vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
- self.vector_client.load_document(documents)
- return self.vector_client
+ return vector_client.similar_search(text, topk)
- def _load_knownlege(self, path):
- docments = []
- for root, _, files in os.walk(path, topdown=False):
- for file in files:
- filename = os.path.join(root, file)
- docs = self._load_file(filename)
- new_docs = []
- for doc in docs:
- doc.metadata = {
- "source": doc.metadata["source"].replace(DATASETS_DIR, "")
- }
- print("doc is embedding...", doc.metadata)
- new_docs.append(doc)
- docments += new_docs
- return docments
-
- def _load_file(self, filename):
- if filename.lower().endswith(".md"):
- loader = TextLoader(filename)
- text_splitter = CHNDocumentSplitter(
- pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
- )
- docs = loader.load_and_split(text_splitter)
- i = 0
- for d in docs:
- content = markdown.markdown(d.page_content)
- soup = BeautifulSoup(content, "html.parser")
- for tag in soup(["!doctype", "meta", "i.fa"]):
- tag.extract()
- docs[i].page_content = soup.get_text()
- docs[i].page_content = docs[i].page_content.replace("\n", " ")
- i += 1
- elif filename.lower().endswith(".pdf"):
- loader = PyPDFLoader(filename)
- textsplitter = CHNDocumentSplitter(
- pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
- )
- docs = loader.load_and_split(textsplitter)
- i = 0
- for d in docs:
- docs[i].page_content = d.page_content.replace("\n", " ").replace(
- "�", ""
- )
- i += 1
- else:
- loader = TextLoader(filename)
- text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
- docs = loader.load_and_split(text_splitor)
- return docs
+ def vector_exist(self):
+ vector_client = VectorStoreConnector(
+ CFG.VECTOR_STORE_TYPE, self.vector_store_config
+ )
+ return vector_client.vector_name_exists()
diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py
index 3db6cdbf5..e2851d122 100644
--- a/pilot/source_embedding/markdown_embedding.py
+++ b/pilot/source_embedding/markdown_embedding.py
@@ -8,27 +8,30 @@ from bs4 import BeautifulSoup
from langchain.document_loaders import TextLoader
from langchain.schema import Document
-from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
+from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register
+from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
+CFG = Config()
+
class MarkdownEmbedding(SourceEmbedding):
"""markdown embedding for read markdown document."""
- def __init__(self, file_path, model_name, vector_store_config):
+ def __init__(self, file_path, vector_store_config):
"""Initialize with markdown path."""
- super().__init__(file_path, model_name, vector_store_config)
+ super().__init__(file_path, vector_store_config)
self.file_path = file_path
- self.model_name = model_name
self.vector_store_config = vector_store_config
+ # self.encoding = encoding
@register
def read(self):
"""Load from markdown path."""
- loader = TextLoader(self.file_path)
+ loader = EncodeTextLoader(self.file_path)
text_splitter = CHNDocumentSplitter(
- pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
+ pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
)
return loader.load_and_split(text_splitter)
diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py
index de1767c51..54f77fb81 100644
--- a/pilot/source_embedding/pdf_embedding.py
+++ b/pilot/source_embedding/pdf_embedding.py
@@ -5,19 +5,20 @@ from typing import List
from langchain.document_loaders import PyPDFLoader
from langchain.schema import Document
-from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
+from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
+CFG = Config()
+
class PDFEmbedding(SourceEmbedding):
"""pdf embedding for read pdf document."""
- def __init__(self, file_path, model_name, vector_store_config):
+ def __init__(self, file_path, vector_store_config):
"""Initialize with pdf path."""
- super().__init__(file_path, model_name, vector_store_config)
+ super().__init__(file_path, vector_store_config)
self.file_path = file_path
- self.model_name = model_name
self.vector_store_config = vector_store_config
@register
@@ -26,7 +27,7 @@ class PDFEmbedding(SourceEmbedding):
# loader = UnstructuredPaddlePDFLoader(self.file_path)
loader = PyPDFLoader(self.file_path)
textsplitter = CHNDocumentSplitter(
- pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
+ pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
)
return loader.load_and_split(textsplitter)
diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py
index 7db92ea9b..50c7044f9 100644
--- a/pilot/source_embedding/source_embedding.py
+++ b/pilot/source_embedding/source_embedding.py
@@ -23,13 +23,11 @@ class SourceEmbedding(ABC):
def __init__(
self,
file_path,
- model_name,
vector_store_config,
embedding_args: Optional[Dict] = None,
):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
- self.model_name = model_name
self.vector_store_config = vector_store_config
self.embedding_args = embedding_args
self.embeddings = vector_store_config["embeddings"]
diff --git a/pilot/source_embedding/string_embedding.py b/pilot/source_embedding/string_embedding.py
index b4d7b1228..a1d18ee82 100644
--- a/pilot/source_embedding/string_embedding.py
+++ b/pilot/source_embedding/string_embedding.py
@@ -8,11 +8,10 @@ from pilot import SourceEmbedding, register
class StringEmbedding(SourceEmbedding):
"""string embedding for read string document."""
- def __init__(self, file_path, model_name, vector_store_config):
+ def __init__(self, file_path, vector_store_config):
"""Initialize with pdf path."""
- super().__init__(file_path, model_name, vector_store_config)
+ super().__init__(file_path, vector_store_config)
self.file_path = file_path
- self.model_name = model_name
self.vector_store_config = vector_store_config
@register
diff --git a/pilot/source_embedding/url_embedding.py b/pilot/source_embedding/url_embedding.py
index 39224a9f4..a315e6e45 100644
--- a/pilot/source_embedding/url_embedding.py
+++ b/pilot/source_embedding/url_embedding.py
@@ -16,11 +16,10 @@ CFG = Config()
class URLEmbedding(SourceEmbedding):
"""url embedding for read url document."""
- def __init__(self, file_path, model_name, vector_store_config):
+ def __init__(self, file_path, vector_store_config):
"""Initialize with url path."""
- super().__init__(file_path, model_name, vector_store_config)
+ super().__init__(file_path, vector_store_config)
self.file_path = file_path
- self.model_name = model_name
self.vector_store_config = vector_store_config
@register
@@ -29,7 +28,7 @@ class URLEmbedding(SourceEmbedding):
loader = WebBaseLoader(web_path=self.file_path)
if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter(
- chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE,
+ chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
diff --git a/pilot/source_embedding/word_embedding.py b/pilot/source_embedding/word_embedding.py
new file mode 100644
index 000000000..1f30f241c
--- /dev/null
+++ b/pilot/source_embedding/word_embedding.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from typing import List
+
+from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader
+from langchain.schema import Document
+
+from pilot.configs.config import Config
+from pilot.source_embedding import SourceEmbedding, register
+from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
+
+CFG = Config()
+
+
+class WordEmbedding(SourceEmbedding):
+ """word embedding for read word document."""
+
+ def __init__(self, file_path, vector_store_config):
+ """Initialize with word path."""
+ super().__init__(file_path, vector_store_config)
+ self.file_path = file_path
+ self.vector_store_config = vector_store_config
+
+ @register
+ def read(self):
+ """Load from word path."""
+ loader = UnstructuredWordDocumentLoader(self.file_path)
+ textsplitter = CHNDocumentSplitter(
+ pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
+ )
+ return loader.load_and_split(textsplitter)
+
+ @register
+ def data_process(self, documents: List[Document]):
+ i = 0
+ for d in documents:
+ documents[i].page_content = d.page_content.replace("\n", "")
+ i += 1
+ return documents
diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py
index 91805ddd4..51f124f62 100644
--- a/pilot/summary/db_summary_client.py
+++ b/pilot/summary/db_summary_client.py
@@ -21,8 +21,10 @@ class DBSummaryClient:
, get_similar_tables method(get user query related tables info)
"""
- @staticmethod
- def db_summary_embedding(dbname):
+ def __init__(self):
+ pass
+
+ def db_summary_embedding(self, dbname):
"""put db profile and table profile summary into vector store"""
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
db_summary_client = MysqlSummary(dbname)
@@ -34,24 +36,21 @@ class DBSummaryClient:
"embeddings": embeddings,
}
embedding = StringEmbedding(
- db_summary_client.get_summery(),
- LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
- vector_store_config,
+ file_path=db_summary_client.get_summery(),
+ vector_store_config=vector_store_config,
)
if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summery():
embedding = StringEmbedding(
vector_table_info,
- LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config,
)
embedding.source_embedding()
else:
embedding = StringEmbedding(
- db_summary_client.get_summery(),
- LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
- vector_store_config,
+ file_path=db_summary_client.get_summery(),
+ vector_store_config=vector_store_config,
)
embedding.source_embedding()
for (
@@ -59,32 +58,24 @@ class DBSummaryClient:
table_summary,
) in db_summary_client.get_table_summary().items():
table_vector_store_config = {
- "vector_store_name": table_name + "_ts",
+ "vector_store_name": dbname + "_" + table_name + "_ts",
"embeddings": embeddings,
}
embedding = StringEmbedding(
table_summary,
- LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
table_vector_store_config,
)
embedding.source_embedding()
logger.info("db summary embedding success")
- @staticmethod
- def get_similar_tables(dbname, query, topk):
+ def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
- embeddings = HuggingFaceEmbeddings(
- model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
- )
vector_store_config = {
"vector_store_name": dbname + "_profile",
- "embeddings": embeddings,
}
knowledge_embedding_client = KnowledgeEmbedding(
- file_path="",
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
- local_persist=False,
vector_store_config=vector_store_config,
)
if CFG.SUMMARY_CONFIG == "FAST":
@@ -104,19 +95,23 @@ class DBSummaryClient:
related_table_summaries = []
for table in related_tables:
vector_store_config = {
- "vector_store_name": table + "_ts",
- "embeddings": embeddings,
+ "vector_store_name": dbname + "_" + table + "_ts",
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
- local_persist=False,
vector_store_config=vector_store_config,
)
table_summery = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summery[0].page_content)
return related_table_summaries
+ def init_db_summary(self):
+ db = CFG.local_db
+ dbs = db.get_database_list()
+ for dbname in dbs:
+ self.db_summary_embedding(dbname)
+
def _get_llm_response(query, db_input, dbsummary):
chat_param = {
@@ -132,30 +127,3 @@ def _get_llm_response(query, db_input, dbsummary):
)
res = chat.nostream_call()
return json.loads(res)["table"]
-
-
-# if __name__ == "__main__":
-# # summary = DBSummaryClient.get_similar_tables("db_test", "查询在线用户的购物车", 10)
-#
-# text= """Based on the input "查询在线聊天的用户好友" and the known database information, the tables involved in the user input are "chat_users" and "friends".
-# Response:
-#
-# {
-# "table": ["chat_users"]
-# }"""
-# text = text.rstrip().replace("\n","")
-# start = text.find("{")
-# end = text.find("}") + 1
-#
-# # 从字符串中截取出JSON数据
-# json_str = text[start:end]
-#
-# # 将JSON数据转换为Python中的字典类型
-# data = json.loads(json_str)
-# # pattern = r'{s*"table"s*:s*[[^]]*]s*}'
-# # match = re.search(pattern, text)
-# # if match:
-# # json_string = match.group(0)
-# # # 将JSON字符串转换为Python对象
-# # json_obj = json.loads(json_string)
-# # print(summary)
diff --git a/pilot/vector_store/chroma_store.py b/pilot/vector_store/chroma_store.py
index 3a9de6874..4949924d4 100644
--- a/pilot/vector_store/chroma_store.py
+++ b/pilot/vector_store/chroma_store.py
@@ -1,7 +1,6 @@
import os
from langchain.vectorstores import Chroma
-
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger
from pilot.vector_store.vector_store_base import VectorStoreBase
diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py
index c42eda7a6..cca027324 100644
--- a/pilot/vector_store/file_loader.py
+++ b/pilot/vector_store/file_loader.py
@@ -17,7 +17,6 @@ from langchain.vectorstores import Chroma
from pilot.configs.model_config import (
DATASETS_DIR,
LLM_MODEL_CONFIG,
- VECTOR_SEARCH_TOP_K,
VECTORE_PATH,
)
@@ -41,7 +40,6 @@ class KnownLedge2Vector:
embeddings: object = None
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
- top_k: int = VECTOR_SEARCH_TOP_K
def __init__(self, model_name=None) -> None:
if not model_name:
diff --git a/run.sh b/run.sh
index a8948ad0f..81d3ec22b 100644
--- a/run.sh
+++ b/run.sh
@@ -22,4 +22,4 @@ while [ `grep -c "Uvicorn running on" /root/server.log` -eq '0' ];do
done
echo "server running"
-PYTHONCMD pilot/server/webserver.py
\ No newline at end of file
+PYTHONCMD pilot/server/webserver.py
diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py
index 03c9633d3..ff13865b4 100644
--- a/tools/knowlege_init.py
+++ b/tools/knowlege_init.py
@@ -10,7 +10,6 @@ from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
LLM_MODEL_CONFIG,
- VECTOR_SEARCH_TOP_K,
)
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
@@ -19,36 +18,30 @@ CFG = Config()
class LocalKnowledgeInit:
embeddings: object = None
- model_name = LLM_MODEL_CONFIG["text2vec"]
- top_k: int = VECTOR_SEARCH_TOP_K
def __init__(self, vector_store_config) -> None:
self.vector_store_config = vector_store_config
+ self.model_name = LLM_MODEL_CONFIG["text2vec"]
def knowledge_persist(self, file_path, append_mode):
"""knowledge persist"""
- kv = KnowledgeEmbedding(
- file_path=file_path,
- model_name=LLM_MODEL_CONFIG["text2vec"],
- vector_store_config=self.vector_store_config,
- )
- vector_store = kv.knowledge_persist_initialization(append_mode)
- return vector_store
-
- def query(self, q):
- """Query similar doc from Vector"""
- vector_store = self.init_vector_store()
- docs = vector_store.similarity_search_with_score(q, k=self.top_k)
- for doc in docs:
- dc, s = doc
- yield s, dc
+ for root, _, files in os.walk(file_path, topdown=False):
+ for file in files:
+ filename = os.path.join(root, file)
+ # docs = self._load_file(filename)
+ ke = KnowledgeEmbedding(
+ file_path=filename,
+ model_name=self.model_name,
+ vector_store_config=self.vector_store_config,
+ )
+ client = ke.init_knowledge_embedding()
+ client.source_embedding()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vector_name", type=str, default="default")
parser.add_argument("--append", type=bool, default=False)
- parser.add_argument("--store_type", type=str, default="Chroma")
args = parser.parse_args()
vector_name = args.vector_name
append_mode = args.append
@@ -56,5 +49,5 @@ if __name__ == "__main__":
vector_store_config = {"vector_store_name": vector_name}
print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
- vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
+ kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
print("your knowledge embedding success...")