WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-30 10:02:03 +08:00
parent 6f8f182d1d
commit a330714c61
15 changed files with 88 additions and 36 deletions

View File

@ -9,6 +9,8 @@ if __name__ == "__main__":
if os.path.isfile("../../../message/chat_history.db"): if os.path.isfile("../../../message/chat_history.db"):
cursor = duckdb.connect("../../../message/chat_history.db").cursor() cursor = duckdb.connect("../../../message/chat_history.db").cursor()
# cursor.execute("SELECT * FROM chat_history limit 20") # cursor.execute("SELECT * FROM chat_history limit 20")
cursor.execute("SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'") cursor.execute(
"SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'"
)
data = cursor.fetchall() data = cursor.fetchall()
print(data) print(data)

View File

@ -29,7 +29,6 @@ class BaseChatHistoryMemory(ABC):
def create(self, user_name: str) -> None: def create(self, user_name: str) -> None:
"""Append the message to the record in the local file""" """Append the message to the record in the local file"""
@abstractmethod @abstractmethod
def append(self, message: OnceConversation) -> None: def append(self, message: OnceConversation) -> None:
"""Append the message to the record in the local file""" """Append the message to the record in the local file"""

View File

@ -36,7 +36,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if not result: if not result:
# 如果表不存在,则创建新表 # 如果表不存在,则创建新表
self.connect.execute( self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)") "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)"
)
def __get_messages_by_conv_uid(self, conv_uid: str): def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor() cursor = self.connect.cursor()
@ -59,7 +60,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute(
"INSERT INTO chat_history(conv_uid, chat_mode summary, user_name, messages)VALUES(?,?,?,?,?)", "INSERT INTO chat_history(conv_uid, chat_mode summary, user_name, messages)VALUES(?,?,?,?,?)",
[self.chat_seesion_id, chat_mode, summary, user_name, ""]) [self.chat_seesion_id, chat_mode, summary, user_name, ""],
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
except Exception as e: except Exception as e:
@ -80,7 +82,14 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
else: else:
cursor.execute( cursor.execute(
"INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)", "INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)",
[self.chat_seesion_id, once_message.chat_mode, once_message.get_user_conv().content, "",json.dumps(conversations, ensure_ascii=False)]) [
self.chat_seesion_id,
once_message.chat_mode,
once_message.get_user_conv().content,
"",
json.dumps(conversations, ensure_ascii=False),
],
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()

View File

@ -3,7 +3,15 @@ import json
import asyncio import asyncio
import time import time
import os import os
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks from fastapi import (
APIRouter,
Request,
Body,
status,
HTTPException,
Response,
BackgroundTasks,
)
from fastapi.responses import JSONResponse, HTMLResponse from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.responses import StreamingResponse, FileResponse from fastapi.responses import StreamingResponse, FileResponse
@ -12,7 +20,12 @@ from fastapi.exceptions import RequestValidationError
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from typing import List from typing import List
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.openapi.api_v1.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@ -36,6 +49,7 @@ model_semaphore = None
global_counter = 0 global_counter = 0
static_file_path = os.path.join(os.getcwd(), "server/static") static_file_path = os.path.join(os.getcwd(), "server/static")
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = "" message = ""
for error in exc.errors(): for error in exc.errors():
@ -82,6 +96,7 @@ def knowledge_list():
params.update({space.name: space.name}) params.update({space.name: space.name})
return params return params
@router.get("/") @router.get("/")
async def read_main(): async def read_main():
return FileResponse(f"{static_file_path}/test.html") return FileResponse(f"{static_file_path}/test.html")
@ -102,8 +117,6 @@ async def dialogue_list(response: Response, user_id: str = None):
summary = item.get("summary") summary = item.get("summary")
chat_mode = item.get("chat_mode") chat_mode = item.get("chat_mode")
conv_vo: ConversationVo = ConversationVo( conv_vo: ConversationVo = ConversationVo(
conv_uid=conv_uid, conv_uid=conv_uid,
user_input=summary, user_input=summary,
@ -138,7 +151,6 @@ async def dialogue_scenes():
return Result.succ(scene_vos) return Result.succ(scene_vos)
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new( async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
@ -146,6 +158,7 @@ async def dialogue_new(
conv_vo = __new_conversation(chat_mode, user_id) conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo) return Result.succ(conv_vo)
@router.post("/v1/chat/mode/params/list", response_model=Result[dict]) @router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value): async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatWithDbQA.value == chat_mode: if ChatScene.ChatWithDbQA.value == chat_mode:
@ -232,11 +245,19 @@ async def chat_completions(dialogue: ConversationVo = Body()):
} }
if not chat.prompt_template.stream_out: if not chat.prompt_template.stream_out:
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream', return StreamingResponse(
background=background_tasks) no_stream_generator(chat),
headers=headers,
media_type="text/event-stream",
background=background_tasks,
)
else: else:
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain', return StreamingResponse(
background=background_tasks) stream_generator(chat),
headers=headers,
media_type="text/plain",
background=background_tasks,
)
def release_model_semaphore(): def release_model_semaphore():
@ -254,14 +275,18 @@ async def stream_generator(chat):
if not CFG.NEW_SERVER_MODE: if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n" yield f"data:{msg}\n\n"
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
else: else:
for chunk in model_response: for chunk in model_response:
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n" yield f"data:{msg}\n\n"
@ -273,4 +298,6 @@ async def stream_generator(chat):
def message2Vo(message: dict, order) -> MessageVo: def message2Vo(message: dict, order) -> MessageVo:
return MessageVo(role=message['type'], context=message['data']['content'], order=order) return MessageVo(
role=message["type"], context=message["data"]["content"], order=order
)

