style:format code style

format code style
This commit is contained in:
aries_ckt 2023-06-29 13:52:53 +08:00
parent 359babecdc
commit 4029f48d5f
12 changed files with 205 additions and 109 deletions

View File

@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message") default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = 'chat_history' table_name = "chat_history"
CFG = Config() CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory): class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str): def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True) os.makedirs(default_db_path, exist_ok=True)
@ -29,15 +28,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
self.__init_chat_history_tables() self.__init_chat_history_tables()
def __init_chat_history_tables(self): def __init_chat_history_tables(self):
# 检查表是否存在 # 检查表是否存在
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", result = self.connect.execute(
[table_name]).fetchall() "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result: if not result:
# 如果表不存在,则创建新表 # 如果表不存在,则创建新表
self.connect.execute( self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)") "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
)
def __get_messages_by_conv_uid(self, conv_uid: str): def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor() cursor = self.connect.cursor()
@ -47,6 +47,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return content[0] return content[0]
else: else:
return None return None
def messages(self) -> List[OnceConversation]: def messages(self) -> List[OnceConversation]:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id) context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
if context: if context:
@ -62,23 +63,35 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
conversations.append(_conversation_to_dic(once_message)) conversations.append(_conversation_to_dic(once_message))
cursor = self.connect.cursor() cursor = self.connect.cursor()
if context: if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", cursor.execute(
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) "UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id],
)
else: else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", cursor.execute(
[self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)]) "INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[
self.chat_seesion_id,
"",
json.dumps(conversations, ensure_ascii=False),
],
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def clear(self) -> None: def clear(self) -> None:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def delete(self) -> bool: def delete(self) -> bool:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit() cursor.commit()
return True return True
@ -87,7 +100,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if os.path.isfile(duckdb_path): if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor() cursor = duckdb.connect(duckdb_path).cursor()
if user_name: if user_name:
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) cursor.execute(
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
)
else: else:
cursor.execute("SELECT * FROM chat_history limit 20") cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名 # 获取查询结果字段名
@ -103,10 +118,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return [] return []
def get_messages(self) -> List[OnceConversation]:
def get_messages(self)-> List[OnceConversation]:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
context = cursor.fetchone() context = cursor.fetchone()
if context: if context:
return json.loads(context[0]) return json.loads(context[0])

View File

@ -2,17 +2,29 @@ import uuid
import json import json
import asyncio import asyncio
import time import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks from fastapi import (
APIRouter,
Request,
Body,
status,
HTTPException,
Response,
BackgroundTasks,
)
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List from typing import List
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.openapi.api_v1.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@ -103,7 +115,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new( async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
): ):
unique_id = uuid.uuid1() unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@ -220,11 +232,19 @@ async def chat_completions(dialogue: ConversationVo = Body()):
} }
if not chat.prompt_template.stream_out: if not chat.prompt_template.stream_out:
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream', return StreamingResponse(
background=background_tasks) no_stream_generator(chat),
headers=headers,
media_type="text/event-stream",
background=background_tasks,
)
else: else:
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain', return StreamingResponse(
background=background_tasks) stream_generator(chat),
headers=headers,
media_type="text/plain",
background=background_tasks,
)
def release_model_semaphore(): def release_model_semaphore():
@ -236,12 +256,15 @@ async def no_stream_generator(chat):
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n" yield f"data: {msg}\n\n"
async def stream_generator(chat): async def stream_generator(chat):
model_response = chat.stream_call() model_response = chat.stream_call()
if not CFG.NEW_SERVER_MODE: if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg) chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n" yield f"data:{msg}\n\n"
@ -249,7 +272,9 @@ async def stream_generator(chat):
else: else:
for chunk in model_response: for chunk in model_response:
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg) chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")
@ -259,4 +284,6 @@ async def stream_generator(chat):
def message2Vo(message: dict, order) -> MessageVo: def message2Vo(message: dict, order) -> MessageVo:
return MessageVo(role=message['type'], context=message['data']['content'], order=order) return MessageVo(
role=message["type"], context=message["data"]["content"], order=order
)

View File

