mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 13:06:53 +00:00
commit
3db8e33f20
@ -40,8 +40,8 @@ def get_answer(q):
|
|||||||
return response.response
|
return response.response
|
||||||
|
|
||||||
def get_similar(q):
|
def get_similar(q):
|
||||||
from pilot.vector_store.extract_tovec import knownledge_tovec
|
from pilot.vector_store.extract_tovec import knownledge_tovec, load_knownledge_from_doc
|
||||||
docsearch = knownledge_tovec("./datasets/plan.md")
|
docsearch = load_knownledge_from_doc()
|
||||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
@ -4,34 +4,34 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
model_path = os.path.join(root_path, "models")
|
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||||
vector_storepath = os.path.join(root_path, "vector_store")
|
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
|
||||||
LOGDIR = os.path.join(root_path, "logs")
|
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||||
|
DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets")
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
llm_model_config = {
|
LLM_MODEL_CONFIG = {
|
||||||
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
|
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
|
||||||
"vicuna-13b": os.path.join(model_path, "vicuna-13b"),
|
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
||||||
"sentence-transforms": os.path.join(model_path, "all-MiniLM-L6-v2")
|
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
LIMIT_MODEL_CONCURRENCY = 5
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
MAX_POSITION_EMBEDDINGS = 2048
|
MAX_POSITION_EMBEDDINGS = 2048
|
||||||
vicuna_model_server = "http://192.168.31.114:8000"
|
VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
|
||||||
|
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
isload_8bit = True
|
ISLOAD_8BIT = True
|
||||||
isdebug = False
|
ISDEBUG = False
|
||||||
|
|
||||||
|
|
||||||
DB_SETTINGS = {
|
DB_SETTINGS = {
|
||||||
"user": "root",
|
"user": "root",
|
||||||
"password": "********",
|
"password": "aa123456",
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 3306
|
"port": 3306
|
||||||
}
|
}
|
@ -8,7 +8,7 @@ from langchain.embeddings.base import Embeddings
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any, Mapping, Optional, List
|
from typing import Any, Mapping, Optional, List
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
|
|
||||||
class VicunaRequestLLM(LLM):
|
class VicunaRequestLLM(LLM):
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ class VicunaRequestLLM(LLM):
|
|||||||
"stop": stop
|
"stop": stop
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(vicuna_model_server, self.vicuna_generate_path),
|
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
|
||||||
data=json.dumps(params),
|
data=json.dumps(params),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -55,7 +55,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
|
|||||||
print("Sending prompt ", p)
|
print("Sending prompt ", p)
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(vicuna_model_server, self.vicuna_embedding_path),
|
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path),
|
||||||
json={
|
json={
|
||||||
"prompt": p
|
"prompt": p
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ from fastchat.serve.inference import load_model
|
|||||||
from pilot.model.loader import ModerLoader
|
from pilot.model.loader import ModerLoader
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
|
|
||||||
model_path = llm_model_config[LLM_MODEL]
|
model_path = LLM_MODEL_CONFIG[LLM_MODEL]
|
||||||
|
|
||||||
|
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
|
@ -12,9 +12,9 @@ import requests
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from pilot.configs.model_config import DB_SETTINGS
|
from pilot.configs.model_config import DB_SETTINGS
|
||||||
from pilot.connections.mysql_conn import MySQLOperator
|
from pilot.connections.mysql_conn import MySQLOperator
|
||||||
|
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
|
||||||
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
default_conversation,
|
default_conversation,
|
||||||
@ -42,11 +42,22 @@ disable_btn = gr.Button.update(interactive=True)
|
|||||||
enable_moderation = False
|
enable_moderation = False
|
||||||
models = []
|
models = []
|
||||||
dbs = []
|
dbs = []
|
||||||
|
vs_list = ["新建知识库"] + get_vector_storelist()
|
||||||
|
|
||||||
priority = {
|
priority = {
|
||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_simlar(q):
|
||||||
|
|
||||||
|
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||||
|
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||||
|
|
||||||
|
contents = [dc.page_content for dc, _ in docs]
|
||||||
|
return "\n".join(contents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gen_sqlgen_conversation(dbname):
|
def gen_sqlgen_conversation(dbname):
|
||||||
mo = MySQLOperator(
|
mo = MySQLOperator(
|
||||||
**DB_SETTINGS
|
**DB_SETTINGS
|
||||||
@ -149,6 +160,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
return
|
return
|
||||||
|
|
||||||
|
query = state.messages[-2][1]
|
||||||
if len(state.messages) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
# 第一轮对话需要加入提示Prompt
|
# 第一轮对话需要加入提示Prompt
|
||||||
|
|
||||||
@ -157,11 +169,23 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
new_state.conv_id = uuid.uuid4().hex
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
|
|
||||||
# prompt 中添加上下文提示
|
# prompt 中添加上下文提示
|
||||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1])
|
if db_selector:
|
||||||
new_state.append_message(new_state.roles[1], None)
|
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||||
state = new_state
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
state = new_state
|
||||||
|
else:
|
||||||
|
new_state.append_message(new_state.roles[0], query)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
state = new_state
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not db_selector:
|
||||||
|
sim_q = get_simlar(query)
|
||||||
|
print("********vector similar info*************: ", sim_q)
|
||||||
|
state.append_message(new_state.roles[0], sim_q + query)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
|
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
@ -181,7 +205,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Stream output
|
# Stream output
|
||||||
response = requests.post(urljoin(vicuna_model_server, "generate_stream"),
|
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
|
||||||
headers=headers, json=payload, stream=True, timeout=20)
|
headers=headers, json=payload, stream=True, timeout=20)
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
@ -236,6 +260,15 @@ pre {
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def change_tab(tab):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def change_mode(mode):
|
||||||
|
if mode == "默认知识库对话":
|
||||||
|
return gr.update(visible=False)
|
||||||
|
else:
|
||||||
|
return gr.update(visible=True)
|
||||||
|
|
||||||
|
|
||||||
def build_single_model_ui():
|
def build_single_model_ui():
|
||||||
|
|
||||||
@ -270,12 +303,10 @@ def build_single_model_ui():
|
|||||||
interactive=True,
|
interactive=True,
|
||||||
label="最大输出Token数",
|
label="最大输出Token数",
|
||||||
)
|
)
|
||||||
|
tabs = gr.Tabs()
|
||||||
with gr.Tabs():
|
with tabs:
|
||||||
with gr.TabItem("知识问答", elem_id="QA"):
|
|
||||||
pass
|
|
||||||
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
||||||
# TODO A selector to choose database
|
# TODO A selector to choose database
|
||||||
with gr.Row(elem_id="db_selector"):
|
with gr.Row(elem_id="db_selector"):
|
||||||
db_selector = gr.Dropdown(
|
db_selector = gr.Dropdown(
|
||||||
label="请选择数据库",
|
label="请选择数据库",
|
||||||
@ -283,6 +314,30 @@ def build_single_model_ui():
|
|||||||
value=dbs[0] if len(models) > 0 else "",
|
value=dbs[0] if len(models) > 0 else "",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
show_label=True).style(container=False)
|
show_label=True).style(container=False)
|
||||||
|
|
||||||
|
with gr.TabItem("知识问答", elem_id="QA"):
|
||||||
|
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
|
||||||
|
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||||
|
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||||
|
with vs_setting:
|
||||||
|
vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True)
|
||||||
|
vs_add = gr.Button("添加为新知识库")
|
||||||
|
with gr.Column() as doc2vec:
|
||||||
|
gr.Markdown("向知识库中添加文件")
|
||||||
|
with gr.Tab("上传文件"):
|
||||||
|
files = gr.File(label="添加文件",
|
||||||
|
file_types=[".txt", ".md", ".docx", ".pdf"],
|
||||||
|
file_count="multiple",
|
||||||
|
show_label=False
|
||||||
|
)
|
||||||
|
|
||||||
|
load_file_button = gr.Button("上传并加载到知识库")
|
||||||
|
with gr.Tab("上传文件夹"):
|
||||||
|
folder_files = gr.File(label="添加文件",
|
||||||
|
file_count="directory",
|
||||||
|
show_label=False)
|
||||||
|
load_folder_button = gr.Button("上传并加载到知识库")
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks():
|
with gr.Blocks():
|
||||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||||
@ -300,6 +355,7 @@ def build_single_model_ui():
|
|||||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||||
clear_btn = gr.Button(value="清理", interactive=False)
|
clear_btn = gr.Button(value="清理", interactive=False)
|
||||||
|
|
||||||
|
|
||||||
gr.Markdown(learn_more_markdown)
|
gr.Markdown(learn_more_markdown)
|
||||||
|
|
||||||
btn_list = [regenerate_btn, clear_btn]
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
|
@ -1,19 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
|
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
|
||||||
# from langchain.embeddings import SentenceTransformerEmbeddings
|
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
embeddings = VicunaEmbeddingLLM()
|
embeddings = VicunaEmbeddingLLM()
|
||||||
|
|
||||||
def knownledge_tovec(filename):
|
def knownledge_tovec(filename):
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
knownledge = f.read()
|
knownledge = f.read()
|
||||||
|
|
||||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
texts = text_splitter.split_text(knownledge)
|
texts = text_splitter.split_text(knownledge)
|
||||||
docsearch = Chroma.from_texts(
|
docsearch = Chroma.from_texts(
|
||||||
@ -21,18 +22,48 @@ def knownledge_tovec(filename):
|
|||||||
)
|
)
|
||||||
return docsearch
|
return docsearch
|
||||||
|
|
||||||
|
def knownledge_tovec_st(filename):
|
||||||
|
""" Use sentence transformers to embedding the document.
|
||||||
|
https://github.com/UKPLab/sentence-transformers
|
||||||
|
"""
|
||||||
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
|
||||||
|
|
||||||
# def knownledge_tovec_st(filename):
|
with open(filename, "r") as f:
|
||||||
# """ Use sentence transformers to embedding the document.
|
knownledge = f.read()
|
||||||
# https://github.com/UKPLab/sentence-transformers
|
|
||||||
# """
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
# from pilot.configs.model_config import llm_model_config
|
|
||||||
# embeddings = SentenceTransformerEmbeddings(model=llm_model_config["sentence-transforms"])
|
|
||||||
|
|
||||||
# with open(filename, "r") as f:
|
texts = text_splitter.split_text(knownledge)
|
||||||
# knownledge = f.read()
|
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))])
|
||||||
|
return docsearch
|
||||||
# text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
||||||
# texts = text_splitter(knownledge)
|
|
||||||
# docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))])
|
def load_knownledge_from_doc():
|
||||||
# return docsearch
|
"""从数据集当中加载知识
|
||||||
|
# TODO 如果向量存储已经存在, 则无需初始化
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not os.path.exists(DATASETS_DIR):
|
||||||
|
print("Not Exists Local DataSets, We will answers the Question use model default.")
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
|
||||||
|
|
||||||
|
files = os.listdir(DATASETS_DIR)
|
||||||
|
for file in files:
|
||||||
|
if not os.path.isdir(file):
|
||||||
|
filename = os.path.join(DATASETS_DIR, file)
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
knownledge = f.read()
|
||||||
|
|
||||||
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0)
|
||||||
|
texts = text_splitter.split_text(knownledge)
|
||||||
|
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))],
|
||||||
|
persist_directory=os.path.join(VECTORE_PATH, ".vectore"))
|
||||||
|
return docsearch
|
||||||
|
|
||||||
|
def get_vector_storelist():
|
||||||
|
if not os.path.exists(VECTORE_PATH):
|
||||||
|
return []
|
||||||
|
return os.listdir(VECTORE_PATH)
|
@ -50,4 +50,6 @@ notebook
|
|||||||
gradio==3.24.1
|
gradio==3.24.1
|
||||||
gradio-client==0.0.8
|
gradio-client==0.0.8
|
||||||
wandb
|
wandb
|
||||||
fschat=0.1.10
|
fschat=0.1.10
|
||||||
|
llama-index=0.5.27
|
||||||
|
pymysql
|
Loading…
Reference in New Issue
Block a user