mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 00:37:34 +00:00
Merge remote-tracking branch 'origin/dev_ty_06_end' into llm_framework
This commit is contained in:
commit
87b1159ff6
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,6 +7,7 @@ __pycache__/
|
|||||||
*.so
|
*.so
|
||||||
|
|
||||||
message/
|
message/
|
||||||
|
static/
|
||||||
|
|
||||||
.env
|
.env
|
||||||
.idea
|
.idea
|
||||||
|
@ -17,6 +17,8 @@ class Config(metaclass=Singleton):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the Config class"""
|
"""Initialize the Config class"""
|
||||||
|
|
||||||
|
self.NEW_SERVER_MODE = False
|
||||||
|
|
||||||
# Gradio language version: en, zh
|
# Gradio language version: en, zh
|
||||||
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
||||||
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
|
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
|
||||||
|
@ -8,6 +8,7 @@ duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if os.path.isfile(duckdb_path):
|
if os.path.isfile(duckdb_path):
|
||||||
cursor = duckdb.connect(duckdb_path).cursor()
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
|
# cursor.execute("SELECT * FROM chat_history limit 20")
|
||||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||||
data = cursor.fetchall()
|
data = cursor.fetchall()
|
||||||
print(data)
|
print(data)
|
||||||
|
@ -8,18 +8,20 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
|||||||
from pilot.scene.message import (
|
from pilot.scene.message import (
|
||||||
OnceConversation,
|
OnceConversation,
|
||||||
conversation_from_dict,
|
conversation_from_dict,
|
||||||
|
_conversation_to_dic,
|
||||||
conversations_to_dict,
|
conversations_to_dict,
|
||||||
)
|
)
|
||||||
from pilot.common.formatting import MyEncoder
|
from pilot.common.formatting import MyEncoder
|
||||||
|
|
||||||
default_db_path = os.path.join(os.getcwd(), "message")
|
default_db_path = os.path.join(os.getcwd(), "message")
|
||||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||||
table_name = "chat_history"
|
table_name = 'chat_history'
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||||
|
|
||||||
def __init__(self, chat_session_id: str):
|
def __init__(self, chat_session_id: str):
|
||||||
self.chat_seesion_id = chat_session_id
|
self.chat_seesion_id = chat_session_id
|
||||||
os.makedirs(default_db_path, exist_ok=True)
|
os.makedirs(default_db_path, exist_ok=True)
|
||||||
@ -27,26 +29,28 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
self.__init_chat_history_tables()
|
self.__init_chat_history_tables()
|
||||||
|
|
||||||
def __init_chat_history_tables(self):
|
def __init_chat_history_tables(self):
|
||||||
|
|
||||||
# 检查表是否存在
|
# 检查表是否存在
|
||||||
result = self.connect.execute(
|
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
[table_name]).fetchall()
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
# 如果表不存在,则创建新表
|
# 如果表不存在,则创建新表
|
||||||
self.connect.execute(
|
self.connect.execute(
|
||||||
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
|
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
|
||||||
)
|
|
||||||
|
|
||||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
||||||
return cursor.fetchone()
|
content = cursor.fetchone()
|
||||||
|
if content:
|
||||||
|
return content[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
def messages(self) -> List[OnceConversation]:
|
def messages(self) -> List[OnceConversation]:
|
||||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||||
if context:
|
if context:
|
||||||
conversations: List[OnceConversation] = json.loads(context[0])
|
conversations: List[OnceConversation] = json.loads(context)
|
||||||
return conversations
|
return conversations
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -54,50 +58,27 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||||
conversations: List[OnceConversation] = []
|
conversations: List[OnceConversation] = []
|
||||||
if context:
|
if context:
|
||||||
conversations = json.load(context)
|
conversations = json.loads(context)
|
||||||
conversations.append(once_message)
|
conversations.append(_conversation_to_dic(once_message))
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
if context:
|
if context:
|
||||||
cursor.execute(
|
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
||||||
"UPDATE chat_history set messages=? where conv_uid=?",
|
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
|
||||||
[
|
|
||||||
json.dumps(
|
|
||||||
conversations_to_dict(conversations),
|
|
||||||
ensure_ascii=False,
|
|
||||||
indent=4,
|
|
||||||
),
|
|
||||||
self.chat_seesion_id,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
cursor.execute(
|
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||||
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
[self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)])
|
||||||
[
|
|
||||||
self.chat_seesion_id,
|
|
||||||
"",
|
|
||||||
json.dumps(
|
|
||||||
conversations_to_dict(conversations),
|
|
||||||
ensure_ascii=False,
|
|
||||||
indent=4,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute(
|
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||||
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
|
||||||
)
|
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
def delete(self) -> bool:
|
def delete(self) -> bool:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute(
|
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||||
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
|
||||||
)
|
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -106,9 +87,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
if os.path.isfile(duckdb_path):
|
if os.path.isfile(duckdb_path):
|
||||||
cursor = duckdb.connect(duckdb_path).cursor()
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
if user_name:
|
if user_name:
|
||||||
cursor.execute(
|
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
|
||||||
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||||
# 获取查询结果字段名
|
# 获取查询结果字段名
|
||||||
@ -124,11 +103,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_messages(self) -> List[OnceConversation]:
|
|
||||||
|
def get_messages(self)-> List[OnceConversation]:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute(
|
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||||
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
|
||||||
)
|
|
||||||
context = cursor.fetchone()
|
context = cursor.fetchone()
|
||||||
if context:
|
if context:
|
||||||
return json.loads(context[0])
|
return json.loads(context[0])
|
||||||
|
@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
elif "ai:" in message:
|
elif "ai:" in message:
|
||||||
history.append(
|
history.append(
|
||||||
{
|
{
|
||||||
"role": "ai",
|
"role": "assistant",
|
||||||
"content": message.split("ai:")[1],
|
"content": message.split("ai:")[1],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
for m in temp_his:
|
for m in temp_his:
|
||||||
if m["role"] == "user":
|
if m["role"] == "user":
|
||||||
last_user_input = m
|
last_user_input = m
|
||||||
|
break
|
||||||
if last_user_input:
|
if last_user_input:
|
||||||
history.remove(last_user_input)
|
history.remove(last_user_input)
|
||||||
history.append(last_user_input)
|
history.append(last_user_input)
|
||||||
|
@ -2,7 +2,7 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
|
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
|
||||||
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@ -12,12 +12,7 @@ from fastapi.responses import JSONResponse
|
|||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.server.api_v1.api_view_model import (
|
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||||
Result,
|
|
||||||
ConversationVo,
|
|
||||||
MessageVo,
|
|
||||||
ChatSceneVo,
|
|
||||||
)
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||||
@ -37,6 +32,9 @@ CHAT_FACTORY = ChatFactory()
|
|||||||
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
||||||
knowledge_service = KnowledgeService()
|
knowledge_service = KnowledgeService()
|
||||||
|
|
||||||
|
model_semaphore = None
|
||||||
|
global_counter = 0
|
||||||
|
|
||||||
|
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
message = ""
|
message = ""
|
||||||
@ -76,15 +74,15 @@ async def dialogue_list(response: Response, user_id: str = None):
|
|||||||
)
|
)
|
||||||
dialogues.append(conv_vo)
|
dialogues.append(conv_vo)
|
||||||
|
|
||||||
return Result[ConversationVo].succ(dialogues)
|
return Result[ConversationVo].succ(dialogues[-10:][::-1])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
|
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
|
||||||
async def dialogue_scenes():
|
async def dialogue_scenes():
|
||||||
scene_vos: List[ChatSceneVo] = []
|
scene_vos: List[ChatSceneVo] = []
|
||||||
new_modes: List[ChatScene] = [
|
new_modes: List[ChatScene] = [
|
||||||
ChatScene.ChatDb,
|
ChatScene.ChatWithDbExecute,
|
||||||
ChatScene.ChatData,
|
ChatScene.ChatWithDbQA,
|
||||||
ChatScene.ChatDashboard,
|
ChatScene.ChatDashboard,
|
||||||
ChatScene.ChatKnowledge,
|
ChatScene.ChatKnowledge,
|
||||||
ChatScene.ChatExecution,
|
ChatScene.ChatExecution,
|
||||||
@ -105,7 +103,7 @@ async def dialogue_scenes():
|
|||||||
|
|
||||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||||
async def dialogue_new(
|
async def dialogue_new(
|
||||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||||
):
|
):
|
||||||
unique_id = uuid.uuid1()
|
unique_id = uuid.uuid1()
|
||||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||||
@ -139,9 +137,9 @@ def knowledge_list():
|
|||||||
|
|
||||||
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
||||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||||
if ChatScene.ChatDb.value == chat_mode:
|
if ChatScene.ChatWithDbQA.value == chat_mode:
|
||||||
return Result.succ(get_db_list())
|
return Result.succ(get_db_list())
|
||||||
elif ChatScene.ChatData.value == chat_mode:
|
elif ChatScene.ChatWithDbExecute.value == chat_mode:
|
||||||
return Result.succ(get_db_list())
|
return Result.succ(get_db_list())
|
||||||
elif ChatScene.ChatDashboard.value == chat_mode:
|
elif ChatScene.ChatDashboard.value == chat_mode:
|
||||||
return Result.succ(get_db_list())
|
return Result.succ(get_db_list())
|
||||||
@ -179,6 +177,16 @@ async def dialogue_history_messages(con_uid: str):
|
|||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||||
|
if not dialogue.chat_mode:
|
||||||
|
dialogue.chat_mode = ChatScene.ChatNormal.value
|
||||||
|
if not dialogue.conv_uid:
|
||||||
|
dialogue.conv_uid = str(uuid.uuid1())
|
||||||
|
|
||||||
|
global model_semaphore, global_counter
|
||||||
|
global_counter += 1
|
||||||
|
if model_semaphore is None:
|
||||||
|
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||||
|
await model_semaphore.acquire()
|
||||||
|
|
||||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||||
raise StopAsyncIteration(
|
raise StopAsyncIteration(
|
||||||
@ -190,99 +198,65 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
"user_input": dialogue.user_input,
|
"user_input": dialogue.user_input,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ChatScene.ChatDb == dialogue.chat_mode:
|
if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
|
||||||
chat_param.update("db_name", dialogue.select_param)
|
chat_param.update({"db_name": dialogue.select_param})
|
||||||
elif ChatScene.ChatData == dialogue.chat_mode:
|
elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
|
||||||
chat_param.update("db_name", dialogue.select_param)
|
chat_param.update({"db_name": dialogue.select_param})
|
||||||
elif ChatScene.ChatDashboard == dialogue.chat_mode:
|
elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
|
||||||
chat_param.update("db_name", dialogue.select_param)
|
chat_param.update({"db_name": dialogue.select_param})
|
||||||
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
elif ChatScene.ChatExecution.value == dialogue.chat_mode:
|
||||||
chat_param.update("plugin_selector", dialogue.select_param)
|
chat_param.update({"plugin_selector": dialogue.select_param})
|
||||||
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
|
elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
|
||||||
chat_param.update("knowledge_space", dialogue.select_param)
|
chat_param.update({"knowledge_space": dialogue.select_param})
|
||||||
|
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||||
|
background_tasks = BackgroundTasks()
|
||||||
|
background_tasks.add_task(release_model_semaphore)
|
||||||
|
headers = {
|
||||||
|
# "Content-Type": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
# "Transfer-Encoding": "chunked",
|
||||||
|
}
|
||||||
|
|
||||||
if not chat.prompt_template.stream_out:
|
if not chat.prompt_template.stream_out:
|
||||||
return non_stream_response(chat)
|
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
|
||||||
|
background=background_tasks)
|
||||||
else:
|
else:
|
||||||
# generator = stream_generator(chat)
|
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
|
||||||
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain'))
|
background=background_tasks)
|
||||||
# return result
|
|
||||||
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
|
||||||
|
|
||||||
|
|
||||||
def stream_test():
|
def release_model_semaphore():
|
||||||
for message in ["Hello", "world", "how", "are", "you"]:
|
model_semaphore.release()
|
||||||
yield message
|
|
||||||
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def stream_generator(chat):
|
async def no_stream_generator(chat):
|
||||||
|
msg = chat.nostream_call()
|
||||||
|
msg = msg.replace("\n", "\\n")
|
||||||
|
yield f"data: {msg}\n\n"
|
||||||
|
|
||||||
|
async def stream_generator(chat):
|
||||||
model_response = chat.stream_call()
|
model_response = chat.stream_call()
|
||||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
if not CFG.NEW_SERVER_MODE:
|
||||||
if chunk:
|
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
if chunk:
|
||||||
chunk, chat.skip_echo_len
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
)
|
chat.current_message.add_ai_message(msg)
|
||||||
chat.current_message.add_ai_message(msg)
|
msg = msg.replace("\n", "\\n")
|
||||||
yield msg
|
yield f"data:{msg}\n\n"
|
||||||
# chat.current_message.add_ai_message(msg)
|
await asyncio.sleep(0.1)
|
||||||
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
|
else:
|
||||||
# json_text = json.dumps(vo.__dict__)
|
for chunk in model_response:
|
||||||
# yield json_text.encode('utf-8')
|
if chunk:
|
||||||
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
|
chat.current_message.add_ai_message(msg)
|
||||||
|
|
||||||
|
msg = msg.replace("\n", "\\n")
|
||||||
|
yield f"data:{msg}\n\n"
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
|
|
||||||
|
|
||||||
# def stream_response(chat):
|
|
||||||
# logger.info("stream out start!")
|
|
||||||
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
|
||||||
# return api_response
|
|
||||||
|
|
||||||
|
|
||||||
def message2Vo(message: dict, order) -> MessageVo:
|
def message2Vo(message: dict, order) -> MessageVo:
|
||||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||||
return MessageVo(
|
|
||||||
role=message["type"], context=message["data"]["content"], order=order
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def non_stream_response(chat):
|
|
||||||
logger.info("not stream out, wait model response!")
|
|
||||||
return chat.nostream_call()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/db/types", response_model=Result[str])
|
|
||||||
async def db_types():
|
|
||||||
return Result.succ(["mysql", "duckdb"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/db/list", response_model=Result[str])
|
|
||||||
async def db_list():
|
|
||||||
db = CFG.local_db
|
|
||||||
dbs = db.get_database_list()
|
|
||||||
return Result.succ(dbs)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/knowledge/list")
|
|
||||||
async def knowledge_list():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/knowledge/add")
|
|
||||||
async def knowledge_add():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/knowledge/delete")
|
|
||||||
async def knowledge_delete():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/knowledge/types")
|
|
||||||
async def knowledge_types():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/knowledge/detail")
|
|
||||||
async def knowledge_detail():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
@ -47,6 +47,8 @@ class BaseOutputParser(ABC):
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
|
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
|
||||||
|
if b"\0" in chunk:
|
||||||
|
chunk = chunk.replace(b"\0", b"")
|
||||||
data = json.loads(chunk.decode())
|
data = json.loads(chunk.decode())
|
||||||
|
|
||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
@ -95,11 +97,8 @@ class BaseOutputParser(ABC):
|
|||||||
def parse_model_nostream_resp(self, response, sep: str):
|
def parse_model_nostream_resp(self, response, sep: str):
|
||||||
text = response.text.strip()
|
text = response.text.strip()
|
||||||
text = text.rstrip()
|
text = text.rstrip()
|
||||||
respObj = json.loads(text)
|
text = text.strip(b"\x00".decode())
|
||||||
|
respObj_ex = json.loads(text)
|
||||||
xx = respObj["response"]
|
|
||||||
xx = xx.strip(b"\x00".decode())
|
|
||||||
respObj_ex = json.loads(xx)
|
|
||||||
if respObj_ex["error_code"] == 0:
|
if respObj_ex["error_code"] == 0:
|
||||||
all_text = respObj_ex["text"]
|
all_text = respObj_ex["text"]
|
||||||
### 解析返回文本,获取AI回复部分
|
### 解析返回文本,获取AI回复部分
|
||||||
@ -123,7 +122,7 @@ class BaseOutputParser(ABC):
|
|||||||
def __extract_json(slef, s):
|
def __extract_json(slef, s):
|
||||||
i = s.index("{")
|
i = s.index("{")
|
||||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||||
if c == "}":
|
if c == "}":
|
||||||
count -= 1
|
count -= 1
|
||||||
elif c == "{":
|
elif c == "{":
|
||||||
@ -131,7 +130,7 @@ class BaseOutputParser(ABC):
|
|||||||
if count == 0:
|
if count == 0:
|
||||||
break
|
break
|
||||||
assert count == 0 # 检查是否找到最后一个'}'
|
assert count == 0 # 检查是否找到最后一个'}'
|
||||||
return s[i : j + 1]
|
return s[i: j + 1]
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
"""
|
"""
|
||||||
@ -148,9 +147,9 @@ class BaseOutputParser(ABC):
|
|||||||
# if "```" in cleaned_output:
|
# if "```" in cleaned_output:
|
||||||
# cleaned_output, _ = cleaned_output.split("```")
|
# cleaned_output, _ = cleaned_output.split("```")
|
||||||
if cleaned_output.startswith("```json"):
|
if cleaned_output.startswith("```json"):
|
||||||
cleaned_output = cleaned_output[len("```json") :]
|
cleaned_output = cleaned_output[len("```json"):]
|
||||||
if cleaned_output.startswith("```"):
|
if cleaned_output.startswith("```"):
|
||||||
cleaned_output = cleaned_output[len("```") :]
|
cleaned_output = cleaned_output[len("```"):]
|
||||||
if cleaned_output.endswith("```"):
|
if cleaned_output.endswith("```"):
|
||||||
cleaned_output = cleaned_output[: -len("```")]
|
cleaned_output = cleaned_output[: -len("```")]
|
||||||
cleaned_output = cleaned_output.strip()
|
cleaned_output = cleaned_output.strip()
|
||||||
@ -159,9 +158,9 @@ class BaseOutputParser(ABC):
|
|||||||
cleaned_output = self.__extract_json(cleaned_output)
|
cleaned_output = self.__extract_json(cleaned_output)
|
||||||
cleaned_output = (
|
cleaned_output = (
|
||||||
cleaned_output.strip()
|
cleaned_output.strip()
|
||||||
.replace("\n", " ")
|
.replace("\n", " ")
|
||||||
.replace("\\n", " ")
|
.replace("\\n", " ")
|
||||||
.replace("\\", " ")
|
.replace("\\", " ")
|
||||||
)
|
)
|
||||||
return cleaned_output
|
return cleaned_output
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType
|
|||||||
|
|
||||||
|
|
||||||
class ExampleSelector(BaseModel, ABC):
|
class ExampleSelector(BaseModel, ABC):
|
||||||
examples: List[List]
|
examples_record: List[List]
|
||||||
use_example: bool = False
|
use_example: bool = False
|
||||||
type: str = ExampleType.ONE_SHOT.value
|
type: str = ExampleType.ONE_SHOT.value
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ class ExampleSelector(BaseModel, ABC):
|
|||||||
Returns: example text
|
Returns: example text
|
||||||
"""
|
"""
|
||||||
if self.use_example:
|
if self.use_example:
|
||||||
need_use = self.examples[:count]
|
need_use = self.examples_record[:count]
|
||||||
return need_use
|
return need_use
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class ExampleSelector(BaseModel, ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if self.use_example:
|
if self.use_example:
|
||||||
need_use = self.examples[:1]
|
need_use = self.examples_record[:1]
|
||||||
return need_use
|
return need_use
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -46,7 +46,10 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
output_parser: BaseOutputParser = None
|
output_parser: BaseOutputParser = None
|
||||||
""""""
|
""""""
|
||||||
sep: str = SeparatorStyle.SINGLE.value
|
sep: str = SeparatorStyle.SINGLE.value
|
||||||
example: ExampleSelector = None
|
|
||||||
|
example_selector: ExampleSelector = None
|
||||||
|
|
||||||
|
need_historical_messages: bool = False
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
@ -20,8 +20,8 @@ class ChatScene(Enum):
|
|||||||
ChatNormal = "chat_normal"
|
ChatNormal = "chat_normal"
|
||||||
ChatDashboard = "chat_dashboard"
|
ChatDashboard = "chat_dashboard"
|
||||||
ChatKnowledge = "chat_knowledge"
|
ChatKnowledge = "chat_knowledge"
|
||||||
ChatDb = "chat_db"
|
# ChatDb = "chat_db"
|
||||||
ChatData = "chat_data"
|
# ChatData= "chat_data"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_valid_mode(mode):
|
def is_valid_mode(mode):
|
||||||
|
@ -39,6 +39,7 @@ from pilot.scene.base_message import (
|
|||||||
ViewMessage,
|
ViewMessage,
|
||||||
)
|
)
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
from pilot.server.llmserver import worker
|
||||||
|
|
||||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
@ -51,7 +52,7 @@ class BaseChat(ABC):
|
|||||||
temperature: float = 0.6
|
temperature: float = 0.6
|
||||||
max_new_tokens: int = 1024
|
max_new_tokens: int = 1024
|
||||||
# By default, keep the last two rounds of conversation records as the context
|
# By default, keep the last two rounds of conversation records as the context
|
||||||
chat_retention_rounds: int = 2
|
chat_retention_rounds: int = 1
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -59,10 +60,10 @@ class BaseChat(ABC):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chat_mode,
|
chat_mode,
|
||||||
chat_session_id,
|
chat_session_id,
|
||||||
current_user_input,
|
current_user_input,
|
||||||
):
|
):
|
||||||
self.chat_session_id = chat_session_id
|
self.chat_session_id = chat_session_id
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
@ -75,11 +76,9 @@ class BaseChat(ABC):
|
|||||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||||
self.chat_mode.value
|
self.chat_mode.value
|
||||||
]
|
]
|
||||||
self.history_message: List[OnceConversation] = []
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
|
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
### load chat_session_id's chat historys
|
|
||||||
self._load_history(self.chat_session_id)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -95,7 +94,6 @@ class BaseChat(ABC):
|
|||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
@ -104,23 +102,12 @@ class BaseChat(ABC):
|
|||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
"%Y-%m-%d %H:%M:%S"
|
|
||||||
)
|
|
||||||
# TODO
|
# TODO
|
||||||
self.current_message.tokens = 0
|
self.current_message.tokens = 0
|
||||||
current_prompt = None
|
|
||||||
|
|
||||||
if self.prompt_template.template:
|
if self.prompt_template.template:
|
||||||
current_prompt = self.prompt_template.format(**input_values)
|
current_prompt = self.prompt_template.format(**input_values)
|
||||||
|
|
||||||
### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
|
||||||
if self.history_message:
|
|
||||||
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
|
||||||
logger.info(
|
|
||||||
f"There are already {len(self.history_message)} rounds of conversations!"
|
|
||||||
)
|
|
||||||
if current_prompt:
|
|
||||||
self.current_message.add_system_message(current_prompt)
|
self.current_message.add_system_message(current_prompt)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@ -140,31 +127,24 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
show_info = ""
|
if not CFG.NEW_SERVER_MODE:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
stream=True,
|
stream=True,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
else:
|
||||||
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
|
return worker.generate_stream_gate(payload)
|
||||||
|
|
||||||
# for resp_text_trunck in ai_response_text:
|
|
||||||
# show_info = resp_text_trunck
|
|
||||||
# yield resp_text_trunck + "▌"
|
|
||||||
|
|
||||||
self.current_message.add_ai_message(show_info)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
logger.error("model response parase faild!" + str(e))
|
logger.error("model response parase faild!" + str(e))
|
||||||
self.current_message.add_view_message(
|
self.current_message.add_view_message(
|
||||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
)
|
)
|
||||||
### 对话记录存储
|
### store current conversation
|
||||||
self.memory.append(self.current_message)
|
self.memory.append(self.current_message)
|
||||||
|
|
||||||
def nostream_call(self):
|
def nostream_call(self):
|
||||||
@ -172,42 +152,27 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
### 走非流式的模型服务接口
|
rsp_str = ""
|
||||||
response = requests.post(
|
if not CFG.NEW_SERVER_MODE:
|
||||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
rsp_str = requests.post(
|
||||||
headers=headers,
|
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||||
json=payload,
|
headers=headers,
|
||||||
timeout=120,
|
json=payload,
|
||||||
)
|
timeout=120,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
###TODO no stream mode need independent
|
||||||
|
output = worker.generate_stream_gate(payload)
|
||||||
|
for rsp in output:
|
||||||
|
rsp_str = str(rsp, "utf-8")
|
||||||
|
print("[TEST: output]:", rsp_str)
|
||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
ai_response_text = (
|
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
|
||||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
self.prompt_template.sep)
|
||||||
response, self.prompt_template.sep
|
### model result deal
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# ### MOCK
|
|
||||||
# ai_response_text = """{
|
|
||||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
|
||||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
|
||||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
|
||||||
# "command": {
|
|
||||||
# "name": "histogram-executor",
|
|
||||||
# "args": {
|
|
||||||
# "title": "订单城市分布柱状图",
|
|
||||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# }"""
|
|
||||||
|
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = (
|
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||||
self.prompt_template.output_parser.parse_prompt_response(
|
|
||||||
ai_response_text
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.do_action(prompt_define_response)
|
result = self.do_action(prompt_define_response)
|
||||||
|
|
||||||
if hasattr(prompt_define_response, "thoughts"):
|
if hasattr(prompt_define_response, "thoughts"):
|
||||||
@ -235,7 +200,7 @@ class BaseChat(ABC):
|
|||||||
self.current_message.add_view_message(
|
self.current_message.add_view_message(
|
||||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
)
|
)
|
||||||
### 对话记录存储
|
### store dialogue
|
||||||
self.memory.append(self.current_message)
|
self.memory.append(self.current_message)
|
||||||
return self.current_ai_response()
|
return self.current_ai_response()
|
||||||
|
|
||||||
@ -247,67 +212,99 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
def generate_llm_text(self) -> str:
|
def generate_llm_text(self) -> str:
|
||||||
text = ""
|
text = ""
|
||||||
|
### Load scene setting or character definition
|
||||||
if self.prompt_template.template_define:
|
if self.prompt_template.template_define:
|
||||||
text = self.prompt_template.template_define + self.prompt_template.sep
|
text += self.prompt_template.template_define + self.prompt_template.sep
|
||||||
|
### Load prompt
|
||||||
|
text += self.__load_system_message()
|
||||||
|
|
||||||
### 处理历史信息
|
### Load examples
|
||||||
if len(self.history_message) > self.chat_retention_rounds:
|
text += self.__load_example_messages()
|
||||||
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
|
||||||
for first_message in self.history_message[0].messages:
|
|
||||||
if not isinstance(first_message, ViewMessage):
|
|
||||||
text += (
|
|
||||||
first_message.type
|
|
||||||
+ ":"
|
|
||||||
+ first_message.content
|
|
||||||
+ self.prompt_template.sep
|
|
||||||
)
|
|
||||||
|
|
||||||
index = self.chat_retention_rounds - 1
|
### Load History
|
||||||
for last_message in self.history_message[-index:].messages:
|
text += self.__load_histroy_messages()
|
||||||
if not isinstance(last_message, ViewMessage):
|
|
||||||
text += (
|
|
||||||
last_message.type
|
|
||||||
+ ":"
|
|
||||||
+ last_message.content
|
|
||||||
+ self.prompt_template.sep
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
### 直接历史记录拼接
|
|
||||||
for conversation in self.history_message:
|
|
||||||
for message in conversation.messages:
|
|
||||||
if not isinstance(message, ViewMessage):
|
|
||||||
text += (
|
|
||||||
message.type
|
|
||||||
+ ":"
|
|
||||||
+ message.content
|
|
||||||
+ self.prompt_template.sep
|
|
||||||
)
|
|
||||||
### current conversation
|
|
||||||
|
|
||||||
for now_message in self.current_message.messages:
|
|
||||||
text += (
|
|
||||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
|
||||||
)
|
|
||||||
|
|
||||||
|
### Load User Input
|
||||||
|
text += self.__load_user_message()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# 暂时为了兼容前端
|
def __load_system_message(self):
|
||||||
|
system_convs = self.current_message.get_system_conv()
|
||||||
|
system_text = ""
|
||||||
|
for system_conv in system_convs:
|
||||||
|
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||||
|
return system_text
|
||||||
|
|
||||||
|
def __load_user_message(self):
|
||||||
|
user_conv = self.current_message.get_user_conv()
|
||||||
|
if user_conv:
|
||||||
|
return user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
||||||
|
else:
|
||||||
|
raise ValueError("Hi! What do you want to talk about?")
|
||||||
|
|
||||||
|
def __load_example_messages(self):
|
||||||
|
example_text = ""
|
||||||
|
if self.prompt_template.example_selector:
|
||||||
|
for round_conv in self.prompt_template.example_selector.examples():
|
||||||
|
for round_message in round_conv['messages']:
|
||||||
|
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||||
|
example_text += (
|
||||||
|
round_message['type']
|
||||||
|
+ ":"
|
||||||
|
+ round_message['data']['content']
|
||||||
|
+ self.prompt_template.sep
|
||||||
|
)
|
||||||
|
return example_text
|
||||||
|
|
||||||
|
def __load_histroy_messages(self):
|
||||||
|
history_text = ""
|
||||||
|
if self.prompt_template.need_historical_messages:
|
||||||
|
if self.history_message:
|
||||||
|
logger.info(
|
||||||
|
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
|
||||||
|
)
|
||||||
|
if len(self.history_message) > self.chat_retention_rounds:
|
||||||
|
for first_message in self.history_message[0]['messages']:
|
||||||
|
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
|
||||||
|
history_text += (
|
||||||
|
first_message['type']
|
||||||
|
+ ":"
|
||||||
|
+ first_message['data']['content']
|
||||||
|
+ self.prompt_template.sep
|
||||||
|
)
|
||||||
|
|
||||||
|
index = self.chat_retention_rounds - 1
|
||||||
|
for round_conv in self.history_message[-index:]:
|
||||||
|
for round_message in round_conv['messages']:
|
||||||
|
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||||
|
history_text += (
|
||||||
|
round_message['type']
|
||||||
|
+ ":"
|
||||||
|
+ round_message['data']['content']
|
||||||
|
+ self.prompt_template.sep
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
### user all history
|
||||||
|
for conversation in self.history_message:
|
||||||
|
for message in conversation['messages']:
|
||||||
|
### histroy message not have promot and view info
|
||||||
|
if not message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||||
|
history_text += (
|
||||||
|
message['type']
|
||||||
|
+ ":"
|
||||||
|
+ message['data']['content']
|
||||||
|
+ self.prompt_template.sep
|
||||||
|
)
|
||||||
|
|
||||||
|
return history_text
|
||||||
|
|
||||||
def current_ai_response(self) -> str:
|
def current_ai_response(self) -> str:
|
||||||
for message in self.current_message.messages:
|
for message in self.current_message.messages:
|
||||||
if message.type == "view":
|
if message.type == "view":
|
||||||
return message.content
|
return message.content
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _load_history(self, session_id: str) -> List[OnceConversation]:
|
|
||||||
"""
|
|
||||||
load chat history by session_id
|
|
||||||
Args:
|
|
||||||
session_id:
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
return self.memory.messages()
|
|
||||||
|
|
||||||
def generate(self, p) -> str:
|
def generate(self, p) -> str:
|
||||||
"""
|
"""
|
||||||
generate context for LLM input
|
generate context for LLM input
|
||||||
|
@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
|
PROMPT_SCENE_DEFINE = None
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE = """
|
||||||
|
@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
|
|||||||
|
|
||||||
## Two examples are defined by default
|
## Two examples are defined by default
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
|
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}],
|
||||||
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
|
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
|
||||||
]
|
]
|
||||||
|
|
||||||
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
||||||
|
@ -5,14 +5,12 @@ from pilot.scene.base import ChatScene
|
|||||||
from pilot.common.schema import SeparatorStyle, ExampleType
|
from pilot.common.schema import SeparatorStyle, ExampleType
|
||||||
|
|
||||||
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||||
from pilot.scene.chat_execution.example import example
|
from pilot.scene.chat_execution.example import plugin_example
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
|
||||||
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE = """
|
||||||
Goals:
|
Goals:
|
||||||
{input}
|
{input}
|
||||||
@ -51,7 +49,7 @@ prompt = PromptTemplate(
|
|||||||
output_parser=PluginChatOutputParser(
|
output_parser=PluginChatOutputParser(
|
||||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
||||||
),
|
),
|
||||||
example=example,
|
example_selector=plugin_example,
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle
|
|||||||
|
|
||||||
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||||
|
|
||||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
PROMPT_SCENE_DEFINE = None
|
||||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
@ -85,16 +85,22 @@ class OnceConversation:
|
|||||||
self.messages.clear()
|
self.messages.clear()
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
|
|
||||||
def get_user_message(self):
|
def get_user_conv(self):
|
||||||
for once in self.messages:
|
for message in self.messages:
|
||||||
if isinstance(once, HumanMessage):
|
if isinstance(message, HumanMessage):
|
||||||
return once.content
|
return message
|
||||||
return ""
|
return None
|
||||||
|
|
||||||
|
def get_system_conv(self):
|
||||||
|
system_convs = []
|
||||||
|
for message in self.messages:
|
||||||
|
if isinstance(message, SystemMessage):
|
||||||
|
system_convs.append(message)
|
||||||
|
return system_convs
|
||||||
|
|
||||||
def _conversation_to_dic(once: OnceConversation) -> dict:
|
def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||||
start_str: str = ""
|
start_str: str = ""
|
||||||
if once.start_date:
|
if hasattr(once, 'start_date') and once.start_date:
|
||||||
if isinstance(once.start_date, datetime):
|
if isinstance(once.start_date, datetime):
|
||||||
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
|
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
else:
|
else:
|
||||||
|
78
pilot/server/dbgpt_server.py
Normal file
78
pilot/server/dbgpt_server.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import traceback
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import (
|
||||||
|
DATASETS_DIR,
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
LLM_MODEL_CONFIG,
|
||||||
|
LOGDIR,
|
||||||
|
)
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
from pilot.server.webserver_base import server_init
|
||||||
|
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi import FastAPI, applications
|
||||||
|
from fastapi.openapi.docs import get_swagger_ui_html
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||||
|
|
||||||
|
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
|
|
||||||
|
|
||||||
|
def swagger_monkey_patch(*args, **kwargs):
|
||||||
|
return get_swagger_ui_html(
|
||||||
|
*args, **kwargs,
|
||||||
|
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||||
|
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
origins = ["*"]
|
||||||
|
|
||||||
|
# 添加跨域中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
|
||||||
|
app.add_route("/test", "static/test.html")
|
||||||
|
|
||||||
|
app.include_router(api_v1)
|
||||||
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||||
|
|
||||||
|
# old version server config
|
||||||
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
|
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||||
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
|
|
||||||
|
# init server config
|
||||||
|
args = parser.parse_args()
|
||||||
|
server_init(args)
|
||||||
|
CFG.NEW_SERVER_MODE = True
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=5000)
|
@ -33,7 +33,7 @@ class ModelWorker:
|
|||||||
model_path = model_path[:-1]
|
model_path = model_path[:-1]
|
||||||
self.model_name = model_name or model_path.split("/")[-1]
|
self.model_name = model_name or model_path.split("/")[-1]
|
||||||
self.device = device
|
self.device = device
|
||||||
|
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
|
||||||
self.ml = ModelLoader(model_path=model_path)
|
self.ml = ModelLoader(model_path=model_path)
|
||||||
self.model, self.tokenizer = self.ml.loader(
|
self.model, self.tokenizer = self.ml.loader(
|
||||||
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
||||||
@ -41,11 +41,11 @@ class ModelWorker:
|
|||||||
|
|
||||||
if not isinstance(self.model, str):
|
if not isinstance(self.model, str):
|
||||||
if hasattr(self.model, "config") and hasattr(
|
if hasattr(self.model, "config") and hasattr(
|
||||||
self.model.config, "max_sequence_length"
|
self.model.config, "max_sequence_length"
|
||||||
):
|
):
|
||||||
self.context_len = self.model.config.max_sequence_length
|
self.context_len = self.model.config.max_sequence_length
|
||||||
elif hasattr(self.model, "config") and hasattr(
|
elif hasattr(self.model, "config") and hasattr(
|
||||||
self.model.config, "max_position_embeddings"
|
self.model.config, "max_position_embeddings"
|
||||||
):
|
):
|
||||||
self.context_len = self.model.config.max_position_embeddings
|
self.context_len = self.model.config.max_position_embeddings
|
||||||
|
|
||||||
@ -55,29 +55,32 @@ class ModelWorker:
|
|||||||
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
||||||
|
|
||||||
|
def start_check(self):
|
||||||
|
print("LLM Model Loading Success!")
|
||||||
|
|
||||||
def get_queue_length(self):
|
def get_queue_length(self):
|
||||||
if (
|
if (
|
||||||
model_semaphore is None
|
model_semaphore is None
|
||||||
or model_semaphore._value is None
|
or model_semaphore._value is None
|
||||||
or model_semaphore._waiters is None
|
or model_semaphore._waiters is None
|
||||||
):
|
):
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
CFG.LIMIT_MODEL_CONCURRENCY
|
CFG.LIMIT_MODEL_CONCURRENCY
|
||||||
- model_semaphore._value
|
- model_semaphore._value
|
||||||
+ len(model_semaphore._waiters)
|
+ len(model_semaphore._waiters)
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
try:
|
try:
|
||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||||
):
|
):
|
||||||
# Please do not open the output in production!
|
# Please do not open the output in production!
|
||||||
# The gpt4all thread shares stdout with the parent process,
|
# The gpt4all thread shares stdout with the parent process,
|
||||||
# and opening it may affect the frontend output.
|
# and opening it may affect the frontend output.
|
||||||
# print("output: ", output)
|
print("output: ", output)
|
||||||
ret = {
|
ret = {
|
||||||
"text": output,
|
"text": output,
|
||||||
"error_code": 0,
|
"error_code": 0,
|
||||||
@ -178,10 +181,9 @@ def generate(prompt_request: PromptRequest):
|
|||||||
for rsp in output:
|
for rsp in output:
|
||||||
# rsp = rsp.decode("utf-8")
|
# rsp = rsp.decode("utf-8")
|
||||||
rsp_str = str(rsp, "utf-8")
|
rsp_str = str(rsp, "utf-8")
|
||||||
print("[TEST: output]:", rsp_str)
|
|
||||||
response.append(rsp_str)
|
response.append(rsp_str)
|
||||||
|
|
||||||
return {"response": rsp_str}
|
return rsp_str
|
||||||
|
|
||||||
|
|
||||||
@app.post("/embedding")
|
@app.post("/embedding")
|
||||||
|
19
pilot/server/static/test.html
Normal file
19
pilot/server/static/test.html
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Streaming Demo</title>
|
||||||
|
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="output"></div>
|
||||||
|
<script>
|
||||||
|
$(document).ready(function() {
|
||||||
|
var source = new EventSource("/v1/chat/completions");
|
||||||
|
source.onmessage = function(event) {
|
||||||
|
$("#output").append(event.data);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
@ -60,7 +60,7 @@ from fastapi.exceptions import RequestValidationError
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
@ -39,6 +39,8 @@ def server_init(args):
|
|||||||
# init config
|
# init config
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
|
from pilot.server.llmserver import worker
|
||||||
|
worker.start_check()
|
||||||
load_native_plugins(cfg)
|
load_native_plugins(cfg)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
async_db_summery()
|
async_db_summery()
|
||||||
|
Loading…
Reference in New Issue
Block a user