@ -76,18 +76,34 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
@router.post("/knowledge/{space_name}/document/upload") @router.post("/knowledge/{space_name}/document/upload")
async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)): async def document_upload(
space_name: str,
doc_name: str = Form(...),
doc_type: str = Form(...),
doc_file: UploadFile = File(...),
):
print(f"/document/upload params: {space_name}") print(f"/document/upload params: {space_name}")
try: try:
if doc_file: if doc_file:
with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp: with NamedTemporaryFile(
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False
) as tmp:
tmp.write(await doc_file.read()) tmp.write(await doc_file.read())
tmp_path = tmp.name tmp_path = tmp.name
shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename)) shutil.move(
tmp_path,
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
request = KnowledgeDocumentRequest() request = KnowledgeDocumentRequest()
request.doc_name = doc_name request.doc_name = doc_name
request.doc_type = doc_type request.doc_type = doc_type
request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename), request.content = (
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
knowledge_space_service.create_knowledge_document( knowledge_space_service.create_knowledge_document(
space=space_name, request=request space=space_name, request=request
) )

View File

@ -25,7 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import (
) )
from enum import Enum from enum import Enum
from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse from pilot.openapi.knowledge.request.knowledge_response import (
ChunkQueryResponse,
DocumentQueryResponse,
)
knowledge_space_dao = KnowledgeSpaceDao() knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao() knowledge_document_dao = KnowledgeDocumentDao()

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
class ChunkQueryResponse(BaseModel): class ChunkQueryResponse(BaseModel):
"""data: data""" """data: data"""
data: List = None data: List = None
"""total: total size""" """total: total size"""
total: int = None total: int = None
@ -14,9 +15,9 @@ class ChunkQueryResponse(BaseModel):
class DocumentQueryResponse(BaseModel): class DocumentQueryResponse(BaseModel):
"""data: data""" """data: data"""
data: List = None data: List = None
"""total: total size""" """total: total size"""
total: int = None total: int = None
"""page: current page""" """page: current page"""
page: int = None page: int = None

View File

@ -122,7 +122,7 @@ class BaseOutputParser(ABC):
def __extract_json(slef, s): def __extract_json(slef, s):
i = s.index("{") i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1): for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}": if c == "}":
count -= 1 count -= 1
elif c == "{": elif c == "{":
@ -130,7 +130,7 @@ class BaseOutputParser(ABC):
if count == 0: if count == 0:
break break
assert count == 0 # 检查是否找到最后一个'}' assert count == 0 # 检查是否找到最后一个'}'
return s[i: j + 1] return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
@ -147,9 +147,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output: # if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```") # cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"): if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):] cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"): if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):] cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"): if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip() cleaned_output = cleaned_output.strip()
@ -158,9 +158,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output) cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = ( cleaned_output = (
cleaned_output.strip() cleaned_output.strip()
.replace("\n", " ") .replace("\n", " ")
.replace("\\n", " ") .replace("\\n", " ")
.replace("\\", " ") .replace("\\", " ")
) )
return cleaned_output return cleaned_output

View File

@ -60,10 +60,10 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, self,
chat_mode, chat_mode,
chat_session_id, chat_session_id,
current_user_input, current_user_input,
): ):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
@ -102,7 +102,9 @@ class BaseChat(ABC):
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
# TODO # TODO
self.current_message.tokens = 0 self.current_message.tokens = 0
@ -168,11 +170,18 @@ class BaseChat(ABC):
print("[TEST: output]:", rsp_str) print("[TEST: output]:", rsp_str)
### output parse ### output parse
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str, ai_response_text = (
self.prompt_template.sep) self.prompt_template.output_parser.parse_model_nostream_resp(
rsp_str, self.prompt_template.sep
)
)
### model result deal ### model result deal
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_action(prompt_define_response) result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"): if hasattr(prompt_define_response, "thoughts"):
@ -232,7 +241,9 @@ class BaseChat(ABC):
system_convs = self.current_message.get_system_conv() system_convs = self.current_message.get_system_conv()
system_text = "" system_text = ""
for system_conv in system_convs: for system_conv in system_convs:
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
return system_text return system_text
def __load_user_message(self): def __load_user_message(self):
@ -246,13 +257,16 @@ class BaseChat(ABC):
example_text = "" example_text = ""
if self.prompt_template.example_selector: if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples(): for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv['messages']: for round_message in round_conv["messages"]:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: if not round_message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
example_text += ( example_text += (
round_message['type'] round_message["type"]
+ ":" + ":"
+ round_message['data']['content'] + round_message["data"]["content"]
+ self.prompt_template.sep + self.prompt_template.sep
) )
return example_text return example_text
@ -264,37 +278,46 @@ class BaseChat(ABC):
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
) )
if len(self.history_message) > self.chat_retention_rounds: if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]['messages']: for first_message in self.history_message[0]["messages"]:
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: if not first_message["type"] in [
ViewMessage.type,
SystemMessage.type,
]:
history_text += ( history_text += (
first_message['type'] first_message["type"]
+ ":" + ":"
+ first_message['data']['content'] + first_message["data"]["content"]
+ self.prompt_template.sep + self.prompt_template.sep
) )
index = self.chat_retention_rounds - 1 index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]: for round_conv in self.history_message[-index:]:
for round_message in round_conv['messages']: for round_message in round_conv["messages"]:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: if not round_message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
history_text += ( history_text += (
round_message['type'] round_message["type"]
+ ":" + ":"
+ round_message['data']['content'] + round_message["data"]["content"]
+ self.prompt_template.sep + self.prompt_template.sep
) )
else: else:
### user all history ### user all history
for conversation in self.history_message: for conversation in self.history_message:
for message in conversation['messages']: for message in conversation["messages"]:
### histroy message not have promot and view info ### histroy message not have promot and view info
if not message['type'] in [SystemMessage.type, ViewMessage.type]: if not message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
history_text += ( history_text += (
message['type'] message["type"]
+ ":" + ":"
+ message['data']['content'] + message["data"]["content"]
+ self.prompt_template.sep + self.prompt_template.sep
) )
return history_text return history_text

