Merge pull request #16 from csunny/dev

Dev
This commit is contained in:
magic.chen
2023-05-07 17:59:49 +08:00
committed by GitHub
7 changed files with 82 additions and 67 deletions

View File

@@ -1,10 +1,10 @@
# DB-GPT # DB-GPT
A Open Database-GPT Experiment, A fully localized project. A Open Database-GPT Experiment, A fully localized project.
一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。
![GitHub Repo stars](https://img.shields.io/github/stars/csunny/db-gpt?style=social) ![GitHub Repo stars](https://img.shields.io/github/stars/csunny/db-gpt?style=social)
一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用[vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)作为基础模型。此外,此程序结合了[langchain](https://github.com/hwchase17/langchain)和[llama-index](https://github.com/jerryjliu/llama_index)基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用[vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)作为基础模型。此外,此程序结合了[langchain](https://github.com/hwchase17/langchain)和[llama-index](https://github.com/jerryjliu/llama_index)基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。

View File

@@ -22,7 +22,7 @@ LLM_MODEL_CONFIG = {
} }
VECTOR_SEARCH_TOP_K = 5 VECTOR_SEARCH_TOP_K = 3
LLM_MODEL = "vicuna-13b" LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5 LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048 MAX_POSITION_EMBEDDINGS = 2048

View File

@@ -147,7 +147,7 @@ conv_vicuna_v1 = Conversation(
) )
conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。 conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议: 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
已知内容: 已知内容:
@@ -158,6 +158,12 @@ conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回
default_conversation = conv_one_shot default_conversation = conv_one_shot
conversation_types = {
"native": "LLM原生对话",
"default_knownledge": "默认知识库对话",
"custome": "新增知识库对话",
}
conv_templates = { conv_templates = {
"conv_one_shot": conv_one_shot, "conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1,

View File

@@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM from langchain.llms.base import LLM
from pilot.configs.model_config import * from pilot.configs.model_config import *
class VicunaRequestLLM(LLM): class VicunaLLM(LLM):
vicuna_generate_path = "generate" vicuna_generate_path = "generate_stream"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
if isinstance(stop, list):
stop = stop + ["Observation:"]
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
params = { params = {
"prompt": prompt, "prompt": prompt,
"temperature": 0.7, "temperature": temperature,
"max_new_tokens": 1024, "max_new_tokens": max_new_tokens,
"stop": stop "stop": stop
} }
response = requests.post( response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params), data=json.dumps(params),
) )
response.raise_for_status()
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
# if chunk: for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
# data = json.loads(chunk.decode()) if chunk:
# if data["error_code"] == 0: data = json.loads(chunk.decode())
# output = data["text"][skip_echo_len:].strip() if data["error_code"] == 0:
# output = self.post_process_code(output) output = data["text"][skip_echo_len:].strip()
# yield output yield output
return response.json()["response"]
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:

View File

@@ -4,29 +4,44 @@
import requests import requests
import json import json
import time import time
import uuid
from urllib.parse import urljoin from urllib.parse import urljoin
import gradio as gr import gradio as gr
from pilot.configs.model_config import * from pilot.configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/" from pilot.conversation import conv_qa_prompt_template, conv_templates
vicuna_stream_path = "worker_generate_stream" from langchain.prompts import PromptTemplate
vicuna_status_path = "worker_get_status"
def generate(prompt): vicuna_stream_path = "generate_stream"
def generate(query):
template_name = "conv_one_shot"
state = conv_templates[template_name].copy()
pt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
question=query)
print(result)
state.append_message(state.roles[0], result)
state.append_message(state.roles[1], None)
prompt = state.get_prompt()
params = { params = {
"model": "vicuna-13b", "model": "vicuna-13b",
"prompt": prompt, "prompt": prompt,
"temperature": 0.7, "temperature": 0.7,
"max_new_tokens": 512, "max_new_tokens": 1024,
"stop": "###" "stop": "###"
} }
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post( response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params) url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
) )
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3 skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
@@ -34,11 +49,10 @@ def generate(prompt):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
if data["error_code"] == 0: if data["error_code"] == 0:
output = data["text"] output = data["text"][skip_echo_len:].strip()
state.messages[-1][-1] = output + ""
yield(output) yield(output)
time.sleep(0.02)
if __name__ == "__main__": if __name__ == "__main__":
print(LLM_MODEL) print(LLM_MODEL)
with gr.Blocks() as demo: with gr.Blocks() as demo:

