WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-28 11:34:40 +08:00
parent 1d3d6cb23c
commit b2d2828b4e
15 changed files with 201 additions and 171 deletions

View File

@ -17,6 +17,8 @@ class Config(metaclass=Singleton):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the Config class""" """Initialize the Config class"""
self.NEW_SERVER_MODE = False
# Gradio language version: en, zh # Gradio language version: en, zh
self.LANGUAGE = os.getenv("LANGUAGE", "en") self.LANGUAGE = os.getenv("LANGUAGE", "en")
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860)) self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))

View File

@ -44,7 +44,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
content = cursor.fetchone() content = cursor.fetchone()
if content: if content:
return cursor.fetchone()[0] return content[0]
else: else:
return None return None
def messages(self) -> List[OnceConversation]: def messages(self) -> List[OnceConversation]:
@ -66,7 +66,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
else: else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)]) [self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)])
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()

View File

@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
elif "ai:" in message: elif "ai:" in message:
history.append( history.append(
{ {
"role": "ai", "role": "assistant",
"content": message.split("ai:")[1], "content": message.split("ai:")[1],
} }
) )
@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
for m in temp_his: for m in temp_his:
if m["role"] == "user": if m["role"] == "user":
last_user_input = m last_user_input = m
break
if last_user_input: if last_user_input:
history.remove(last_user_input) history.remove(last_user_input)
history.append(last_user_input) history.append(last_user_input)

View File

@ -2,7 +2,7 @@ import uuid
import json import json
import asyncio import asyncio
import time import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -31,6 +31,8 @@ CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log") logger = build_logger("api_v1", LOGDIR + "api_v1.log")
knowledge_service = KnowledgeService() knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = "" message = ""
@ -148,6 +150,11 @@ async def dialogue_history_messages(con_uid: str):
@router.post('/v1/chat/completions') @router.post('/v1/chat/completions')
async def chat_completions(dialogue: ConversationVo = Body()): async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
global model_semaphore, global_counter
global_counter += 1
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
if not ChatScene.is_valid_mode(dialogue.chat_mode): if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")) raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
@ -170,73 +177,31 @@ async def chat_completions(dialogue: ConversationVo = Body()):
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
if not chat.prompt_template.stream_out: if not chat.prompt_template.stream_out:
return non_stream_response(chat) return chat.nostream_call()
else: else:
return StreamingResponse(stream_generator(chat), media_type="text/plain") background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(stream_generator(chat), background=background_tasks)
def stream_test():
for message in ["Hello", "world", "how", "are", "you"]:
yield message
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
def release_model_semaphore():
model_semaphore.release()
def stream_generator(chat): def stream_generator(chat):
model_response = chat.stream_call() model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if not CFG.NEW_SERVER_MODE:
if chunk: for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) if chunk:
chat.current_message.add_ai_message(msg) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
yield msg chat.current_message.add_ai_message(msg)
# chat.current_message.add_ai_message(msg) yield msg
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order) else:
# json_text = json.dumps(vo.__dict__) for chunk in model_response:
# yield json_text.encode('utf-8') if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
yield msg
chat.memory.append(chat.current_message) chat.memory.append(chat.current_message)
def message2Vo(message: dict, order) -> MessageVo: def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 return MessageVo(role=message['type'], context=message['data']['content'], order=order)
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
def non_stream_response(chat):
logger.info("not stream out, wait model response!")
return chat.nostream_call()
@router.get('/v1/db/types', response_model=Result[str])
async def db_types():
return Result.succ(["mysql", "duckdb"])
@router.get('/v1/db/list', response_model=Result[str])
async def db_list():
db = CFG.local_db
dbs = db.get_database_list()
return Result.succ(dbs)
@router.get('/v1/knowledge/list')
async def knowledge_list():
return ["test1", "test2"]
@router.post('/v1/knowledge/add')
async def knowledge_add():
return ["test1", "test2"]
@router.post('/v1/knowledge/delete')
async def knowledge_delete():
return ["test1", "test2"]
@router.get('/v1/knowledge/types')
async def knowledge_types():
return ["test1", "test2"]
@router.get('/v1/knowledge/detail')
async def knowledge_detail():
return ["test1", "test2"]

View File

@ -47,6 +47,8 @@ class BaseOutputParser(ABC):
return code return code
def parse_model_stream_resp_ex(self, chunk, skip_echo_len): def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
if b"\0" in chunk:
chunk = chunk.replace(b"\0", b"")
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
@ -95,11 +97,8 @@ class BaseOutputParser(ABC):
def parse_model_nostream_resp(self, response, sep: str): def parse_model_nostream_resp(self, response, sep: str):
text = response.text.strip() text = response.text.strip()
text = text.rstrip() text = text.rstrip()
respObj = json.loads(text) text = text.strip(b"\x00".decode())
respObj_ex = json.loads(text)
xx = respObj["response"]
xx = xx.strip(b"\x00".decode())
respObj_ex = json.loads(xx)
if respObj_ex["error_code"] == 0: if respObj_ex["error_code"] == 0:
all_text = respObj_ex["text"] all_text = respObj_ex["text"]
### 解析返回文本获取AI回复部分 ### 解析返回文本获取AI回复部分
@ -123,7 +122,7 @@ class BaseOutputParser(ABC):
def __extract_json(slef, s): def __extract_json(slef, s):
i = s.index("{") i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1): for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "}": if c == "}":
count -= 1 count -= 1
elif c == "{": elif c == "{":
@ -131,7 +130,7 @@ class BaseOutputParser(ABC):
if count == 0: if count == 0:
break break
assert count == 0 # 检查是否找到最后一个'}' assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1] return s[i: j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
@ -148,9 +147,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output: # if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```") # cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"): if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :] cleaned_output = cleaned_output[len("```json"):]
if cleaned_output.startswith("```"): if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :] cleaned_output = cleaned_output[len("```"):]
if cleaned_output.endswith("```"): if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip() cleaned_output = cleaned_output.strip()
@ -159,9 +158,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output) cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = ( cleaned_output = (
cleaned_output.strip() cleaned_output.strip()
.replace("\n", " ") .replace("\n", " ")
.replace("\\n", " ") .replace("\\n", " ")
.replace("\\", " ") .replace("\\", " ")
) )
return cleaned_output return cleaned_output

View File

@ -15,6 +15,10 @@ class ExampleSelector(BaseModel, ABC):
else: else:
return self.__few_shot_context(count) return self.__few_shot_context(count)
def __examples_text(self, used_examples):
def __few_shot_context(self, count: int = 2) -> List[List]: def __few_shot_context(self, count: int = 2) -> List[List]:
""" """
Use 2 or more examples, default 2 Use 2 or more examples, default 2

View File

@ -39,6 +39,7 @@ from pilot.scene.base_message import (
ViewMessage, ViewMessage,
) )
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.server.llmserver import worker
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"} headers = {"User-Agent": "dbgpt Client"}
@ -59,10 +60,10 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, self,
chat_mode, chat_mode,
chat_session_id, chat_session_id,
current_user_input, current_user_input,
): ):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
@ -95,7 +96,6 @@ class BaseChat(ABC):
def generate_input_values(self): def generate_input_values(self):
pass pass
def do_action(self, prompt_response): def do_action(self, prompt_response):
return prompt_response return prompt_response
@ -138,24 +138,17 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
show_info = "" if not CFG.NEW_SERVER_MODE:
response = requests.post( response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"), urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, headers=headers,
json=payload, json=payload,
stream=True, stream=True,
timeout=120, timeout=120,
) )
return response return response
else:
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len) return worker.generate_stream_gate(payload)
# for resp_text_trunck in ai_response_text:
# show_info = resp_text_trunck
# yield resp_text_trunck + "▌"
self.current_message.add_ai_message(show_info)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error("model response parase faild" + str(e)) logger.error("model response parase faild" + str(e))
@ -170,39 +163,28 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
### 走非流式的模型服务接口 rsp_str = ""
response = requests.post( if not CFG.NEW_SERVER_MODE:
urljoin(CFG.MODEL_SERVER, "generate"), ### 走非流式的模型服务接口
headers=headers, rsp_str = requests.post(
json=payload, urljoin(CFG.MODEL_SERVER, "generate"),
timeout=120, headers=headers,
) json=payload,
timeout=120,
)
else:
###TODO no stream mode need independent
output = worker.generate_stream_gate(payload)
for rsp in output:
rsp_str = str(rsp, "utf-8")
print("[TEST: output]:", rsp_str)
### output parse ### output parse
ai_response_text = ( ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
self.prompt_template.output_parser.parse_model_nostream_resp( self.prompt_template.sep)
response, self.prompt_template.sep ### model result deal
)
)
# ### MOCK
# ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": {
# "name": "histogram-executor",
# "args": {
# "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# }
# }
# }"""
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
result = self.do_action(prompt_define_response) result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"): if hasattr(prompt_define_response, "thoughts"):
@ -248,41 +230,42 @@ class BaseChat(ABC):
### 处理历史信息 ### 处理历史信息
if len(self.history_message) > self.chat_retention_rounds: if len(self.history_message) > self.chat_retention_rounds:
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 ### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages: for first_message in self.history_message[0]['messages']:
if not isinstance(first_message, ViewMessage): if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
text += ( text += (
first_message.type first_message['type']
+ ":" + ":"
+ first_message.content + first_message['data']['content']
+ self.prompt_template.sep + self.prompt_template.sep
) )
index = self.chat_retention_rounds - 1 index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages: for round_conv in self.history_message[-index:]:
if not isinstance(last_message, ViewMessage): for round_message in round_conv['messages']:
text += ( if not isinstance(round_message, ViewMessage):
last_message.type text += (
+ ":" round_message['type']
+ last_message.content + ":"
+ self.prompt_template.sep + round_message['data']['content']
) + self.prompt_template.sep
)
else: else:
### 直接历史记录拼接 ### 直接历史记录拼接
for conversation in self.history_message: for conversation in self.history_message:
for message in conversation.messages: for message in conversation['messages']:
if not isinstance(message, ViewMessage): if not isinstance(message, ViewMessage):
text += ( text += (
message.type message['type']
+ ":" + ":"
+ message.content + message['data']['content']
+ self.prompt_template.sep + self.prompt_template.sep
) )
### current conversation ### current conversation
for now_message in self.current_message.messages: for now_message in self.current_message.messages:
text += ( text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep now_message.type + ":" + now_message.content + self.prompt_template.sep
) )
return text return text