View File

@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default ## Two examples are defined by default
EXAMPLES = [ EXAMPLES = [
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}], [{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}] [{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}],
] ]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@ -98,9 +98,10 @@ class OnceConversation:
system_convs.append(message) system_convs.append(message)
return system_convs return system_convs
def _conversation_to_dic(once: OnceConversation) -> dict: def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = "" start_str: str = ""
if hasattr(once, 'start_date') and once.start_date: if hasattr(once, "start_date") and once.start_date:
if isinstance(once.start_date, datetime): if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else: else:

View File

@ -23,9 +23,12 @@ from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
static_file_path = os.path.join(os.getcwd(), "server/static") static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config() CFG = Config()
@ -34,9 +37,10 @@ logger = build_logger("webserver", LOGDIR + "webserver.log")
def swagger_monkey_patch(*args, **kwargs): def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html( return get_swagger_ui_html(
*args, **kwargs, *args,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', **kwargs,
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
) )
@ -55,14 +59,16 @@ app.add_middleware(
) )
app.mount("/static", StaticFiles(directory=static_file_path), name="static") app.mount("/static", StaticFiles(directory=static_file_path), name="static")
app.add_route("/test", "static/test.html") app.add_route("/test", "static/test.html")
app.include_router(knowledge_router)
app.include_router(api_v1) app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
# old version server config # old version server config
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
@ -75,4 +81,5 @@ if __name__ == "__main__":
server_init(args) server_init(args)
CFG.NEW_SERVER_MODE = True CFG.NEW_SERVER_MODE = True
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000) uvicorn.run(app, host="0.0.0.0", port=5000)

View File

@ -9,7 +9,8 @@ import sys
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
global_counter = 0 global_counter = 0
@ -41,11 +42,11 @@ class ModelWorker:
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr( if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length" self.model.config, "max_sequence_length"
): ):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr( elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings" self.model.config, "max_position_embeddings"
): ):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
@ -60,22 +61,22 @@ class ModelWorker:
def get_queue_length(self): def get_queue_length(self):
if ( if (
model_semaphore is None model_semaphore is None
or model_semaphore._value is None or model_semaphore._value is None
or model_semaphore._waiters is None or model_semaphore._waiters is None
): ):
return 0 return 0
else: else:
( (
CFG.LIMIT_MODEL_CONCURRENCY CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value - model_semaphore._value
+ len(model_semaphore._waiters) + len(model_semaphore._waiters)
) )
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
try: try:
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):
# Please do not open the output in production! # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process, # The gpt4all thread shares stdout with the parent process,
@ -107,23 +108,23 @@ worker = ModelWorker(
) )
app = FastAPI() app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router # from pilot.openapi.knowledge.knowledge_controller import router
#
app.include_router(router) # app.include_router(router)
#
origins = [ # origins = [
"http://localhost", # "http://localhost",
"http://localhost:8000", # "http://localhost:8000",
"http://localhost:3000", # "http://localhost:3000",
] # ]
#
app.add_middleware( # app.add_middleware(
CORSMiddleware, # CORSMiddleware,
allow_origins=origins, # allow_origins=origins,
allow_credentials=True, # allow_credentials=True,
allow_methods=["*"], # allow_methods=["*"],
allow_headers=["*"], # allow_headers=["*"],
) # )
class PromptRequest(BaseModel): class PromptRequest(BaseModel):

View File

@ -40,6 +40,7 @@ def server_init(args):
cfg = Config() cfg = Config()
from pilot.server.llmserver import worker from pilot.server.llmserver import worker
worker.start_check() worker.start_check()
load_native_plugins(cfg) load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)