style:format code style

format code style
This commit is contained in:
aries_ckt 2023-06-29 13:52:53 +08:00
parent 359babecdc
commit 4029f48d5f
12 changed files with 205 additions and 109 deletions

View File

@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = 'chat_history'
table_name = "chat_history"
CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True)
@ -29,15 +28,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
self.__init_chat_history_tables()
def __init_chat_history_tables(self):
# 检查表是否存在
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
[table_name]).fetchall()
result = self.connect.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
)
def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor()
@ -47,6 +47,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return content[0]
else:
return None
def messages(self) -> List[OnceConversation]:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
if context:
@ -62,23 +63,35 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
conversations.append(_conversation_to_dic(once_message))
cursor = self.connect.cursor()
if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
cursor.execute(
"UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id],
)
else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)])
cursor.execute(
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[
self.chat_seesion_id,
"",
json.dumps(conversations, ensure_ascii=False),
],
)
cursor.commit()
self.connect.commit()
def clear(self) -> None:
cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit()
self.connect.commit()
def delete(self) -> bool:
cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit()
return True
@ -87,7 +100,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if user_name:
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
cursor.execute(
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
)
else:
cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名
@ -103,10 +118,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return []
def get_messages(self)-> List[OnceConversation]:
def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
context = cursor.fetchone()
if context:
return json.loads(context[0])

View File

@ -2,17 +2,29 @@ import uuid
import json
import asyncio
import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi import (
APIRouter,
Request,
Body,
status,
HTTPException,
Response,
BackgroundTasks,
)
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
from pilot.openapi.api_v1.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@ -103,7 +115,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@ -220,11 +232,19 @@ async def chat_completions(dialogue: ConversationVo = Body()):
}
if not chat.prompt_template.stream_out:
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
background=background_tasks)
return StreamingResponse(
no_stream_generator(chat),
headers=headers,
media_type="text/event-stream",
background=background_tasks,
)
else:
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
background=background_tasks)
return StreamingResponse(
stream_generator(chat),
headers=headers,
media_type="text/plain",
background=background_tasks,
)
def release_model_semaphore():
@ -236,12 +256,15 @@ async def no_stream_generator(chat):
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat):
model_response = chat.stream_call()
if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
@ -249,7 +272,9 @@ async def stream_generator(chat):
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n")
@ -259,4 +284,6 @@ async def stream_generator(chat):
def message2Vo(message: dict, order) -> MessageVo:
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
return MessageVo(
role=message["type"], context=message["data"]["content"], order=order
)

View File

