mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
@@ -1,10 +1,10 @@
|
|||||||
# DB-GPT
|
# DB-GPT
|
||||||
A Open Database-GPT Experiment, A fully localized project.
|
A Open Database-GPT Experiment, A fully localized project.
|
||||||
|
|
||||||
一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。
|
||||||
|
|
||||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用[vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)作为基础模型。此外,此程序结合了[langchain](https://github.com/hwchase17/langchain)和[llama-index](https://github.com/jerryjliu/llama_index)基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。
|
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用[vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)作为基础模型。此外,此程序结合了[langchain](https://github.com/hwchase17/langchain)和[llama-index](https://github.com/jerryjliu/llama_index)基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。
|
||||||
|
|
||||||
|
|
||||||
|
@@ -22,7 +22,7 @@ LLM_MODEL_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
VECTOR_SEARCH_TOP_K = 5
|
VECTOR_SEARCH_TOP_K = 3
|
||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
LIMIT_MODEL_CONCURRENCY = 5
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
MAX_POSITION_EMBEDDINGS = 2048
|
MAX_POSITION_EMBEDDINGS = 2048
|
||||||
|
@@ -147,7 +147,7 @@ conv_vicuna_v1 = Conversation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
||||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
||||||
|
|
||||||
已知内容:
|
已知内容:
|
||||||
@@ -158,6 +158,12 @@ conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回
|
|||||||
|
|
||||||
default_conversation = conv_one_shot
|
default_conversation = conv_one_shot
|
||||||
|
|
||||||
|
conversation_types = {
|
||||||
|
"native": "LLM原生对话",
|
||||||
|
"default_knownledge": "默认知识库对话",
|
||||||
|
"custome": "新增知识库对话",
|
||||||
|
}
|
||||||
|
|
||||||
conv_templates = {
|
conv_templates = {
|
||||||
"conv_one_shot": conv_one_shot,
|
"conv_one_shot": conv_one_shot,
|
||||||
"vicuna_v1": conv_vicuna_v1,
|
"vicuna_v1": conv_vicuna_v1,
|
||||||
|
@@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List
|
|||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
|
|
||||||
class VicunaRequestLLM(LLM):
|
class VicunaLLM(LLM):
|
||||||
|
|
||||||
vicuna_generate_path = "generate"
|
vicuna_generate_path = "generate_stream"
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
|
||||||
if isinstance(stop, list):
|
|
||||||
stop = stop + ["Observation:"]
|
|
||||||
|
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
|
||||||
params = {
|
params = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"temperature": 0.7,
|
"temperature": temperature,
|
||||||
"max_new_tokens": 1024,
|
"max_new_tokens": max_new_tokens,
|
||||||
"stop": stop
|
"stop": stop
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
|
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
|
||||||
data=json.dumps(params),
|
data=json.dumps(params),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
|
||||||
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||||
# if chunk:
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
# data = json.loads(chunk.decode())
|
if chunk:
|
||||||
# if data["error_code"] == 0:
|
data = json.loads(chunk.decode())
|
||||||
# output = data["text"][skip_echo_len:].strip()
|
if data["error_code"] == 0:
|
||||||
# output = self.post_process_code(output)
|
output = data["text"][skip_echo_len:].strip()
|
||||||
# yield output
|
yield output
|
||||||
return response.json()["response"]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
|
@@ -4,29 +4,44 @@
|
|||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
vicuna_base_uri = "http://192.168.31.114:21002/"
|
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
||||||
vicuna_stream_path = "worker_generate_stream"
|
from langchain.prompts import PromptTemplate
|
||||||
vicuna_status_path = "worker_get_status"
|
|
||||||
|
|
||||||
def generate(prompt):
|
vicuna_stream_path = "generate_stream"
|
||||||
|
|
||||||
|
def generate(query):
|
||||||
|
|
||||||
|
template_name = "conv_one_shot"
|
||||||
|
state = conv_templates[template_name].copy()
|
||||||
|
|
||||||
|
pt = PromptTemplate(
|
||||||
|
template=conv_qa_prompt_template,
|
||||||
|
input_variables=["context", "question"]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
|
||||||
|
question=query)
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
state.append_message(state.roles[0], result)
|
||||||
|
state.append_message(state.roles[1], None)
|
||||||
|
|
||||||
|
prompt = state.get_prompt()
|
||||||
params = {
|
params = {
|
||||||
"model": "vicuna-13b",
|
"model": "vicuna-13b",
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_new_tokens": 512,
|
"max_new_tokens": 1024,
|
||||||
"stop": "###"
|
"stop": "###"
|
||||||
}
|
}
|
||||||
|
|
||||||
sts_response = requests.post(
|
|
||||||
url=urljoin(vicuna_base_uri, vicuna_status_path)
|
|
||||||
)
|
|
||||||
print(sts_response.text)
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
|
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
|
||||||
)
|
)
|
||||||
|
|
||||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||||
@@ -34,11 +49,10 @@ def generate(prompt):
|
|||||||
if chunk:
|
if chunk:
|
||||||
data = json.loads(chunk.decode())
|
data = json.loads(chunk.decode())
|
||||||
if data["error_code"] == 0:
|
if data["error_code"] == 0:
|
||||||
output = data["text"]
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
state.messages[-1][-1] = output + "▌"
|
||||||
yield(output)
|
yield(output)
|
||||||
|
|
||||||
time.sleep(0.02)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(LLM_MODEL)
|
print(LLM_MODEL)
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
|
@@ -3,31 +3,27 @@
|
|||||||
|
|
||||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from pilot.conversation import conv_qk_prompt_template
|
from pilot.conversation import conv_qa_prompt_template
|
||||||
from langchain.chains import RetrievalQA
|
|
||||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||||
|
from pilot.model.vicuna_llm import VicunaLLM
|
||||||
|
|
||||||
class KnownLedgeBaseQA:
|
class KnownLedgeBaseQA:
|
||||||
|
|
||||||
llm: object = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
k2v = KnownLedge2Vector()
|
k2v = KnownLedge2Vector()
|
||||||
self.vector_store = k2v.init_vector_store()
|
self.vector_store = k2v.init_vector_store()
|
||||||
|
self.llm = VicunaLLM()
|
||||||
|
|
||||||
def get_answer(self, query):
|
def get_similar_answer(self, query):
|
||||||
prompt_template = conv_qk_prompt_template
|
|
||||||
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template=prompt_template,
|
template=conv_qa_prompt_template,
|
||||||
input_variables=["context", "question"]
|
input_variables=["context", "question"]
|
||||||
)
|
)
|
||||||
|
|
||||||
knownledge_chain = RetrievalQA.from_llm(
|
retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
|
||||||
llm=self.llm,
|
docs = retriever.get_relevant_documents(query=query)
|
||||||
retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}),
|
|
||||||
prompt=prompt
|
context = [d.page_content for d in docs]
|
||||||
)
|
result = prompt.format(context="\n".join(context), question=query)
|
||||||
knownledge_chain.return_source_documents = True
|
return result
|
||||||
result = knownledge_chain({"query": query})
|
|
||||||
yield result
|
|
||||||
|
@@ -11,6 +11,7 @@ import datetime
|
|||||||
import requests
|
import requests
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from pilot.configs.model_config import DB_SETTINGS
|
from pilot.configs.model_config import DB_SETTINGS
|
||||||
|
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||||
from pilot.connections.mysql_conn import MySQLOperator
|
from pilot.connections.mysql_conn import MySQLOperator
|
||||||
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
|
|||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
default_conversation,
|
default_conversation,
|
||||||
conv_templates,
|
conv_templates,
|
||||||
|
conversation_types,
|
||||||
SeparatorStyle
|
SeparatorStyle
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -149,7 +151,7 @@ def post_process_code(code):
|
|||||||
code = sep.join(blocks)
|
code = sep.join(blocks)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
model_name = LLM_MODEL
|
model_name = LLM_MODEL
|
||||||
|
|
||||||
@@ -170,7 +172,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
|
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
|
|
||||||
# prompt 中添加上下文提示
|
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||||
|
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||||
if db_selector:
|
if db_selector:
|
||||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||||
new_state.append_message(new_state.roles[1], None)
|
new_state.append_message(new_state.roles[1], None)
|
||||||
@@ -180,13 +183,11 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
new_state.append_message(new_state.roles[1], None)
|
new_state.append_message(new_state.roles[1], None)
|
||||||
state = new_state
|
state = new_state
|
||||||
|
|
||||||
# try:
|
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||||
# if not db_selector:
|
query = state.messages[-2][1]
|
||||||
# sim_q = get_simlar(query)
|
knqa = KnownLedgeBaseQA()
|
||||||
# print("********vector similar info*************: ", sim_q)
|
state.messages[-2][1] = knqa.get_similar_answer(query)
|
||||||
# state.append_message(new_state.roles[0], sim_q + query)
|
|
||||||
# except Exception as e:
|
|
||||||
# print(e)
|
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
|
|
||||||
@@ -222,7 +223,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
state.messages[-1][-1] = output
|
state.messages[-1][-1] = output
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||||
return
|
return
|
||||||
time.sleep(0.02)
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||||
@@ -231,6 +232,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
|
# 记录运行日志
|
||||||
finish_tstamp = time.time()
|
finish_tstamp = time.time()
|
||||||
logger.info(f"{output}")
|
logger.info(f"{output}")
|
||||||
|
|
||||||
@@ -266,7 +268,7 @@ def change_tab(tab):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def change_mode(mode):
|
def change_mode(mode):
|
||||||
if mode == "默认知识库对话":
|
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True)
|
||||||
@@ -318,7 +320,8 @@ def build_single_model_ui():
|
|||||||
show_label=True).style(container=False)
|
show_label=True).style(container=False)
|
||||||
|
|
||||||
with gr.TabItem("知识问答", elem_id="QA"):
|
with gr.TabItem("知识问答", elem_id="QA"):
|
||||||
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
|
|
||||||
|
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||||
with vs_setting:
|
with vs_setting:
|
||||||
@@ -363,7 +366,7 @@ def build_single_model_ui():
|
|||||||
btn_list = [regenerate_btn, clear_btn]
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, db_selector, temperature, max_output_tokens],
|
[state, mode, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||||
@@ -372,7 +375,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, db_selector, temperature, max_output_tokens],
|
[state, mode, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -380,7 +383,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, db_selector, temperature, max_output_tokens],
|
[state, mode, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list
|
[state, chatbot] + btn_list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user