View File

@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE = """

View File

@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default ## Two examples are defined by default
EXAMPLES = [ EXAMPLES = [
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}], [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}],
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}] [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
] ]
example = ExampleSelector(examples=EXAMPLES, use_example=True) example = ExampleSelector(examples=EXAMPLES, use_example=True)

View File

@ -9,10 +9,8 @@ from pilot.scene.chat_execution.example import example
CFG = Config() CFG = Config()
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers." PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE = """
Goals: Goals:
{input} {input}

View File

@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. PROMPT_SCENE_DEFINE = None
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
CFG = Config() CFG = Config()

View File

@ -95,7 +95,7 @@ class OnceConversation:
def _conversation_to_dic(once: OnceConversation) -> dict: def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = "" start_str: str = ""
if once.start_date: if hasattr(once, 'start_date') and once.start_date:
if isinstance(once.start_date, datetime): if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else: else:

View File

@ -0,0 +1,74 @@
import traceback
import os
import shutil
import argparse
import sys
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.utils import build_logger
from pilot.server.webserver_base import server_init
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
origins = ["*"]
# 添加跨域中间件
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
# app.mount("static", StaticFiles(directory="static"), name="static")
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args()
server_init(args)
CFG.NEW_SERVER_MODE = True
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)

View File

@ -27,14 +27,13 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config() CFG = Config()
class ModelWorker: class ModelWorker:
def __init__(self, model_path, model_name, device, num_gpus=1): def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"): if model_path.endswith("/"):
model_path = model_path[:-1] model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1] self.model_name = model_name or model_path.split("/")[-1]
self.device = device self.device = device
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
self.ml = ModelLoader(model_path=model_path) self.ml = ModelLoader(model_path=model_path)
self.model, self.tokenizer = self.ml.loader( self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
@ -42,11 +41,11 @@ class ModelWorker:
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr( if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length" self.model.config, "max_sequence_length"
): ):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr( elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings" self.model.config, "max_position_embeddings"
): ):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
@ -56,29 +55,32 @@ class ModelWorker:
self.llm_chat_adapter = get_llm_chat_adapter(model_path) self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func() self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
def start_check(self):
print("LLM Model Loading Success")
def get_queue_length(self): def get_queue_length(self):
if ( if (
model_semaphore is None model_semaphore is None
or model_semaphore._value is None or model_semaphore._value is None
or model_semaphore._waiters is None or model_semaphore._waiters is None
): ):
return 0 return 0
else: else:
( (
CFG.LIMIT_MODEL_CONCURRENCY CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value - model_semaphore._value
+ len(model_semaphore._waiters) + len(model_semaphore._waiters)
) )
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
try: try:
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):
# Please do not open the output in production! # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process, # The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output. # and opening it may affect the frontend output.
# print("output: ", output) print("output: ", output)
ret = { ret = {
"text": output, "text": output,
"error_code": 0, "error_code": 0,
@ -106,6 +108,7 @@ worker = ModelWorker(
app = FastAPI() app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router) app.include_router(router)
origins = [ origins = [
@ -122,6 +125,7 @@ app.add_middleware(
allow_headers=["*"] allow_headers=["*"]
) )
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
prompt: str prompt: str
temperature: float temperature: float
@ -177,10 +181,9 @@ def generate(prompt_request: PromptRequest):
for rsp in output: for rsp in output:
# rsp = rsp.decode("utf-8") # rsp = rsp.decode("utf-8")
rsp_str = str(rsp, "utf-8") rsp_str = str(rsp, "utf-8")
print("[TEST: output]:", rsp_str)
response.append(rsp_str) response.append(rsp_str)
return {"response": rsp_str} return rsp_str
@app.post("/embedding") @app.post("/embedding")

View File

@ -39,6 +39,8 @@ def server_init(args):
# init config # init config
cfg = Config() cfg = Config()
from pilot.server.llmserver import worker
worker.start_check()
load_native_plugins(cfg) load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
async_db_summery() async_db_summery()