View File

@ -123,6 +123,7 @@ class KnowledgeDocumentDao:
updated_space = session.merge(document) updated_space = session.merge(document)
session.commit() session.commit()
return updated_space.id return updated_space.id
# #
# def delete_knowledge_document(self, document_id: int): # def delete_knowledge_document(self, document_id: int):
# cursor = self.conn.cursor() # cursor = self.conn.cursor()

View File

@ -115,8 +115,13 @@ class KnowledgeService:
space=space_name, space=space_name,
) )
doc = knowledge_document_dao.get_knowledge_documents(query)[0] doc = knowledge_document_dao.get_knowledge_documents(query)[0]
if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name: if (
raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync") doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
client = KnowledgeEmbedding( client = KnowledgeEmbedding(
knowledge_source=doc.content, knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(), knowledge_type=doc.doc_type.upper(),

View File

@ -107,7 +107,9 @@ class BaseChat(ABC):
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0 self.current_message.tokens = 0
if self.prompt_template.template: if self.prompt_template.template:

View File

@ -65,10 +65,10 @@ class ChatDashboard(BaseChat):
try: try:
datas = self.database.run(self.db_connect, chart_item.sql) datas = self.database.run(self.db_connect, chart_item.sql)
chart_data: ChartData = ChartData() chart_data: ChartData = ChartData()
chart_data.chart_sql = chart_item['sql'] chart_data.chart_sql = chart_item["sql"]
chart_data.chart_type = chart_item['showcase'] chart_data.chart_type = chart_item["showcase"]
chart_data.chart_name = chart_item['title'] chart_data.chart_name = chart_item["title"]
chart_data.chart_desc = chart_item['thoughts'] chart_data.chart_desc = chart_item["thoughts"]
chart_data.column_name = datas[0] chart_data.column_name = datas[0]
chart_data.values = datas chart_data.values = datas
except Exception as e: except Exception as e:

View File

@ -1,5 +1,6 @@
from pilot.prompts.example_base import ExampleSelector from pilot.prompts.example_base import ExampleSelector
from pilot.common.schema import ExampleType from pilot.common.schema import ExampleType
## Two examples are defined by default ## Two examples are defined by default
EXAMPLES = [ EXAMPLES = [
{ {
@ -34,4 +35,6 @@ EXAMPLES = [
}, },
] ]
sql_data_example = ExampleSelector(examples_record=EXAMPLES, use_example=True, type=ExampleType.ONE_SHOT.value) sql_data_example = ExampleSelector(
examples_record=EXAMPLES, use_example=True, type=ExampleType.ONE_SHOT.value
)

View File

@ -9,6 +9,8 @@ from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config from pilot.configs.config import Config
CFG = Config() CFG = Config()
class SqlAction(NamedTuple): class SqlAction(NamedTuple):
sql: str sql: str
thoughts: Dict thoughts: Dict
@ -35,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser):
df = pd.DataFrame(data[1:], columns=data[0]) df = pd.DataFrame(data[1:], columns=data[0])
if CFG.NEW_SERVER_MODE: if CFG.NEW_SERVER_MODE:
html = df.to_html(index=False, escape=False, sparsify=False) html = df.to_html(index=False, escape=False, sparsify=False)
html = ''.join(html.split()) html = "".join(html.split())
else: else:
table_style = """<style> table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444} table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}

View File

@ -46,6 +46,6 @@ prompt = PromptTemplate(
output_parser=DbChatOutputParser( output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
), ),
example_selector=sql_data_example example_selector=sql_data_example,
) )
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -8,7 +8,9 @@ from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = """You are an assistant that answers user specialized database questions. """ PROMPT_SCENE_DEFINE = (
"""You are an assistant that answers user specialized database questions. """
)
# PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info: # PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
# {table_info} # {table_info}

View File

@ -12,7 +12,6 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser): class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text

View File

@ -42,9 +42,10 @@ def signal_handler(sig, frame):
def swagger_monkey_patch(*args, **kwargs): def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html( return get_swagger_ui_html(
*args, **kwargs, *args,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', **kwargs,
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
) )