Merge remote-tracking branch 'origin/main' into dev

# Conflicts:
#	pilot/configs/config.py
#	pilot/connections/mysql.py
#	pilot/conversation.py
#	pilot/server/webserver.py
This commit is contained in:
yhjun1026
2023-05-25 10:22:38 +08:00
87 changed files with 2152 additions and 874 deletions

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from functools import cache
from typing import List
from pilot.model.inference import generate_stream
class BaseChatAdpter:
"""The Base class for chat with llm models. it will match the model,
and fetch output from model"""
def match(self, model_path: str):
return True
def get_generate_stream_func(self):
"""Return the generate stream handler func"""
pass
llm_model_chat_adapters: List[BaseChatAdpter] = []
def register_llm_model_chat_adapter(cls):
"""Register a chat adapter"""
llm_model_chat_adapters.append(cls())
@cache
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
"""Get a chat generate func for a model"""
for adapter in llm_model_chat_adapters:
if adapter.match(model_path):
return adapter
raise ValueError(f"Invalid model for chat adapter {model_path}")
class VicunaChatAdapter(BaseChatAdpter):
"""Model chat Adapter for vicuna"""
def match(self, model_path: str):
return "vicuna" in model_path
def get_generate_stream_func(self):
return generate_stream
class ChatGLMChatAdapter(BaseChatAdpter):
"""Model chat Adapter for ChatGLM"""
def match(self, model_path: str):
return "chatglm" in model_path
def get_generate_stream_func(self):
from pilot.model.chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream
class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5"""
def match(self, model_path: str):
return "codet5" in model_path
def get_generate_stream_func(self):
# TODO
pass
class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen"""
def match(self, model_path: str):
return "codegen" in model_path
def get_generate_stream_func(self):
# TODO
pass
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)

View File

