Merge remote-tracking branch 'origin/dev_ty_06_end' into llm_framework

This commit is contained in:
aries_ckt 2023-06-30 10:43:35 +08:00
commit 1de0542015
20 changed files with 79 additions and 42 deletions

View File

@ -2,7 +2,10 @@
const nextConfig = {
experimental: {
esmExternals: 'loose'
}
},
images: {
unoptimized: true
},
}
module.exports = nextConfig

View File

@ -6,7 +6,9 @@
"dev": "next dev",
"build": "next build",
"start": "next start",
"lint": "next lint"
"lint": "next lint",
"export": "next export",
"compile": "next build && next export"
},
"dependencies": {
"@ant-design/pro-components": "^2.6.2",

View File

@ -7,4 +7,4 @@ if __name__ == "__main__":
connect = CFG.local_db.get_session("gpt-user")
datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
print(datas)
print(datas)

View File

@ -9,6 +9,8 @@ if __name__ == "__main__":
if os.path.isfile("../../../message/chat_history.db"):
cursor = duckdb.connect("../../../message/chat_history.db").cursor()
# cursor.execute("SELECT * FROM chat_history limit 20")
cursor.execute("SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'")
cursor.execute(
"SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'"
)
data = cursor.fetchall()
print(data)

View File

@ -26,10 +26,9 @@ class BaseChatHistoryMemory(ABC):
"""Retrieve the messages from the local file"""
@abstractmethod
def create(self, user_name:str) -> None:
def create(self, user_name: str) -> None:
"""Append the message to the record in the local file"""
@abstractmethod
def append(self, message: OnceConversation) -> None:
"""Append the message to the record in the local file"""

View File

@ -36,7 +36,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)")
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)"
)
def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor()
@ -59,7 +60,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO chat_history(conv_uid, chat_mode summary, user_name, messages)VALUES(?,?,?,?,?)",
[self.chat_seesion_id, chat_mode, summary, user_name, ""])
[self.chat_seesion_id, chat_mode, summary, user_name, ""],
)
cursor.commit()
self.connect.commit()
except Exception as e:
@ -80,7 +82,14 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
else:
cursor.execute(
"INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)",
[self.chat_seesion_id, once_message.chat_mode, once_message.get_user_conv().content, "",json.dumps(conversations, ensure_ascii=False)])
[
self.chat_seesion_id,
once_message.chat_mode,
once_message.get_user_conv().content,
"",
json.dumps(conversations, ensure_ascii=False),
],
)
cursor.commit()
self.connect.commit()

Binary file not shown.

View File