@ -76,18 +76,34 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
@router.post("/knowledge/{space_name}/document/upload")
async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)):
async def document_upload(
space_name: str,
doc_name: str = Form(...),
doc_type: str = Form(...),
doc_file: UploadFile = File(...),
):
print(f"/document/upload params: {space_name}")
try:
if doc_file:
with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp:
with NamedTemporaryFile(
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False
) as tmp:
tmp.write(await doc_file.read())
tmp_path = tmp.name
shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename))
shutil.move(
tmp_path,
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
request = KnowledgeDocumentRequest()
request.doc_name = doc_name
request.doc_type = doc_type
request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename),
request.content = (
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
knowledge_space_service.create_knowledge_document(
space=space_name, request=request
)

View File

@ -25,7 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import (
)
from enum import Enum
from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse
from pilot.openapi.knowledge.request.knowledge_response import (
ChunkQueryResponse,
DocumentQueryResponse,
)
knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
class ChunkQueryResponse(BaseModel):
"""data: data"""
data: List = None
"""total: total size"""
total: int = None
@ -14,9 +15,9 @@ class ChunkQueryResponse(BaseModel):
class DocumentQueryResponse(BaseModel):
"""data: data"""
data: List = None
"""total: total size"""
total: int = None
"""page: current page"""
page: int = None

View File

@ -122,7 +122,7 @@ class BaseOutputParser(ABC):
def __extract_json(slef, s):
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1):
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
@ -130,7 +130,7 @@ class BaseOutputParser(ABC):
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i: j + 1]
return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T:
"""
@ -147,9 +147,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):]
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):]
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
@ -158,9 +158,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = (
cleaned_output.strip()
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
)
return cleaned_output

View File

@ -60,10 +60,10 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(
self,
chat_mode,
chat_session_id,
current_user_input,
self,
chat_mode,
chat_session_id,
current_user_input,
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
@ -102,7 +102,9 @@ class BaseChat(ABC):
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
# TODO
self.current_message.tokens = 0
@ -168,11 +170,18 @@ class BaseChat(ABC):
print("[TEST: output]:", rsp_str)
### output parse
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
self.prompt_template.sep)
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
rsp_str, self.prompt_template.sep
)
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"):
@ -232,7 +241,9 @@ class BaseChat(ABC):
system_convs = self.current_message.get_system_conv()
system_text = ""
for system_conv in system_convs:
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep
system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
return system_text
def __load_user_message(self):
@ -246,13 +257,16 @@ class BaseChat(ABC):
example_text = ""
if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv['messages']:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
example_text += (
round_message['type']
+ ":"
+ round_message['data']['content']
+ self.prompt_template.sep
round_message["type"]
+ ":"
+ round_message["data"]["content"]
+ self.prompt_template.sep
)
return example_text
@ -264,37 +278,46 @@ class BaseChat(ABC):
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
)
if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]['messages']:
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
for first_message in self.history_message[0]["messages"]:
if not first_message["type"] in [
ViewMessage.type,
SystemMessage.type,
]:
history_text += (
first_message['type']
+ ":"
+ first_message['data']['content']
+ self.prompt_template.sep
first_message["type"]
+ ":"
+ first_message["data"]["content"]
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]:
for round_message in round_conv['messages']:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
history_text += (
round_message['type']
+ ":"
+ round_message['data']['content']
+ self.prompt_template.sep
round_message["type"]
+ ":"
+ round_message["data"]["content"]
+ self.prompt_template.sep
)
else:
### user all history
for conversation in self.history_message:
for message in conversation['messages']:
for message in conversation["messages"]:
### histroy message not have promot and view info
if not message['type'] in [SystemMessage.type, ViewMessage.type]:
if not message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
history_text += (
message['type']
+ ":"
+ message['data']['content']
+ self.prompt_template.sep
message["type"]
+ ":"
+ message["data"]["content"]
+ self.prompt_template.sep
)
return history_text

View File

@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}],
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
[{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
[{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@ -98,9 +98,10 @@ class OnceConversation:
system_convs.append(message)
return system_convs
def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = ""
if hasattr(once, 'start_date') and once.start_date:
if hasattr(once, "start_date") and once.start_date:
if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else:

View File

@ -23,9 +23,12 @@ from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
@ -34,9 +37,10 @@ logger = build_logger("webserver", LOGDIR + "webserver.log")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
*args,
**kwargs,
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
)
@ -55,14 +59,16 @@ app.add_middleware(
)
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
app.add_route("/test", "static/test.html")
app.add_route("/test", "static/test.html")
app.include_router(knowledge_router)
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
@ -75,4 +81,5 @@ if __name__ == "__main__":
server_init(args)
CFG.NEW_SERVER_MODE = True
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)

View File

@ -9,7 +9,8 @@ import sys
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
global_counter = 0
@ -41,11 +42,11 @@ class ModelWorker:
if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings
@ -60,22 +61,22 @@ class ModelWorker:
def get_queue_length(self):
if (
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
):
return 0
else:
(
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
)
def generate_stream_gate(self, params):
try:
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
@ -107,23 +108,23 @@ worker = ModelWorker(
)
app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router)
origins = [
"http://localhost",
"http://localhost:8000",
"http://localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# from pilot.openapi.knowledge.knowledge_controller import router
#
# app.include_router(router)
#
# origins = [
# "http://localhost",
# "http://localhost:8000",
# "http://localhost:3000",
# ]
#
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
class PromptRequest(BaseModel):

View File

@ -40,6 +40,7 @@ def server_init(args):
cfg = Config()
from pilot.server.llmserver import worker
worker.start_check()
load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)