@@ -1,8 +1,7 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
code_highlight_css = (
"""
code_highlight_css = """
#chatbot .hll { background-color: #ffffcc }
#chatbot .c { color: #408080; font-style: italic }
#chatbot .err { border: 1px solid #FF0000 }
@@ -71,6 +70,5 @@ code_highlight_css = (
#chatbot .vi { color: #19177C }
#chatbot .vm { color: #19177C }
#chatbot .il { color: #666666 }
""")
#.highlight { background: #f8f8f8; }
"""
# .highlight { background: #f8f8f8; }

View File

@@ -49,7 +49,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
warnings.warn(
"The 'color_map' parameter has been deprecated.",
)
#self.md = utils.get_markdown_parser()
# self.md = utils.get_markdown_parser()
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
self.select: EventListenerMethod
"""
@@ -112,7 +112,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
): # This happens for previously processed messages
return chat_message
elif isinstance(chat_message, str):
#return self.md.render(chat_message)
# return self.md.render(chat_message)
return str(self.md.convert(chat_message))
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
@@ -141,9 +141,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
processed_messages.append(
(
#self._process_chat_messages(message_pair[0]),
'<pre style="font-family: var(--font)">' +
message_pair[0] + "</pre>",
# self._process_chat_messages(message_pair[0]),
'<pre style="font-family: var(--font)">'
+ message_pair[0]
+ "</pre>",
self._process_chat_messages(message_pair[1]),
)
)
@@ -163,5 +164,3 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
**kwargs,
)
return self

View File

@@ -1,40 +1,91 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import uvicorn
import asyncio
import json
from typing import Optional, List
from fastapi import FastAPI, Request, BackgroundTasks
import os
import sys
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
from pilot.model.inference import generate_stream
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModelLoader
from pilot.configs.model_config import *
from pilot.configs.config import Config
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
global_counter = 0
model_semaphore = None
ml = ModelLoader(model_path=model_path)
model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
#model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import *
from pilot.model.inference import generate_output, generate_stream, get_embeddings
from pilot.model.loader import ModelLoader
from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config()
class ModelWorker:
def __init__(self):
pass
def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1]
self.device = device
self.ml = ModelLoader(model_path=model_path)
self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
)
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
def get_queue_length(self):
if (
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
):
return 0
else:
(
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
)
def generate_stream_gate(self, params):
try:
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt):
return get_embeddings(self.model, self.tokenizer, prompt)
# TODO
app = FastAPI()
class PromptRequest(BaseModel):
prompt: str
temperature: float
@@ -42,6 +93,7 @@ class PromptRequest(BaseModel):
model: str
stop: str = None
class StreamRequest(BaseModel):
model: str
prompt: str
@@ -49,64 +101,43 @@ class StreamRequest(BaseModel):
max_new_tokens: int
stop: str
class EmbeddingRequest(BaseModel):
prompt: str
def release_model_semaphore():
model_semaphore.release()
def generate_stream_gate(params):
try:
for output in generate_stream(
model,
tokenizer,
params,
DEVICE,
CFG.MAX_POSITION_EMBEDDINGS,
):
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {
"text": "**GPU OutOfMemory, Please Refresh.**",
"error_code": 0
}
yield json.dumps(ret).encode() + b"\0"
@app.post("/generate_stream")
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
print(model, tokenizer, params, DEVICE)
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
await model_semaphore.acquire()
generator = generate_stream_gate(params)
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: PromptRequest):
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
"stop": prompt_request.stop,
}
response = []
response = []
rsp_str = ""
output = generate_stream_gate(params)
output = worker.generate_stream_gate(params)
for rsp in output:
# rsp = rsp.decode("utf-8")
rsp_str = str(rsp, "utf-8")
@@ -114,15 +145,22 @@ def generate(prompt_request: PromptRequest):
response.append(rsp_str)
return {"response": rsp_str}
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"])
output = worker.get_embeddings(params["prompt"])
return {"response": [float(x) for x in output]}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_level="info")
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
print(model_path, DEVICE)
worker = ModelWorker(
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
)
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")

View File

@@ -1,29 +1,30 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.vector_store.file_loader import KnownLedge2Vector
from langchain.prompts import PromptTemplate
from pilot.conversation import conv_qa_prompt_template
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.conversation import conv_qa_prompt_template
from pilot.model.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector
class KnownLedgeBaseQA:
def __init__(self) -> None:
k2v = KnownLedge2Vector()
self.vector_store = k2v.init_vector_store()
self.llm = VicunaLLM()
def get_similar_answer(self, query):
prompt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
template=conv_qa_prompt_template, input_variables=["context", "question"]
)
retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
retriever = self.vector_store.as_retriever(
search_kwargs={"k": VECTOR_SEARCH_TOP_K}
)
docs = retriever.get_relevant_documents(query=query)
context = [d.page_content for d in docs]
context = [d.page_content for d in docs]
result = prompt.format(context="\n".join(context), question=query)
return result

View File

@@ -2,50 +2,63 @@
# -*- coding: utf-8 -*-
import argparse
import datetime
import json
import os
import shutil
import uuid
import json
import sys
import time
import gradio as gr
import datetime
import requests
import uuid
from urllib.parse import urljoin
import gradio as gr
import requests
from langchain import PromptTemplate
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command import execute_ai_response_json
from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.scene.base_chat import BaseChat
from pilot.commands.exception_not_commands import NotCommands
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.connections.mysql import MySQLOperator
from pilot.conversation import (
default_conversation,
SeparatorStyle,
conv_qa_prompt_template,
conv_templates,
conversation_types,
conversation_sql_mode,
SeparatorStyle, conv_qa_prompt_template
conversation_types,
default_conversation,
)
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.plugins import scan_plugins
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger, server_error_msg
from pilot.vector_store.extract_tovec import (
get_vector_storelist,
knownledge_tovec_st,
load_knownledge_from_doc,
)
from pilot.commands.command import execute_ai_response_json
from pilot.scene.base import ChatScene
@@ -66,9 +79,7 @@ autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
priority = {
"vicuna-13b": "aaa"
}
priority = {"vicuna-13b": "aaa"}
# 加载插件
CFG = Config()
@@ -76,10 +87,12 @@ CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
"port": CFG.LOCAL_DB_PORT,
}
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
@@ -89,9 +102,7 @@ def get_simlar(q):
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
)
mo = MySQLOperator(**DB_SETTINGS)
message = ""
@@ -334,8 +345,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
block_css = (
code_highlight_css
+ """
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@@ -372,7 +383,7 @@ def build_single_model_ui():
notice_markdown = """
# DB-GPT
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat)并使用vicuna-13b作为基础模型。此外此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 总的来说它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目使用本地化的GPT大模型与您的数据和环境进行交互无数据泄露风险100% 私密100% 安全
"""
learn_more_markdown = """
### Licence
@@ -396,7 +407,7 @@ def build_single_model_ui():
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=1024,
value=512,
step=64,
interactive=True,
label="最大输出Token数",
@@ -412,7 +423,8 @@ def build_single_model_ui():
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True).style(container=False)
show_label=True,
).style(container=False)
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
@@ -420,7 +432,9 @@ def build_single_model_ui():
tab_qa = gr.TabItem("知识问答", elem_id="QA")
with tab_qa:
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
mode = gr.Radio(
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
)
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
@@ -429,18 +443,22 @@ def build_single_model_ui():
with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"):
files = gr.File(label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
show_label=False
)
files = gr.File(
label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
allow_flagged_uploads=True,
show_label=False,
)
load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件夹",
accept_multiple_files=True,
file_count="directory",
show_label=False)
folder_files = gr.File(
label="添加文件夹",
accept_multiple_files=True,
file_count="directory",
show_label=False,
)
load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks():
@@ -481,28 +499,32 @@ def build_single_model_ui():
).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list
[state, chatbot] + btn_list,
)
vs_add.click(
fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name]
)
load_file_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name],
)
load_folder_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, folder_files],
outputs=[vs_name],
)
vs_add.click(fn=save_vs_name, show_progress=True,
inputs=[vs_name],
outputs=[vs_name])
load_file_button.click(fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name])
load_folder_button.click(fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, folder_files],
outputs=[vs_name])
return state, chatbot, textbox, send_btn, button_row, parameter_row
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
@@ -544,15 +566,21 @@ def knowledge_embedding_store(vs_id, files):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
for file in files:
filename = os.path.split(file.name)[-1]
shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename))
shutil.move(
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
)
knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
knowledge_embedding_client.knowledge_embedding()
logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
@@ -596,5 +624,8 @@ if __name__ == "__main__":
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)