Merge remote-tracking branch 'origin/dev_ty_06_end' into llm_framework

This commit is contained in:
aries_ckt 2023-06-29 13:35:03 +08:00
commit 87b1159ff6
21 changed files with 375 additions and 315 deletions

1
.gitignore vendored
View File

@ -7,6 +7,7 @@ __pycache__/
*.so *.so
message/ message/
static/
.env .env
.idea .idea

View File

@ -17,6 +17,8 @@ class Config(metaclass=Singleton):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the Config class""" """Initialize the Config class"""
self.NEW_SERVER_MODE = False
# Gradio language version: en, zh # Gradio language version: en, zh
self.LANGUAGE = os.getenv("LANGUAGE", "en") self.LANGUAGE = os.getenv("LANGUAGE", "en")
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860)) self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))

View File

@ -8,6 +8,7 @@ duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
if __name__ == "__main__": if __name__ == "__main__":
if os.path.isfile(duckdb_path): if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor() cursor = duckdb.connect(duckdb_path).cursor()
# cursor.execute("SELECT * FROM chat_history limit 20")
cursor.execute("SELECT * FROM chat_history limit 20") cursor.execute("SELECT * FROM chat_history limit 20")
data = cursor.fetchall() data = cursor.fetchall()
print(data) print(data)

View File

@ -8,18 +8,20 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.scene.message import ( from pilot.scene.message import (
OnceConversation, OnceConversation,
conversation_from_dict, conversation_from_dict,
_conversation_to_dic,
conversations_to_dict, conversations_to_dict,
) )
from pilot.common.formatting import MyEncoder from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message") default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = "chat_history" table_name = 'chat_history'
CFG = Config() CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory): class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str): def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True) os.makedirs(default_db_path, exist_ok=True)
@ -27,26 +29,28 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
self.__init_chat_history_tables() self.__init_chat_history_tables()
def __init_chat_history_tables(self): def __init_chat_history_tables(self):
# 检查表是否存在 # 检查表是否存在
result = self.connect.execute( result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name] [table_name]).fetchall()
).fetchall()
if not result: if not result:
# 如果表不存在,则创建新表 # 如果表不存在,则创建新表
self.connect.execute( self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)" "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
)
def __get_messages_by_conv_uid(self, conv_uid: str): def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
return cursor.fetchone() content = cursor.fetchone()
if content:
return content[0]
else:
return None
def messages(self) -> List[OnceConversation]: def messages(self) -> List[OnceConversation]:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id) context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
if context: if context:
conversations: List[OnceConversation] = json.loads(context[0]) conversations: List[OnceConversation] = json.loads(context)
return conversations return conversations
return [] return []
@ -54,50 +58,27 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
context = self.__get_messages_by_conv_uid(self.chat_seesion_id) context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
conversations: List[OnceConversation] = [] conversations: List[OnceConversation] = []
if context: if context:
conversations = json.load(context) conversations = json.loads(context)
conversations.append(once_message) conversations.append(_conversation_to_dic(once_message))
cursor = self.connect.cursor() cursor = self.connect.cursor()
if context: if context:
cursor.execute( cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
"UPDATE chat_history set messages=? where conv_uid=?", [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
[
json.dumps(
conversations_to_dict(conversations),
ensure_ascii=False,
indent=4,
),
self.chat_seesion_id,
],
)
else: else:
cursor.execute( cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", [self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)])
[
self.chat_seesion_id,
"",
json.dumps(
conversations_to_dict(conversations),
ensure_ascii=False,
indent=4,
),
],
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def clear(self) -> None: def clear(self) -> None:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def delete(self) -> bool: def delete(self) -> bool:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit() cursor.commit()
return True return True
@ -106,9 +87,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if os.path.isfile(duckdb_path): if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor() cursor = duckdb.connect(duckdb_path).cursor()
if user_name: if user_name:
cursor.execute( cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
)
else: else:
cursor.execute("SELECT * FROM chat_history limit 20") cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名 # 获取查询结果字段名
@ -124,11 +103,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return [] return []
def get_messages(self) -> List[OnceConversation]:
def get_messages(self)-> List[OnceConversation]:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
context = cursor.fetchone() context = cursor.fetchone()
if context: if context:
return json.loads(context[0]) return json.loads(context[0])

View File

@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
elif "ai:" in message: elif "ai:" in message:
history.append( history.append(
{ {
"role": "ai", "role": "assistant",
"content": message.split("ai:")[1], "content": message.split("ai:")[1],
} }
) )
@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
for m in temp_his: for m in temp_his:
if m["role"] == "user": if m["role"] == "user":
last_user_input = m last_user_input = m
break
if last_user_input: if last_user_input:
history.remove(last_user_input) history.remove(last_user_input)
history.append(last_user_input) history.append(last_user_input)

View File

@ -2,7 +2,7 @@ import uuid
import json import json
import asyncio import asyncio
import time import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -12,12 +12,7 @@ from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from typing import List from typing import List
from pilot.server.api_v1.api_view_model import ( from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@ -37,6 +32,9 @@ CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log") logger = build_logger("api_v1", LOGDIR + "api_v1.log")
knowledge_service = KnowledgeService() knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = "" message = ""
@ -76,15 +74,15 @@ async def dialogue_list(response: Response, user_id: str = None):
) )
dialogues.append(conv_vo) dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues) return Result[ConversationVo].succ(dialogues[-10:][::-1])
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]]) @router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes(): async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = [] scene_vos: List[ChatSceneVo] = []
new_modes: List[ChatScene] = [ new_modes: List[ChatScene] = [
ChatScene.ChatDb, ChatScene.ChatWithDbExecute,
ChatScene.ChatData, ChatScene.ChatWithDbQA,
ChatScene.ChatDashboard, ChatScene.ChatDashboard,
ChatScene.ChatKnowledge, ChatScene.ChatKnowledge,
ChatScene.ChatExecution, ChatScene.ChatExecution,
@ -105,7 +103,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new( async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
): ):
unique_id = uuid.uuid1() unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@ -139,9 +137,9 @@ def knowledge_list():
@router.post("/v1/chat/mode/params/list", response_model=Result[dict]) @router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value): async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatDb.value == chat_mode: if ChatScene.ChatWithDbQA.value == chat_mode:
return Result.succ(get_db_list()) return Result.succ(get_db_list())
elif ChatScene.ChatData.value == chat_mode: elif ChatScene.ChatWithDbExecute.value == chat_mode:
return Result.succ(get_db_list()) return Result.succ(get_db_list())
elif ChatScene.ChatDashboard.value == chat_mode: elif ChatScene.ChatDashboard.value == chat_mode:
return Result.succ(get_db_list()) return Result.succ(get_db_list())
@ -179,6 +177,16 @@ async def dialogue_history_messages(con_uid: str):
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()): async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value
if not dialogue.conv_uid:
dialogue.conv_uid = str(uuid.uuid1())
global model_semaphore, global_counter
global_counter += 1
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
if not ChatScene.is_valid_mode(dialogue.chat_mode): if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration( raise StopAsyncIteration(
@ -190,99 +198,65 @@ async def chat_completions(dialogue: ConversationVo = Body()):
"user_input": dialogue.user_input, "user_input": dialogue.user_input,
} }
if ChatScene.ChatDb == dialogue.chat_mode: if ChatScene.ChatWithDbQA.value == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param) chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatData == dialogue.chat_mode: elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param) chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatDashboard == dialogue.chat_mode: elif ChatScene.ChatDashboard.value == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param) chat_param.update({"db_name": dialogue.select_param})
elif ChatScene.ChatExecution == dialogue.chat_mode: elif ChatScene.ChatExecution.value == dialogue.chat_mode:
chat_param.update("plugin_selector", dialogue.select_param) chat_param.update({"plugin_selector": dialogue.select_param})
elif ChatScene.ChatKnowledge == dialogue.chat_mode: elif ChatScene.ChatKnowledge.value == dialogue.chat_mode:
chat_param.update("knowledge_space", dialogue.select_param) chat_param.update({"knowledge_space": dialogue.select_param})
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
headers = {
# "Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
# "Transfer-Encoding": "chunked",
}
if not chat.prompt_template.stream_out: if not chat.prompt_template.stream_out:
return non_stream_response(chat) return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
background=background_tasks)
else: else:
# generator = stream_generator(chat) return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
# result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain')) background=background_tasks)
# return result
return StreamingResponse(stream_generator(chat), media_type="text/plain")
def stream_test(): def release_model_semaphore():
for message in ["Hello", "world", "how", "are", "you"]: model_semaphore.release()
yield message
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
def stream_generator(chat): async def no_stream_generator(chat):
msg = chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat):
model_response = chat.stream_call() model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if not CFG.NEW_SERVER_MODE:
if chunk: for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( if chunk:
chunk, chat.skip_echo_len msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
) chat.current_message.add_ai_message(msg)
chat.current_message.add_ai_message(msg) msg = msg.replace("\n", "\\n")
yield msg yield f"data:{msg}\n\n"
# chat.current_message.add_ai_message(msg) await asyncio.sleep(0.1)
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order) else:
# json_text = json.dumps(vo.__dict__) for chunk in model_response:
# yield json_text.encode('utf-8') if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
chat.memory.append(chat.current_message) chat.memory.append(chat.current_message)
# def stream_response(chat):
# logger.info("stream out start!")
# api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
# return api_response
def message2Vo(message: dict, order) -> MessageVo: def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 return MessageVo(role=message['type'], context=message['data']['content'], order=order)
return MessageVo(
role=message["type"], context=message["data"]["content"], order=order
)
def non_stream_response(chat):
logger.info("not stream out, wait model response!")
return chat.nostream_call()
@router.get("/v1/db/types", response_model=Result[str])
async def db_types():
return Result.succ(["mysql", "duckdb"])
@router.get("/v1/db/list", response_model=Result[str])
async def db_list():
db = CFG.local_db
dbs = db.get_database_list()
return Result.succ(dbs)
@router.get("/v1/knowledge/list")
async def knowledge_list():
return ["test1", "test2"]
@router.post("/v1/knowledge/add")
async def knowledge_add():
return ["test1", "test2"]
@router.post("/v1/knowledge/delete")
async def knowledge_delete():
return ["test1", "test2"]
@router.get("/v1/knowledge/types")
async def knowledge_types():
return ["test1", "test2"]
@router.get("/v1/knowledge/detail")
async def knowledge_detail():
return ["test1", "test2"]

View File

@ -47,6 +47,8 @@ class BaseOutputParser(ABC):
return code return code
def parse_model_stream_resp_ex(self, chunk, skip_echo_len): def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
if b"\0" in chunk:
chunk = chunk.replace(b"\0", b"")
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
@ -95,11 +97,8 @@ class BaseOutputParser(ABC):
def parse_model_nostream_resp(self, response, sep: str): def parse_model_nostream_resp(self, response, sep: str):
text = response.text.strip() text = response.text.strip()
text = text.rstrip() text = text.rstrip()
respObj = json.loads(text) text = text.strip(b"\x00".decode())
respObj_ex = json.loads(text)
xx = respObj["response"]
xx = xx.strip(b"\x00".decode())
respObj_ex = json.loads(xx)
if respObj_ex["error_code"] == 0: if respObj_ex["error_code"] == 0:
all_text = respObj_ex["text"] all_text = respObj_ex["text"]
### 解析返回文本获取AI回复部分 ### 解析返回文本获取AI回复部分
@ -123,7 +122,7 @@ class BaseOutputParser(ABC):
def __extract_json(slef, s): def __extract_json(slef, s):
i = s.index("{") i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1): for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "}": if c == "}":
count -= 1 count -= 1
elif c == "{": elif c == "{":
@ -131,7 +130,7 @@ class BaseOutputParser(ABC):
if count == 0: if count == 0:
break break
assert count == 0 # 检查是否找到最后一个'}' assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1] return s[i: j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
@ -148,9 +147,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output: # if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```") # cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"): if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :] cleaned_output = cleaned_output[len("```json"):]
if cleaned_output.startswith("```"): if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :] cleaned_output = cleaned_output[len("```"):]
if cleaned_output.endswith("```"): if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip() cleaned_output = cleaned_output.strip()
@ -159,9 +158,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output) cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = ( cleaned_output = (
cleaned_output.strip() cleaned_output.strip()
.replace("\n", " ") .replace("\n", " ")
.replace("\\n", " ") .replace("\\n", " ")
.replace("\\", " ") .replace("\\", " ")
) )
return cleaned_output return cleaned_output

View File

@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC): class ExampleSelector(BaseModel, ABC):
examples: List[List] examples_record: List[List]
use_example: bool = False use_example: bool = False
type: str = ExampleType.ONE_SHOT.value type: str = ExampleType.ONE_SHOT.value
@ -22,7 +22,7 @@ class ExampleSelector(BaseModel, ABC):
Returns: example text Returns: example text
""" """
if self.use_example: if self.use_example:
need_use = self.examples[:count] need_use = self.examples_record[:count]
return need_use return need_use
return None return None
@ -33,7 +33,7 @@ class ExampleSelector(BaseModel, ABC):
""" """
if self.use_example: if self.use_example:
need_use = self.examples[:1] need_use = self.examples_record[:1]
return need_use return need_use
return None return None

View File

@ -46,7 +46,10 @@ class PromptTemplate(BaseModel, ABC):
output_parser: BaseOutputParser = None output_parser: BaseOutputParser = None
"""""" """"""
sep: str = SeparatorStyle.SINGLE.value sep: str = SeparatorStyle.SINGLE.value
example: ExampleSelector = None
example_selector: ExampleSelector = None
need_historical_messages: bool = False
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -20,8 +20,8 @@ class ChatScene(Enum):
ChatNormal = "chat_normal" ChatNormal = "chat_normal"
ChatDashboard = "chat_dashboard" ChatDashboard = "chat_dashboard"
ChatKnowledge = "chat_knowledge" ChatKnowledge = "chat_knowledge"
ChatDb = "chat_db" # ChatDb = "chat_db"
ChatData = "chat_data" # ChatData= "chat_data"
@staticmethod @staticmethod
def is_valid_mode(mode): def is_valid_mode(mode):

View File

@ -39,6 +39,7 @@ from pilot.scene.base_message import (
ViewMessage, ViewMessage,
) )
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.server.llmserver import worker
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"} headers = {"User-Agent": "dbgpt Client"}
@ -51,7 +52,7 @@ class BaseChat(ABC):
temperature: float = 0.6 temperature: float = 0.6
max_new_tokens: int = 1024 max_new_tokens: int = 1024
# By default, keep the last two rounds of conversation records as the context # By default, keep the last two rounds of conversation records as the context
chat_retention_rounds: int = 2 chat_retention_rounds: int = 1
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -59,10 +60,10 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, self,
chat_mode, chat_mode,
chat_session_id, chat_session_id,
current_user_input, current_user_input,
): ):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
@ -75,11 +76,9 @@ class BaseChat(ABC):
self.prompt_template: PromptTemplate = CFG.prompt_templates[ self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value self.chat_mode.value
] ]
self.history_message: List[OnceConversation] = [] self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value) self.current_message: OnceConversation = OnceConversation(chat_mode.value)
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0
### load chat_session_id's chat historys
self._load_history(self.chat_session_id)
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -95,7 +94,6 @@ class BaseChat(ABC):
def generate_input_values(self): def generate_input_values(self):
pass pass
@abstractmethod
def do_action(self, prompt_response): def do_action(self, prompt_response):
return prompt_response return prompt_response
@ -104,23 +102,12 @@ class BaseChat(ABC):
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime( self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
"%Y-%m-%d %H:%M:%S"
)
# TODO # TODO
self.current_message.tokens = 0 self.current_message.tokens = 0
current_prompt = None
if self.prompt_template.template: if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values) current_prompt = self.prompt_template.format(**input_values)
### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
if self.history_message:
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
logger.info(
f"There are already {len(self.history_message)} rounds of conversations!"
)
if current_prompt:
self.current_message.add_system_message(current_prompt) self.current_message.add_system_message(current_prompt)
payload = { payload = {
@ -140,31 +127,24 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
show_info = "" if not CFG.NEW_SERVER_MODE:
response = requests.post( response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"), urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, headers=headers,
json=payload, json=payload,
stream=True, stream=True,
timeout=120, timeout=120,
) )
return response return response
else:
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len) return worker.generate_stream_gate(payload)
# for resp_text_trunck in ai_response_text:
# show_info = resp_text_trunck
# yield resp_text_trunck + "▌"
self.current_message.add_ai_message(show_info)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error("model response parase faild" + str(e)) logger.error("model response parase faild" + str(e))
self.current_message.add_view_message( self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """ f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
) )
### 对话记录存储 ### store current conversation
self.memory.append(self.current_message) self.memory.append(self.current_message)
def nostream_call(self): def nostream_call(self):
@ -172,42 +152,27 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
### 走非流式的模型服务接口 rsp_str = ""
response = requests.post( if not CFG.NEW_SERVER_MODE:
urljoin(CFG.MODEL_SERVER, "generate"), rsp_str = requests.post(
headers=headers, urljoin(CFG.MODEL_SERVER, "generate"),
json=payload, headers=headers,
timeout=120, json=payload,
) timeout=120,
)
else:
###TODO no stream mode need independent
output = worker.generate_stream_gate(payload)
for rsp in output:
rsp_str = str(rsp, "utf-8")
print("[TEST: output]:", rsp_str)
### output parse ### output parse
ai_response_text = ( ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
self.prompt_template.output_parser.parse_model_nostream_resp( self.prompt_template.sep)
response, self.prompt_template.sep ### model result deal
)
)
# ### MOCK
# ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": {
# "name": "histogram-executor",
# "args": {
# "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# }
# }
# }"""
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = ( prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_action(prompt_define_response) result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"): if hasattr(prompt_define_response, "thoughts"):
@ -235,7 +200,7 @@ class BaseChat(ABC):
self.current_message.add_view_message( self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """ f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
) )
### 对话记录存储 ### store dialogue
self.memory.append(self.current_message) self.memory.append(self.current_message)
return self.current_ai_response() return self.current_ai_response()
@ -247,67 +212,99 @@ class BaseChat(ABC):
def generate_llm_text(self) -> str: def generate_llm_text(self) -> str:
text = "" text = ""
### Load scene setting or character definition
if self.prompt_template.template_define: if self.prompt_template.template_define:
text = self.prompt_template.template_define + self.prompt_template.sep text += self.prompt_template.template_define + self.prompt_template.sep
### Load prompt
text += self.__load_system_message()
### 处理历史信息 ### Load examples
if len(self.history_message) > self.chat_retention_rounds: text += self.__load_example_messages()
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage):
text += (
first_message.type
+ ":"
+ first_message.content
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1 ### Load History
for last_message in self.history_message[-index:].messages: text += self.__load_histroy_messages()
if not isinstance(last_message, ViewMessage):
text += (
last_message.type
+ ":"
+ last_message.content
+ self.prompt_template.sep
)
else:
### 直接历史记录拼接
for conversation in self.history_message:
for message in conversation.messages:
if not isinstance(message, ViewMessage):
text += (
message.type
+ ":"
+ message.content
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
)
### Load User Input
text += self.__load_user_message()
return text return text
# 暂时为了兼容前端 def __load_system_message(self):
system_convs = self.current_message.get_system_conv()
system_text = ""
for system_conv in system_convs:
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep
return system_text
def __load_user_message(self):
user_conv = self.current_message.get_user_conv()
if user_conv:
return user_conv.type + ":" + user_conv.content + self.prompt_template.sep
else:
raise ValueError("Hi! What do you want to talk about")
def __load_example_messages(self):
example_text = ""
if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv['messages']:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
example_text += (
round_message['type']
+ ":"
+ round_message['data']['content']
+ self.prompt_template.sep
)
return example_text
def __load_histroy_messages(self):
history_text = ""
if self.prompt_template.need_historical_messages:
if self.history_message:
logger.info(
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
)
if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]['messages']:
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
history_text += (
first_message['type']
+ ":"
+ first_message['data']['content']
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]:
for round_message in round_conv['messages']:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
history_text += (
round_message['type']
+ ":"
+ round_message['data']['content']
+ self.prompt_template.sep
)
else:
### user all history
for conversation in self.history_message:
for message in conversation['messages']:
### histroy message not have promot and view info
if not message['type'] in [SystemMessage.type, ViewMessage.type]:
history_text += (
message['type']
+ ":"
+ message['data']['content']
+ self.prompt_template.sep
)
return history_text
def current_ai_response(self) -> str: def current_ai_response(self) -> str:
for message in self.current_message.messages: for message in self.current_message.messages:
if message.type == "view": if message.type == "view":
return message.content return message.content
return None return None
def _load_history(self, session_id: str) -> List[OnceConversation]:
"""
load chat history by session_id
Args:
session_id:
Returns:
"""
return self.memory.messages()
def generate(self, p) -> str: def generate(self, p) -> str:
""" """
generate context for LLM input generate context for LLM input

View File

@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE = """

View File

@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default ## Two examples are defined by default
EXAMPLES = [ EXAMPLES = [
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}], [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}],
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}], [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
] ]
example = ExampleSelector(examples=EXAMPLES, use_example=True) plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@ -5,14 +5,12 @@ from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle, ExampleType from pilot.common.schema import SeparatorStyle, ExampleType
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
from pilot.scene.chat_execution.example import example from pilot.scene.chat_execution.example import plugin_example
CFG = Config() CFG = Config()
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers." PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE = """
Goals: Goals:
{input} {input}
@ -51,7 +49,7 @@ prompt = PromptTemplate(
output_parser=PluginChatOutputParser( output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
), ),
example=example, example_selector=plugin_example,
) )
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. PROMPT_SCENE_DEFINE = None
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
CFG = Config() CFG = Config()