@ -2,6 +2,7 @@ import uuid
import json
import asyncio
import time
import os
from fastapi import (
APIRouter,
Request,
@ -12,11 +13,11 @@ from fastapi import (
BackgroundTasks,
)
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List
from pilot.openapi.api_v1.api_view_model import (
@ -46,6 +47,7 @@ knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
static_file_path = os.path.join(os.getcwd(), "server/static")
async def validation_exception_handler(request: Request, exc: RequestValidationError):
@ -95,6 +97,10 @@ def knowledge_list():
return params
@router.get("/")
async def read_main():
return FileResponse(f"{static_file_path}/test.html")
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None):
@ -111,8 +117,6 @@ async def dialogue_list(response: Response, user_id: str = None):
summary = item.get("summary")
chat_mode = item.get("chat_mode")
conv_vo: ConversationVo = ConversationVo(
conv_uid=conv_uid,
user_input=summary,
@ -147,7 +151,6 @@ async def dialogue_scenes():
return Result.succ(scene_vos)
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
@ -155,6 +158,7 @@ async def dialogue_new(
conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo)
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatWithDbQA.value == chat_mode:
@ -274,15 +278,15 @@ async def stream_generator(chat):
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"

View File

@ -124,6 +124,7 @@ class KnowledgeDocumentDao:
updated_space = session.merge(document)
session.commit()
return updated_space.id
#
# def delete_knowledge_document(self, document_id: int):
# cursor = self.conn.cursor()

View File

@ -114,8 +114,13 @@ class KnowledgeService:
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name:
raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync")
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
client = KnowledgeEmbedding(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),

View File

@ -107,7 +107,9 @@ class BaseChat(ABC):
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0
if self.prompt_template.template:

View File

@ -65,12 +65,12 @@ class ChatDashboard(BaseChat):
try:
datas = self.database.run(self.db_connect, chart_item.sql)
chart_data: ChartData = ChartData()
chart_data.chart_sql = chart_item['sql']
chart_data.chart_type = chart_item['showcase']
chart_data.chart_name = chart_item['title']
chart_data.chart_desc = chart_item['thoughts']
chart_data.chart_sql = chart_item["sql"]
chart_data.chart_type = chart_item["showcase"]
chart_data.chart_name = chart_item["title"]
chart_data.chart_desc = chart_item["thoughts"]
chart_data.column_name = datas[0]
chart_data.values =datas
chart_data.values = datas
except Exception as e:
# TODO 修复流程
print(str(e))

View File

@ -1,5 +1,6 @@
from pilot.prompts.example_base import ExampleSelector
from pilot.common.schema import ExampleType
## Two examples are defined by default
EXAMPLES = [
{
@ -34,4 +35,6 @@ EXAMPLES = [
},
]
sql_data_example = ExampleSelector(examples_record=EXAMPLES, use_example=True, type=ExampleType.ONE_SHOT.value)
sql_data_example = ExampleSelector(
examples_record=EXAMPLES, use_example=True, type=ExampleType.ONE_SHOT.value
)

View File

@ -9,6 +9,8 @@ from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
CFG = Config()
class SqlAction(NamedTuple):
sql: str
thoughts: Dict
@ -35,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser):
df = pd.DataFrame(data[1:], columns=data[0])
if CFG.NEW_SERVER_MODE:
html = df.to_html(index=False, escape=False, sparsify=False)
html = ''.join(html.split())
html = "".join(html.split())
else:
table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}

View File

@ -46,6 +46,6 @@ prompt = PromptTemplate(
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
example_selector=sql_data_example
example_selector=sql_data_example,
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -8,7 +8,9 @@ from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """
PROMPT_SCENE_DEFINE = (
"""You are an assistant that answers user specialized database questions. """
)
# PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
# {table_info}
@ -27,21 +29,24 @@ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelli
# """
_DEFAULT_TEMPLATE_EN = """
You are a database expert. you will be given metadata information about a database or table, and then provide a brief summary and answer to the question. For example, question: "How many tables are there in database 'db_gpt'?" , answer: "There are 5 tables in database 'db_gpt', which are 'book', 'book_category', 'borrower', 'borrowing', and 'category'.
Based on the database metadata information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
database metadata information:
Provide professional answers to requests and questions. If you can't get an answer from what you've provided, say: "Insufficient information in the knowledge base is available to answer this question." Feel free to fudge information.
Use the following tables generate sql if have any table info:
{table_info}
question:
user question:
{input}
think step by step.
"""
_DEFAULT_TEMPLATE_ZH = """
你是一位数据库专家你将获得有关数据库或表的元数据信息然后提供简要的总结和回答例如问题数据库 'db_gpt' 中有多少个表 答案数据库 'db_gpt' 中有 5 个表分别是 'book''book_category''borrower''borrowing' 'category'
根据以下数据库元数据信息为用户提供专业简洁的答案如果无法从提供的内容中获取答案请说知识库中提供的信息不足以回答此问题 禁止随意捏造信息
数据库元数据信息:
根据要求和问题提供专业的答案如果无法从提供的内容中获取答案请说知识库中提供的信息不足以回答此问题 禁止随意捏造信息
使用一下表结构信息:
{table_info}
问题:
{input}
一步步思考
"""
_DEFAULT_TEMPLATE = (

View File

@ -12,7 +12,6 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text

View File

@ -1,4 +1,3 @@
import signal
import traceback
import os
import shutil
@ -10,7 +9,10 @@ sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import (
LOGDIR
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.utils import build_logger
@ -62,7 +64,6 @@ app.add_middleware(
)
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
app.add_route("/test", "static/test.html")
app.include_router(knowledge_router)
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)

View File

@ -6,7 +6,7 @@
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
</head>
<body>
<div id="output"></div>
<div id="output">Hello World! I'm DB-GPT!</div>
<script>
$(document).ready(function() {
var source = new EventSource("/v1/chat/completions");