mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 06:26:18 +00:00
style:format code style
format code style
This commit is contained in:
parent
359babecdc
commit
4029f48d5f
@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder
|
|||||||
|
|
||||||
default_db_path = os.path.join(os.getcwd(), "message")
|
default_db_path = os.path.join(os.getcwd(), "message")
|
||||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||||
table_name = 'chat_history'
|
table_name = "chat_history"
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||||
|
|
||||||
def __init__(self, chat_session_id: str):
|
def __init__(self, chat_session_id: str):
|
||||||
self.chat_seesion_id = chat_session_id
|
self.chat_seesion_id = chat_session_id
|
||||||
os.makedirs(default_db_path, exist_ok=True)
|
os.makedirs(default_db_path, exist_ok=True)
|
||||||
@ -29,15 +28,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
self.__init_chat_history_tables()
|
self.__init_chat_history_tables()
|
||||||
|
|
||||||
def __init_chat_history_tables(self):
|
def __init_chat_history_tables(self):
|
||||||
|
|
||||||
# 检查表是否存在
|
# 检查表是否存在
|
||||||
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
result = self.connect.execute(
|
||||||
[table_name]).fetchall()
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
# 如果表不存在,则创建新表
|
# 如果表不存在,则创建新表
|
||||||
self.connect.execute(
|
self.connect.execute(
|
||||||
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
|
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
|
||||||
|
)
|
||||||
|
|
||||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
@ -47,6 +47,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
return content[0]
|
return content[0]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def messages(self) -> List[OnceConversation]:
|
def messages(self) -> List[OnceConversation]:
|
||||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||||
if context:
|
if context:
|
||||||
@ -62,23 +63,35 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
conversations.append(_conversation_to_dic(once_message))
|
conversations.append(_conversation_to_dic(once_message))
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
if context:
|
if context:
|
||||||
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
cursor.execute(
|
||||||
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
|
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||||
|
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
cursor.execute(
|
||||||
[self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)])
|
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||||
|
[
|
||||||
|
self.chat_seesion_id,
|
||||||
|
"",
|
||||||
|
json.dumps(conversations, ensure_ascii=False),
|
||||||
|
],
|
||||||
|
)
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
cursor.execute(
|
||||||
|
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||||
|
)
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
def delete(self) -> bool:
|
def delete(self) -> bool:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
cursor.execute(
|
||||||
|
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||||
|
)
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -87,7 +100,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
if os.path.isfile(duckdb_path):
|
if os.path.isfile(duckdb_path):
|
||||||
cursor = duckdb.connect(duckdb_path).cursor()
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
if user_name:
|
if user_name:
|
||||||
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
|
cursor.execute(
|
||||||
|
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||||
# 获取查询结果字段名
|
# 获取查询结果字段名
|
||||||
@ -103,10 +118,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def get_messages(self) -> List[OnceConversation]:
|
||||||
def get_messages(self)-> List[OnceConversation]:
|
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
cursor.execute(
|
||||||
|
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||||
|
)
|
||||||
context = cursor.fetchone()
|
context = cursor.fetchone()
|
||||||
if context:
|
if context:
|
||||||
return json.loads(context[0])
|
return json.loads(context[0])
|
||||||
|
@ -2,17 +2,29 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Request,
|
||||||
|
Body,
|
||||||
|
status,
|
||||||
|
HTTPException,
|
||||||
|
Response,
|
||||||
|
BackgroundTasks,
|
||||||
|
)
|
||||||
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sse_starlette.sse import EventSourceResponse
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
from pilot.openapi.api_v1.api_view_model import (
|
||||||
|
Result,
|
||||||
|
ConversationVo,
|
||||||
|
MessageVo,
|
||||||
|
ChatSceneVo,
|
||||||
|
)
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||||
@ -103,7 +115,7 @@ async def dialogue_scenes():
|
|||||||
|
|
||||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||||
async def dialogue_new(
|
async def dialogue_new(
|
||||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||||
):
|
):
|
||||||
unique_id = uuid.uuid1()
|
unique_id = uuid.uuid1()
|
||||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||||
@ -220,11 +232,19 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not chat.prompt_template.stream_out:
|
if not chat.prompt_template.stream_out:
|
||||||
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
|
return StreamingResponse(
|
||||||
background=background_tasks)
|
no_stream_generator(chat),
|
||||||
|
headers=headers,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
background=background_tasks,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
|
return StreamingResponse(
|
||||||
background=background_tasks)
|
stream_generator(chat),
|
||||||
|
headers=headers,
|
||||||
|
media_type="text/plain",
|
||||||
|
background=background_tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def release_model_semaphore():
|
def release_model_semaphore():
|
||||||
@ -236,12 +256,15 @@ async def no_stream_generator(chat):
|
|||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
yield f"data: {msg}\n\n"
|
yield f"data: {msg}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(chat):
|
async def stream_generator(chat):
|
||||||
model_response = chat.stream_call()
|
model_response = chat.stream_call()
|
||||||
if not CFG.NEW_SERVER_MODE:
|
if not CFG.NEW_SERVER_MODE:
|
||||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
|
chunk, chat.skip_echo_len
|
||||||
|
)
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
yield f"data:{msg}\n\n"
|
yield f"data:{msg}\n\n"
|
||||||
@ -249,7 +272,9 @@ async def stream_generator(chat):
|
|||||||
else:
|
else:
|
||||||
for chunk in model_response:
|
for chunk in model_response:
|
||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
|
chunk, chat.skip_echo_len
|
||||||
|
)
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
|
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
@ -259,4 +284,6 @@ async def stream_generator(chat):
|
|||||||
|
|
||||||
|
|
||||||
def message2Vo(message: dict, order) -> MessageVo:
|
def message2Vo(message: dict, order) -> MessageVo:
|
||||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
return MessageVo(
|
||||||
|
role=message["type"], context=message["data"]["content"], order=order
|
||||||
|
)
|
||||||
|
@ -76,18 +76,34 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge/{space_name}/document/upload")
|
@router.post("/knowledge/{space_name}/document/upload")
|
||||||
async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)):
|
async def document_upload(
|
||||||
|
space_name: str,
|
||||||
|
doc_name: str = Form(...),
|
||||||
|
doc_type: str = Form(...),
|
||||||
|
doc_file: UploadFile = File(...),
|
||||||
|
):
|
||||||
print(f"/document/upload params: {space_name}")
|
print(f"/document/upload params: {space_name}")
|
||||||
try:
|
try:
|
||||||
if doc_file:
|
if doc_file:
|
||||||
with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp:
|
with NamedTemporaryFile(
|
||||||
|
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False
|
||||||
|
) as tmp:
|
||||||
tmp.write(await doc_file.read())
|
tmp.write(await doc_file.read())
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename))
|
shutil.move(
|
||||||
|
tmp_path,
|
||||||
|
os.path.join(
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||||
|
),
|
||||||
|
)
|
||||||
request = KnowledgeDocumentRequest()
|
request = KnowledgeDocumentRequest()
|
||||||
request.doc_name = doc_name
|
request.doc_name = doc_name
|
||||||
request.doc_type = doc_type
|
request.doc_type = doc_type
|
||||||
request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename),
|
request.content = (
|
||||||
|
os.path.join(
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||||
|
),
|
||||||
|
)
|
||||||
knowledge_space_service.create_knowledge_document(
|
knowledge_space_service.create_knowledge_document(
|
||||||
space=space_name, request=request
|
space=space_name, request=request
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import (
|
|||||||
)
|
)
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse
|
from pilot.openapi.knowledge.request.knowledge_response import (
|
||||||
|
ChunkQueryResponse,
|
||||||
|
DocumentQueryResponse,
|
||||||
|
)
|
||||||
|
|
||||||
knowledge_space_dao = KnowledgeSpaceDao()
|
knowledge_space_dao = KnowledgeSpaceDao()
|
||||||
knowledge_document_dao = KnowledgeDocumentDao()
|
knowledge_document_dao = KnowledgeDocumentDao()
|
||||||
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class ChunkQueryResponse(BaseModel):
|
class ChunkQueryResponse(BaseModel):
|
||||||
"""data: data"""
|
"""data: data"""
|
||||||
|
|
||||||
data: List = None
|
data: List = None
|
||||||
"""total: total size"""
|
"""total: total size"""
|
||||||
total: int = None
|
total: int = None
|
||||||
@ -14,9 +15,9 @@ class ChunkQueryResponse(BaseModel):
|
|||||||
|
|
||||||
class DocumentQueryResponse(BaseModel):
|
class DocumentQueryResponse(BaseModel):
|
||||||
"""data: data"""
|
"""data: data"""
|
||||||
|
|
||||||
data: List = None
|
data: List = None
|
||||||
"""total: total size"""
|
"""total: total size"""
|
||||||
total: int = None
|
total: int = None
|
||||||
"""page: current page"""
|
"""page: current page"""
|
||||||
page: int = None
|
page: int = None
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ class BaseOutputParser(ABC):
|
|||||||
def __extract_json(slef, s):
|
def __extract_json(slef, s):
|
||||||
i = s.index("{")
|
i = s.index("{")
|
||||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||||
if c == "}":
|
if c == "}":
|
||||||
count -= 1
|
count -= 1
|
||||||
elif c == "{":
|
elif c == "{":
|
||||||
@ -130,7 +130,7 @@ class BaseOutputParser(ABC):
|
|||||||
if count == 0:
|
if count == 0:
|
||||||
break
|
break
|
||||||
assert count == 0 # 检查是否找到最后一个'}'
|
assert count == 0 # 检查是否找到最后一个'}'
|
||||||
return s[i: j + 1]
|
return s[i : j + 1]
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
"""
|
"""
|
||||||
@ -147,9 +147,9 @@ class BaseOutputParser(ABC):
|
|||||||
# if "```" in cleaned_output:
|
# if "```" in cleaned_output:
|
||||||
# cleaned_output, _ = cleaned_output.split("```")
|
# cleaned_output, _ = cleaned_output.split("```")
|
||||||
if cleaned_output.startswith("```json"):
|
if cleaned_output.startswith("```json"):
|
||||||
cleaned_output = cleaned_output[len("```json"):]
|
cleaned_output = cleaned_output[len("```json") :]
|
||||||
if cleaned_output.startswith("```"):
|
if cleaned_output.startswith("```"):
|
||||||
cleaned_output = cleaned_output[len("```"):]
|
cleaned_output = cleaned_output[len("```") :]
|
||||||
if cleaned_output.endswith("```"):
|
if cleaned_output.endswith("```"):
|
||||||
cleaned_output = cleaned_output[: -len("```")]
|
cleaned_output = cleaned_output[: -len("```")]
|
||||||
cleaned_output = cleaned_output.strip()
|
cleaned_output = cleaned_output.strip()
|
||||||
@ -158,9 +158,9 @@ class BaseOutputParser(ABC):
|
|||||||
cleaned_output = self.__extract_json(cleaned_output)
|
cleaned_output = self.__extract_json(cleaned_output)
|
||||||
cleaned_output = (
|
cleaned_output = (
|
||||||
cleaned_output.strip()
|
cleaned_output.strip()
|
||||||
.replace("\n", " ")
|
.replace("\n", " ")
|
||||||
.replace("\\n", " ")
|
.replace("\\n", " ")
|
||||||
.replace("\\", " ")
|
.replace("\\", " ")
|
||||||
)
|
)
|
||||||
return cleaned_output
|
return cleaned_output
|
||||||
|
|
||||||
|
@ -60,10 +60,10 @@ class BaseChat(ABC):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chat_mode,
|
chat_mode,
|
||||||
chat_session_id,
|
chat_session_id,
|
||||||
current_user_input,
|
current_user_input,
|
||||||
):
|
):
|
||||||
self.chat_session_id = chat_session_id
|
self.chat_session_id = chat_session_id
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
@ -102,7 +102,9 @@ class BaseChat(ABC):
|
|||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
# TODO
|
# TODO
|
||||||
self.current_message.tokens = 0
|
self.current_message.tokens = 0
|
||||||
|
|
||||||
@ -168,11 +170,18 @@ class BaseChat(ABC):
|
|||||||
print("[TEST: output]:", rsp_str)
|
print("[TEST: output]:", rsp_str)
|
||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
|
ai_response_text = (
|
||||||
self.prompt_template.sep)
|
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||||
|
rsp_str, self.prompt_template.sep
|
||||||
|
)
|
||||||
|
)
|
||||||
### model result deal
|
### model result deal
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
prompt_define_response = (
|
||||||
|
self.prompt_template.output_parser.parse_prompt_response(
|
||||||
|
ai_response_text
|
||||||
|
)
|
||||||
|
)
|
||||||
result = self.do_action(prompt_define_response)
|
result = self.do_action(prompt_define_response)
|
||||||
|
|
||||||
if hasattr(prompt_define_response, "thoughts"):
|
if hasattr(prompt_define_response, "thoughts"):
|
||||||
@ -232,7 +241,9 @@ class BaseChat(ABC):
|
|||||||
system_convs = self.current_message.get_system_conv()
|
system_convs = self.current_message.get_system_conv()
|
||||||
system_text = ""
|
system_text = ""
|
||||||
for system_conv in system_convs:
|
for system_conv in system_convs:
|
||||||
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
system_text += (
|
||||||
|
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||||
|
)
|
||||||
return system_text
|
return system_text
|
||||||
|
|
||||||
def __load_user_message(self):
|
def __load_user_message(self):
|
||||||
@ -246,13 +257,16 @@ class BaseChat(ABC):
|
|||||||
example_text = ""
|
example_text = ""
|
||||||
if self.prompt_template.example_selector:
|
if self.prompt_template.example_selector:
|
||||||
for round_conv in self.prompt_template.example_selector.examples():
|
for round_conv in self.prompt_template.example_selector.examples():
|
||||||
for round_message in round_conv['messages']:
|
for round_message in round_conv["messages"]:
|
||||||
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
|
if not round_message["type"] in [
|
||||||
|
SystemMessage.type,
|
||||||
|
ViewMessage.type,
|
||||||
|
]:
|
||||||
example_text += (
|
example_text += (
|
||||||
round_message['type']
|
round_message["type"]
|
||||||
+ ":"
|
+ ":"
|
||||||
+ round_message['data']['content']
|
+ round_message["data"]["content"]
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
return example_text
|
return example_text
|
||||||
|
|
||||||
@ -264,37 +278,46 @@ class BaseChat(ABC):
|
|||||||
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
|
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
|
||||||
)
|
)
|
||||||
if len(self.history_message) > self.chat_retention_rounds:
|
if len(self.history_message) > self.chat_retention_rounds:
|
||||||
for first_message in self.history_message[0]['messages']:
|
for first_message in self.history_message[0]["messages"]:
|
||||||
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
|
if not first_message["type"] in [
|
||||||
|
ViewMessage.type,
|
||||||
|
SystemMessage.type,
|
||||||
|
]:
|
||||||
history_text += (
|
history_text += (
|
||||||
first_message['type']
|
first_message["type"]
|
||||||
+ ":"
|
+ ":"
|
||||||
+ first_message['data']['content']
|
+ first_message["data"]["content"]
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
|
||||||
index = self.chat_retention_rounds - 1
|
index = self.chat_retention_rounds - 1
|
||||||
for round_conv in self.history_message[-index:]:
|
for round_conv in self.history_message[-index:]:
|
||||||
for round_message in round_conv['messages']:
|
for round_message in round_conv["messages"]:
|
||||||
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
|
if not round_message["type"] in [
|
||||||
|
SystemMessage.type,
|
||||||
|
ViewMessage.type,
|
||||||
|
]:
|
||||||
history_text += (
|
history_text += (
|
||||||
round_message['type']
|
round_message["type"]
|
||||||
+ ":"
|
+ ":"
|
||||||
+ round_message['data']['content']
|
+ round_message["data"]["content"]
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
### user all history
|
### user all history
|
||||||
for conversation in self.history_message:
|
for conversation in self.history_message:
|
||||||
for message in conversation['messages']:
|
for message in conversation["messages"]:
|
||||||
### histroy message not have promot and view info
|
### histroy message not have promot and view info
|
||||||
if not message['type'] in [SystemMessage.type, ViewMessage.type]:
|
if not message["type"] in [
|
||||||
|
SystemMessage.type,
|
||||||
|
ViewMessage.type,
|
||||||
|
]:
|
||||||
history_text += (
|
history_text += (
|
||||||
message['type']
|
message["type"]
|
||||||
+ ":"
|
+ ":"
|
||||||
+ message['data']['content']
|
+ message["data"]["content"]
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
|
||||||
return history_text
|
return history_text
|
||||||
|
@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
|
|||||||
|
|
||||||
## Two examples are defined by default
|
## Two examples are defined by default
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}],
|
[{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
|
||||||
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
|
[{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
|
||||||
]
|
]
|
||||||
|
|
||||||
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
||||||
|
@ -98,9 +98,10 @@ class OnceConversation:
|
|||||||
system_convs.append(message)
|
system_convs.append(message)
|
||||||
return system_convs
|
return system_convs
|
||||||
|
|
||||||
|
|
||||||
def _conversation_to_dic(once: OnceConversation) -> dict:
|
def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||||
start_str: str = ""
|
start_str: str = ""
|
||||||
if hasattr(once, 'start_date') and once.start_date:
|
if hasattr(once, "start_date") and once.start_date:
|
||||||
if isinstance(once.start_date, datetime):
|
if isinstance(once.start_date, datetime):
|
||||||
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
|
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
else:
|
else:
|
||||||
|
@ -23,9 +23,12 @@ from fastapi import FastAPI, applications
|
|||||||
from fastapi.openapi.docs import get_swagger_ui_html
|
from fastapi.openapi.docs import get_swagger_ui_html
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router
|
||||||
|
|
||||||
|
|
||||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||||
|
|
||||||
|
|
||||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -34,9 +37,10 @@ logger = build_logger("webserver", LOGDIR + "webserver.log")
|
|||||||
|
|
||||||
def swagger_monkey_patch(*args, **kwargs):
|
def swagger_monkey_patch(*args, **kwargs):
|
||||||
return get_swagger_ui_html(
|
return get_swagger_ui_html(
|
||||||
*args, **kwargs,
|
*args,
|
||||||
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
**kwargs,
|
||||||
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
||||||
|
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -55,14 +59,16 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
|
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
|
||||||
app.add_route("/test", "static/test.html")
|
app.add_route("/test", "static/test.html")
|
||||||
|
app.include_router(knowledge_router)
|
||||||
app.include_router(api_v1)
|
app.include_router(api_v1)
|
||||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
parser.add_argument(
|
||||||
|
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
||||||
|
)
|
||||||
|
|
||||||
# old version server config
|
# old version server config
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
@ -75,4 +81,5 @@ if __name__ == "__main__":
|
|||||||
server_init(args)
|
server_init(args)
|
||||||
CFG.NEW_SERVER_MODE = True
|
CFG.NEW_SERVER_MODE = True
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||||
|
@ -9,7 +9,8 @@ import sys
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import BackgroundTasks, FastAPI, Request
|
from fastapi import BackgroundTasks, FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
# from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
@ -41,11 +42,11 @@ class ModelWorker:
|
|||||||
|
|
||||||
if not isinstance(self.model, str):
|
if not isinstance(self.model, str):
|
||||||
if hasattr(self.model, "config") and hasattr(
|
if hasattr(self.model, "config") and hasattr(
|
||||||
self.model.config, "max_sequence_length"
|
self.model.config, "max_sequence_length"
|
||||||
):
|
):
|
||||||
self.context_len = self.model.config.max_sequence_length
|
self.context_len = self.model.config.max_sequence_length
|
||||||
elif hasattr(self.model, "config") and hasattr(
|
elif hasattr(self.model, "config") and hasattr(
|
||||||
self.model.config, "max_position_embeddings"
|
self.model.config, "max_position_embeddings"
|
||||||
):
|
):
|
||||||
self.context_len = self.model.config.max_position_embeddings
|
self.context_len = self.model.config.max_position_embeddings
|
||||||
|
|
||||||
@ -60,22 +61,22 @@ class ModelWorker:
|
|||||||
|
|
||||||
def get_queue_length(self):
|
def get_queue_length(self):
|
||||||
if (
|
if (
|
||||||
model_semaphore is None
|
model_semaphore is None
|
||||||
or model_semaphore._value is None
|
or model_semaphore._value is None
|
||||||
or model_semaphore._waiters is None
|
or model_semaphore._waiters is None
|
||||||
):
|
):
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
CFG.LIMIT_MODEL_CONCURRENCY
|
CFG.LIMIT_MODEL_CONCURRENCY
|
||||||
- model_semaphore._value
|
- model_semaphore._value
|
||||||
+ len(model_semaphore._waiters)
|
+ len(model_semaphore._waiters)
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
try:
|
try:
|
||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||||
):
|
):
|
||||||
# Please do not open the output in production!
|
# Please do not open the output in production!
|
||||||
# The gpt4all thread shares stdout with the parent process,
|
# The gpt4all thread shares stdout with the parent process,
|
||||||
@ -107,23 +108,23 @@ worker = ModelWorker(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
from pilot.openapi.knowledge.knowledge_controller import router
|
# from pilot.openapi.knowledge.knowledge_controller import router
|
||||||
|
#
|
||||||
app.include_router(router)
|
# app.include_router(router)
|
||||||
|
#
|
||||||
origins = [
|
# origins = [
|
||||||
"http://localhost",
|
# "http://localhost",
|
||||||
"http://localhost:8000",
|
# "http://localhost:8000",
|
||||||
"http://localhost:3000",
|
# "http://localhost:3000",
|
||||||
]
|
# ]
|
||||||
|
#
|
||||||
app.add_middleware(
|
# app.add_middleware(
|
||||||
CORSMiddleware,
|
# CORSMiddleware,
|
||||||
allow_origins=origins,
|
# allow_origins=origins,
|
||||||
allow_credentials=True,
|
# allow_credentials=True,
|
||||||
allow_methods=["*"],
|
# allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
# allow_headers=["*"],
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
|
@ -40,6 +40,7 @@ def server_init(args):
|
|||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
from pilot.server.llmserver import worker
|
from pilot.server.llmserver import worker
|
||||||
|
|
||||||
worker.start_check()
|
worker.start_check()
|
||||||
load_native_plugins(cfg)
|
load_native_plugins(cfg)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
Loading…
Reference in New Issue
Block a user