From b2d2828b4eb7e5eb343034fa54086285e75a5e53 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Wed, 28 Jun 2023 11:34:40 +0800 Subject: [PATCH] WEB API independent --- pilot/configs/config.py | 2 + pilot/memory/chat_history/duckdb_history.py | 4 +- pilot/model/llm_out/proxy_llm.py | 3 +- pilot/openapi/api_v1/api_v1.py | 89 +++++--------- pilot/out_parser/base.py | 23 ++-- pilot/prompts/example_base.py | 4 + pilot/scene/base_chat.py | 127 +++++++++----------- pilot/scene/chat_db/auto_execute/prompt.py | 2 +- pilot/scene/chat_execution/example.py | 4 +- pilot/scene/chat_execution/prompt.py | 2 - pilot/scene/chat_normal/prompt.py | 3 +- pilot/scene/message.py | 2 +- pilot/server/dbgpt_server.py | 74 ++++++++++++ pilot/server/llmserver.py | 31 ++--- pilot/server/webserver_base.py | 2 + 15 files changed, 201 insertions(+), 171 deletions(-) create mode 100644 pilot/server/dbgpt_server.py diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 94ff19e21..7bf6d97e6 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -17,6 +17,8 @@ class Config(metaclass=Singleton): def __init__(self) -> None: """Initialize the Config class""" + self.NEW_SERVER_MODE = False + # Gradio language version: en, zh self.LANGUAGE = os.getenv("LANGUAGE", "en") self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860)) diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index 42177be75..de80a5bc2 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -44,7 +44,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) content = cursor.fetchone() if content: - return cursor.fetchone()[0] + return content[0] else: return None def messages(self) -> List[OnceConversation]: @@ -66,7 +66,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) else: cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", - [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)]) + [self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)]) cursor.commit() self.connect.commit() diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 4336d43e3..c4423a1a6 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) elif "ai:" in message: history.append( { - "role": "ai", + "role": "assistant", "content": message.split("ai:")[1], } ) @@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) for m in temp_his: if m["role"] == "user": last_user_input = m + break if last_user_input: history.remove(last_user_input) history.append(last_user_input) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 832c6785c..d11ba0976 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -2,7 +2,7 @@ import uuid import json import asyncio import time -from fastapi import APIRouter, Request, Body, status, HTTPException, Response +from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse @@ -31,6 +31,8 @@ CHAT_FACTORY = ChatFactory() logger = build_logger("api_v1", LOGDIR + "api_v1.log") knowledge_service = KnowledgeService() +model_semaphore = None +global_counter = 0 async def validation_exception_handler(request: Request, exc: RequestValidationError): message = "" @@ -148,6 +150,11 @@ async def dialogue_history_messages(con_uid: str): @router.post('/v1/chat/completions') async def chat_completions(dialogue: ConversationVo = Body()): print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") + global model_semaphore, global_counter + global_counter += 1 + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) + await model_semaphore.acquire() if not ChatScene.is_valid_mode(dialogue.chat_mode): raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")) @@ -170,73 +177,31 @@ async def chat_completions(dialogue: ConversationVo = Body()): chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) if not chat.prompt_template.stream_out: - return non_stream_response(chat) + return chat.nostream_call() else: - return StreamingResponse(stream_generator(chat), media_type="text/plain") - - -def stream_test(): - for message in ["Hello", "world", "how", "are", "you"]: - yield message - # yield json.dumps(Result.succ(message).__dict__).encode("utf-8") + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + return StreamingResponse(stream_generator(chat), background=background_tasks) +def release_model_semaphore(): + model_semaphore.release() def stream_generator(chat): model_response = chat.stream_call() - for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) - chat.current_message.add_ai_message(msg) - yield msg - # chat.current_message.add_ai_message(msg) - # vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order) - # json_text = json.dumps(vo.__dict__) - # yield json_text.encode('utf-8') + if not CFG.NEW_SERVER_MODE: + for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + chat.current_message.add_ai_message(msg) + yield msg + else: + for chunk in model_response: + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + chat.current_message.add_ai_message(msg) + yield msg chat.memory.append(chat.current_message) def message2Vo(message: dict, order) -> MessageVo: - # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 - return MessageVo(role=message['type'], context=message['data']['content'], order=order) - - -def non_stream_response(chat): - logger.info("not stream out, wait model response!") - return chat.nostream_call() - - -@router.get('/v1/db/types', response_model=Result[str]) -async def db_types(): - return Result.succ(["mysql", "duckdb"]) - - -@router.get('/v1/db/list', response_model=Result[str]) -async def db_list(): - db = CFG.local_db - dbs = db.get_database_list() - return Result.succ(dbs) - - -@router.get('/v1/knowledge/list') -async def knowledge_list(): - return ["test1", "test2"] - - -@router.post('/v1/knowledge/add') -async def knowledge_add(): - return ["test1", "test2"] - - -@router.post('/v1/knowledge/delete') -async def knowledge_delete(): - return ["test1", "test2"] - - -@router.get('/v1/knowledge/types') -async def knowledge_types(): - return ["test1", "test2"] - - -@router.get('/v1/knowledge/detail') -async def knowledge_detail(): - return ["test1", "test2"] + return MessageVo(role=message['type'], context=message['data']['content'], order=order) \ No newline at end of file diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 4476a68fc..ca308e92f 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -47,6 +47,8 @@ class BaseOutputParser(ABC): return code def parse_model_stream_resp_ex(self, chunk, skip_echo_len): + if b"\0" in chunk: + chunk = chunk.replace(b"\0", b"") data = json.loads(chunk.decode()) """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. @@ -95,11 +97,8 @@ class BaseOutputParser(ABC): def parse_model_nostream_resp(self, response, sep: str): text = response.text.strip() text = text.rstrip() - respObj = json.loads(text) - - xx = respObj["response"] - xx = xx.strip(b"\x00".decode()) - respObj_ex = json.loads(xx) + text = text.strip(b"\x00".decode()) + respObj_ex = json.loads(text) if respObj_ex["error_code"] == 0: all_text = respObj_ex["text"] ### 解析返回文本,获取AI回复部分 @@ -123,7 +122,7 @@ class BaseOutputParser(ABC): def __extract_json(slef, s): i = s.index("{") count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1 :], start=i + 1): + for j, c in enumerate(s[i + 1:], start=i + 1): if c == "}": count -= 1 elif c == "{": @@ -131,7 +130,7 @@ class BaseOutputParser(ABC): if count == 0: break assert count == 0 # 检查是否找到最后一个'}' - return s[i : j + 1] + return s[i: j + 1] def parse_prompt_response(self, model_out_text) -> T: """ @@ -148,9 +147,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json") :] + cleaned_output = cleaned_output[len("```json"):] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```") :] + cleaned_output = cleaned_output[len("```"):] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -159,9 +158,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", " ") - .replace("\\n", " ") - .replace("\\", " ") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py index 4d876aa51..3372c2b1d 100644 --- a/pilot/prompts/example_base.py +++ b/pilot/prompts/example_base.py @@ -15,6 +15,10 @@ class ExampleSelector(BaseModel, ABC): else: return self.__few_shot_context(count) + def __examples_text(self, used_examples): + + + def __few_shot_context(self, count: int = 2) -> List[List]: """ Use 2 or more examples, default 2 diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 042d6af59..560ce7039 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -39,6 +39,7 @@ from pilot.scene.base_message import ( ViewMessage, ) from pilot.configs.config import Config +from pilot.server.llmserver import worker logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") headers = {"User-Agent": "dbgpt Client"} @@ -59,10 +60,10 @@ class BaseChat(ABC): arbitrary_types_allowed = True def __init__( - self, - chat_mode, - chat_session_id, - current_user_input, + self, + chat_mode, + chat_session_id, + current_user_input, ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode @@ -95,7 +96,6 @@ class BaseChat(ABC): def generate_input_values(self): pass - def do_action(self, prompt_response): return prompt_response @@ -138,24 +138,17 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - show_info = "" - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=120, - ) - return response - - # yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len) - - # for resp_text_trunck in ai_response_text: - # show_info = resp_text_trunck - # yield resp_text_trunck + "▌" - - self.current_message.add_ai_message(show_info) - + if not CFG.NEW_SERVER_MODE: + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + stream=True, + timeout=120, + ) + return response + else: + return worker.generate_stream_gate(payload) except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) @@ -170,39 +163,28 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - ### 走非流式的模型服务接口 - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate"), - headers=headers, - json=payload, - timeout=120, - ) + rsp_str = "" + if not CFG.NEW_SERVER_MODE: + ### 走非流式的模型服务接口 + rsp_str = requests.post( + urljoin(CFG.MODEL_SERVER, "generate"), + headers=headers, + json=payload, + timeout=120, + ) + else: + ###TODO no stream mode need independent + output = worker.generate_stream_gate(payload) + for rsp in output: + rsp_str = str(rsp, "utf-8") + print("[TEST: output]:", rsp_str) ### output parse - ai_response_text = ( - self.prompt_template.output_parser.parse_model_nostream_resp( - response, self.prompt_template.sep - ) - ) - - # ### MOCK - # ai_response_text = """{ - # "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。", - # "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。", - # "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", - # "command": { - # "name": "histogram-executor", - # "args": { - # "title": "订单城市分布柱状图", - # "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" - # } - # } - # }""" - + ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str, + self.prompt_template.sep) + ### model result deal self.current_message.add_ai_message(ai_response_text) prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) - - result = self.do_action(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): @@ -248,41 +230,42 @@ class BaseChat(ABC): ### 处理历史信息 if len(self.history_message) > self.chat_retention_rounds: ### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 - for first_message in self.history_message[0].messages: - if not isinstance(first_message, ViewMessage): + for first_message in self.history_message[0]['messages']: + if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: text += ( - first_message.type - + ":" - + first_message.content - + self.prompt_template.sep + first_message['type'] + + ":" + + first_message['data']['content'] + + self.prompt_template.sep ) index = self.chat_retention_rounds - 1 - for last_message in self.history_message[-index:].messages: - if not isinstance(last_message, ViewMessage): - text += ( - last_message.type - + ":" - + last_message.content - + self.prompt_template.sep - ) + for round_conv in self.history_message[-index:]: + for round_message in round_conv['messages']: + if not isinstance(round_message, ViewMessage): + text += ( + round_message['type'] + + ":" + + round_message['data']['content'] + + self.prompt_template.sep + ) else: ### 直接历史记录拼接 for conversation in self.history_message: - for message in conversation.messages: + for message in conversation['messages']: if not isinstance(message, ViewMessage): text += ( - message.type - + ":" - + message.content - + self.prompt_template.sep + message['type'] + + ":" + + message['data']['content'] + + self.prompt_template.sep ) ### current conversation for now_message in self.current_message.messages: text += ( - now_message.type + ":" + now_message.content + self.prompt_template.sep + now_message.type + ":" + now_message.content + self.prompt_template.sep ) return text diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 938860aaf..cc4878c70 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle CFG = Config() -PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" +PROMPT_SCENE_DEFINE = None _DEFAULT_TEMPLATE = """ diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index 6cd71b39c..0008ee9ec 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector ## Two examples are defined by default EXAMPLES = [ - [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}], - [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}] + [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}], + [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}] ] example = ExampleSelector(examples=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index b5bc38bb2..fef2a13b3 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -9,10 +9,8 @@ from pilot.scene.chat_execution.example import example CFG = Config() -# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.""" PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers." - _DEFAULT_TEMPLATE = """ Goals: {input} diff --git a/pilot/scene/chat_normal/prompt.py b/pilot/scene/chat_normal/prompt.py index 55514dffa..1dc455255 100644 --- a/pilot/scene/chat_normal/prompt.py +++ b/pilot/scene/chat_normal/prompt.py @@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle from pilot.scene.chat_normal.out_parser import NormalChatOutputParser -PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. - The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ +PROMPT_SCENE_DEFINE = None CFG = Config() diff --git a/pilot/scene/message.py b/pilot/scene/message.py index a2a894fe8..f2bd1fa56 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -95,7 +95,7 @@ class OnceConversation: def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" - if once.start_date: + if hasattr(once, 'start_date') and once.start_date: if isinstance(once.start_date, datetime): start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") else: diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py new file mode 100644 index 000000000..7032982d9 --- /dev/null +++ b/pilot/server/dbgpt_server.py @@ -0,0 +1,74 @@ +import traceback +import os +import shutil +import argparse +import sys + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +from pilot.configs.config import Config +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, +) +from pilot.utils import build_logger + +from pilot.server.webserver_base import server_init + +from fastapi import FastAPI, applications +from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware + +from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler + + +CFG = Config() +logger = build_logger("webserver", LOGDIR + "webserver.log") + + +def swagger_monkey_patch(*args, **kwargs): + return get_swagger_ui_html( + *args, **kwargs, + swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', + swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' + ) + + +applications.get_swagger_ui_html = swagger_monkey_patch + +app = FastAPI() +origins = ["*"] + +# 添加跨域中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + +# app.mount("static", StaticFiles(directory="static"), name="static") +app.include_router(api_v1) +app.add_exception_handler(RequestValidationError, validation_exception_handler) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) + + # old version server config + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--share", default=False, action="store_true") + + # init server config + args = parser.parse_args() + server_init(args) + CFG.NEW_SERVER_MODE = True + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 3030d1fdc..0fb4c5ae6 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -27,14 +27,13 @@ from pilot.server.chat_adapter import get_llm_chat_adapter CFG = Config() - class ModelWorker: def __init__(self, model_path, model_name, device, num_gpus=1): if model_path.endswith("/"): model_path = model_path[:-1] self.model_name = model_name or model_path.split("/")[-1] self.device = device - + print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......") self.ml = ModelLoader(model_path=model_path) self.model, self.tokenizer = self.ml.loader( num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG @@ -42,11 +41,11 @@ class ModelWorker: if not isinstance(self.model, str): if hasattr(self.model, "config") and hasattr( - self.model.config, "max_sequence_length" + self.model.config, "max_sequence_length" ): self.context_len = self.model.config.max_sequence_length elif hasattr(self.model, "config") and hasattr( - self.model.config, "max_position_embeddings" + self.model.config, "max_position_embeddings" ): self.context_len = self.model.config.max_position_embeddings @@ -56,29 +55,32 @@ class ModelWorker: self.llm_chat_adapter = get_llm_chat_adapter(model_path) self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func() + def start_check(self): + print("LLM Model Loading Success!") + def get_queue_length(self): if ( - model_semaphore is None - or model_semaphore._value is None - or model_semaphore._waiters is None + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None ): return 0 else: ( - CFG.LIMIT_MODEL_CONCURRENCY - - model_semaphore._value - + len(model_semaphore._waiters) + CFG.LIMIT_MODEL_CONCURRENCY + - model_semaphore._value + + len(model_semaphore._waiters) ) def generate_stream_gate(self, params): try: for output in self.generate_stream_func( - self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS + self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, # and opening it may affect the frontend output. - # print("output: ", output) + print("output: ", output) ret = { "text": output, "error_code": 0, @@ -106,6 +108,7 @@ worker = ModelWorker( app = FastAPI() from pilot.openapi.knowledge.knowledge_controller import router + app.include_router(router) origins = [ @@ -122,6 +125,7 @@ app.add_middleware( allow_headers=["*"] ) + class PromptRequest(BaseModel): prompt: str temperature: float @@ -177,10 +181,9 @@ def generate(prompt_request: PromptRequest): for rsp in output: # rsp = rsp.decode("utf-8") rsp_str = str(rsp, "utf-8") - print("[TEST: output]:", rsp_str) response.append(rsp_str) - return {"response": rsp_str} + return rsp_str @app.post("/embedding") diff --git a/pilot/server/webserver_base.py b/pilot/server/webserver_base.py index 0aa2ac3f9..c76984c37 100644 --- a/pilot/server/webserver_base.py +++ b/pilot/server/webserver_base.py @@ -39,6 +39,8 @@ def server_init(args): # init config cfg = Config() + from pilot.server.llmserver import worker + worker.start_check() load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) async_db_summery()