View File

@ -85,16 +85,22 @@ class OnceConversation:
self.messages.clear() self.messages.clear()
self.session_id = None self.session_id = None
def get_user_message(self): def get_user_conv(self):
for once in self.messages: for message in self.messages:
if isinstance(once, HumanMessage): if isinstance(message, HumanMessage):
return once.content return message
return "" return None
def get_system_conv(self):
system_convs = []
for message in self.messages:
if isinstance(message, SystemMessage):
system_convs.append(message)
return system_convs
def _conversation_to_dic(once: OnceConversation) -> dict: def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = "" start_str: str = ""
if once.start_date: if hasattr(once, 'start_date') and once.start_date:
if isinstance(once.start_date, datetime): if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else: else:

View File

@ -0,0 +1,78 @@
import traceback
import os
import shutil
import argparse
import sys
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.utils import build_logger
from pilot.server.webserver_base import server_init
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
origins = ["*"]
# 添加跨域中间件
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
app.add_route("/test", "static/test.html")
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args()
server_init(args)
CFG.NEW_SERVER_MODE = True
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)

View File

@ -33,7 +33,7 @@ class ModelWorker:
model_path = model_path[:-1] model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1] self.model_name = model_name or model_path.split("/")[-1]
self.device = device self.device = device
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
self.ml = ModelLoader(model_path=model_path) self.ml = ModelLoader(model_path=model_path)
self.model, self.tokenizer = self.ml.loader( self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
@ -41,11 +41,11 @@ class ModelWorker:
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr( if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length" self.model.config, "max_sequence_length"
): ):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr( elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings" self.model.config, "max_position_embeddings"
): ):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
@ -55,29 +55,32 @@ class ModelWorker:
self.llm_chat_adapter = get_llm_chat_adapter(model_path) self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func() self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
def start_check(self):
print("LLM Model Loading Success")
def get_queue_length(self): def get_queue_length(self):
if ( if (
model_semaphore is None model_semaphore is None
or model_semaphore._value is None or model_semaphore._value is None
or model_semaphore._waiters is None or model_semaphore._waiters is None
): ):
return 0 return 0
else: else:
( (
CFG.LIMIT_MODEL_CONCURRENCY CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value - model_semaphore._value
+ len(model_semaphore._waiters) + len(model_semaphore._waiters)
) )
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
try: try:
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):
# Please do not open the output in production! # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process, # The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output. # and opening it may affect the frontend output.
# print("output: ", output) print("output: ", output)
ret = { ret = {
"text": output, "text": output,
"error_code": 0, "error_code": 0,
@ -178,10 +181,9 @@ def generate(prompt_request: PromptRequest):
for rsp in output: for rsp in output:
# rsp = rsp.decode("utf-8") # rsp = rsp.decode("utf-8")
rsp_str = str(rsp, "utf-8") rsp_str = str(rsp, "utf-8")
print("[TEST: output]:", rsp_str)
response.append(rsp_str) response.append(rsp_str)
return {"response": rsp_str} return rsp_str
@app.post("/embedding") @app.post("/embedding")

View File

@ -0,0 +1,19 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Streaming Demo</title>
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
</head>
<body>
<div id="output"></div>
<script>
$(document).ready(function() {
var source = new EventSource("/v1/chat/completions");
source.onmessage = function(event) {
$("#output").append(event.data);
}
});
</script>
</body>
</html>

View File

@ -60,7 +60,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件 # 加载插件
CFG = Config() CFG = Config()

View File

@ -39,6 +39,8 @@ def server_init(args):
# init config # init config
cfg = Config() cfg = Config()
from pilot.server.llmserver import worker
worker.start_check()
load_native_plugins(cfg) load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
async_db_summery() async_db_summery()