View File

@@ -3,31 +3,27 @@
from pilot.vector_store.file_loader import KnownLedge2Vector from pilot.vector_store.file_loader import KnownLedge2Vector
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from pilot.conversation import conv_qk_prompt_template from pilot.conversation import conv_qa_prompt_template
from langchain.chains import RetrievalQA
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.model.vicuna_llm import VicunaLLM
class KnownLedgeBaseQA: class KnownLedgeBaseQA:
llm: object = None
def __init__(self) -> None: def __init__(self) -> None:
k2v = KnownLedge2Vector() k2v = KnownLedge2Vector()
self.vector_store = k2v.init_vector_store() self.vector_store = k2v.init_vector_store()
self.llm = VicunaLLM()
def get_answer(self, query): def get_similar_answer(self, query):
prompt_template = conv_qk_prompt_template
prompt = PromptTemplate( prompt = PromptTemplate(
template=prompt_template, template=conv_qa_prompt_template,
input_variables=["context", "question"] input_variables=["context", "question"]
) )
knownledge_chain = RetrievalQA.from_llm( retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
llm=self.llm, docs = retriever.get_relevant_documents(query=query)
retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}),
prompt=prompt context = [d.page_content for d in docs]
) result = prompt.format(context="\n".join(context), question=query)
knownledge_chain.return_source_documents = True return result
result = knownledge_chain({"query": query})
yield result

View File

@@ -11,6 +11,7 @@ import datetime
import requests import requests
from urllib.parse import urljoin from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS from pilot.configs.model_config import DB_SETTINGS
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql_conn import MySQLOperator from pilot.connections.mysql_conn import MySQLOperator
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
@@ -19,6 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
from pilot.conversation import ( from pilot.conversation import (
default_conversation, default_conversation,
conv_templates, conv_templates,
conversation_types,
SeparatorStyle SeparatorStyle
) )
@@ -149,7 +151,7 @@ def post_process_code(code):
code = sep.join(blocks) code = sep.join(blocks)
return code return code
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request): def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
start_tstamp = time.time() start_tstamp = time.time()
model_name = LLM_MODEL model_name = LLM_MODEL
@@ -170,7 +172,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
query = state.messages[-2][1] query = state.messages[-2][1]
# prompt 中添加上下文提示 # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector: if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(new_state.roles[1], None) new_state.append_message(new_state.roles[1], None)
@@ -180,13 +183,11 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
new_state.append_message(new_state.roles[1], None) new_state.append_message(new_state.roles[1], None)
state = new_state state = new_state
# try: if mode == conversation_types["default_knownledge"] and not db_selector:
# if not db_selector: query = state.messages[-2][1]
# sim_q = get_simlar(query) knqa = KnownLedgeBaseQA()
# print("********vector similar info*************: ", sim_q) state.messages[-2][1] = knqa.get_similar_answer(query)
# state.append_message(new_state.roles[0], sim_q + query)
# except Exception as e:
# print(e)
prompt = state.get_prompt() prompt = state.get_prompt()
@@ -222,7 +223,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
state.messages[-1][-1] = output state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return return
time.sleep(0.02)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
@@ -231,6 +232,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
state.messages[-1][-1] = state.messages[-1][-1][:-1] state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time() finish_tstamp = time.time()
logger.info(f"{output}") logger.info(f"{output}")
@@ -266,7 +268,7 @@ def change_tab(tab):
pass pass
def change_mode(mode): def change_mode(mode):
if mode == "默认知识库对话": if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False) return gr.update(visible=False)
else: else:
return gr.update(visible=True) return gr.update(visible=True)
@@ -318,7 +320,8 @@ def build_single_model_ui():
show_label=True).style(container=False) show_label=True).style(container=False)
with gr.TabItem("知识问答", elem_id="QA"): with gr.TabItem("知识问答", elem_id="QA"):
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
vs_setting = gr.Accordion("配置知识库", open=False) vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting: with vs_setting:
@@ -363,7 +366,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn] btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot, http_bot,
[state, db_selector, temperature, max_output_tokens], [state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@@ -372,7 +375,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
http_bot, http_bot,
[state, db_selector, temperature, max_output_tokens], [state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
@@ -380,7 +383,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
http_bot, http_bot,
[state, db_selector, temperature, max_output_tokens], [state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list [state, chatbot] + btn_list
) )