mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-27 20:39:41 +00:00
WEB API independent
This commit is contained in:
parent
1d3d6cb23c
commit
b2d2828b4e
@ -17,6 +17,8 @@ class Config(metaclass=Singleton):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the Config class"""
|
"""Initialize the Config class"""
|
||||||
|
|
||||||
|
self.NEW_SERVER_MODE = False
|
||||||
|
|
||||||
# Gradio language version: en, zh
|
# Gradio language version: en, zh
|
||||||
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
||||||
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
|
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
|
||||||
|
@ -44,7 +44,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
||||||
content = cursor.fetchone()
|
content = cursor.fetchone()
|
||||||
if content:
|
if content:
|
||||||
return cursor.fetchone()[0]
|
return content[0]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
def messages(self) -> List[OnceConversation]:
|
def messages(self) -> List[OnceConversation]:
|
||||||
@ -66,7 +66,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
|
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
|
||||||
else:
|
else:
|
||||||
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||||
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)])
|
[self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)])
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
elif "ai:" in message:
|
elif "ai:" in message:
|
||||||
history.append(
|
history.append(
|
||||||
{
|
{
|
||||||
"role": "ai",
|
"role": "assistant",
|
||||||
"content": message.split("ai:")[1],
|
"content": message.split("ai:")[1],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
for m in temp_his:
|
for m in temp_his:
|
||||||
if m["role"] == "user":
|
if m["role"] == "user":
|
||||||
last_user_input = m
|
last_user_input = m
|
||||||
|
break
|
||||||
if last_user_input:
|
if last_user_input:
|
||||||
history.remove(last_user_input)
|
history.remove(last_user_input)
|
||||||
history.append(last_user_input)
|
history.append(last_user_input)
|
||||||
|
@ -2,7 +2,7 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
|
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
|
||||||
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@ -31,6 +31,8 @@ CHAT_FACTORY = ChatFactory()
|
|||||||
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
||||||
knowledge_service = KnowledgeService()
|
knowledge_service = KnowledgeService()
|
||||||
|
|
||||||
|
model_semaphore = None
|
||||||
|
global_counter = 0
|
||||||
|
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
message = ""
|
message = ""
|
||||||
@ -148,6 +150,11 @@ async def dialogue_history_messages(con_uid: str):
|
|||||||
@router.post('/v1/chat/completions')
|
@router.post('/v1/chat/completions')
|
||||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||||
|
global model_semaphore, global_counter
|
||||||
|
global_counter += 1
|
||||||
|
if model_semaphore is None:
|
||||||
|
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||||
|
await model_semaphore.acquire()
|
||||||
|
|
||||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||||
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
||||||
@ -170,73 +177,31 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
|
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||||
if not chat.prompt_template.stream_out:
|
if not chat.prompt_template.stream_out:
|
||||||
return non_stream_response(chat)
|
return chat.nostream_call()
|
||||||
else:
|
else:
|
||||||
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
background_tasks = BackgroundTasks()
|
||||||
|
background_tasks.add_task(release_model_semaphore)
|
||||||
|
return StreamingResponse(stream_generator(chat), background=background_tasks)
|
||||||
def stream_test():
|
|
||||||
for message in ["Hello", "world", "how", "are", "you"]:
|
|
||||||
yield message
|
|
||||||
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
|
|
||||||
|
|
||||||
|
def release_model_semaphore():
|
||||||
|
model_semaphore.release()
|
||||||
|
|
||||||
def stream_generator(chat):
|
def stream_generator(chat):
|
||||||
model_response = chat.stream_call()
|
model_response = chat.stream_call()
|
||||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
if not CFG.NEW_SERVER_MODE:
|
||||||
if chunk:
|
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
if chunk:
|
||||||
chat.current_message.add_ai_message(msg)
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
yield msg
|
chat.current_message.add_ai_message(msg)
|
||||||
# chat.current_message.add_ai_message(msg)
|
yield msg
|
||||||
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
|
else:
|
||||||
# json_text = json.dumps(vo.__dict__)
|
for chunk in model_response:
|
||||||
# yield json_text.encode('utf-8')
|
if chunk:
|
||||||
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
|
chat.current_message.add_ai_message(msg)
|
||||||
|
yield msg
|
||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
|
|
||||||
|
|
||||||
def message2Vo(message: dict, order) -> MessageVo:
|
def message2Vo(message: dict, order) -> MessageVo:
|
||||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
|
||||||
|
|
||||||
|
|
||||||
def non_stream_response(chat):
|
|
||||||
logger.info("not stream out, wait model response!")
|
|
||||||
return chat.nostream_call()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/db/types', response_model=Result[str])
|
|
||||||
async def db_types():
|
|
||||||
return Result.succ(["mysql", "duckdb"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/db/list', response_model=Result[str])
|
|
||||||
async def db_list():
|
|
||||||
db = CFG.local_db
|
|
||||||
dbs = db.get_database_list()
|
|
||||||
return Result.succ(dbs)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/knowledge/list')
|
|
||||||
async def knowledge_list():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/knowledge/add')
|
|
||||||
async def knowledge_add():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/knowledge/delete')
|
|
||||||
async def knowledge_delete():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/knowledge/types')
|
|
||||||
async def knowledge_types():
|
|
||||||
return ["test1", "test2"]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/knowledge/detail')
|
|
||||||
async def knowledge_detail():
|
|
||||||
return ["test1", "test2"]
|
|
@ -47,6 +47,8 @@ class BaseOutputParser(ABC):
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
|
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
|
||||||
|
if b"\0" in chunk:
|
||||||
|
chunk = chunk.replace(b"\0", b"")
|
||||||
data = json.loads(chunk.decode())
|
data = json.loads(chunk.decode())
|
||||||
|
|
||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
@ -95,11 +97,8 @@ class BaseOutputParser(ABC):
|
|||||||
def parse_model_nostream_resp(self, response, sep: str):
|
def parse_model_nostream_resp(self, response, sep: str):
|
||||||
text = response.text.strip()
|
text = response.text.strip()
|
||||||
text = text.rstrip()
|
text = text.rstrip()
|
||||||
respObj = json.loads(text)
|
text = text.strip(b"\x00".decode())
|
||||||
|
respObj_ex = json.loads(text)
|
||||||
xx = respObj["response"]
|
|
||||||
xx = xx.strip(b"\x00".decode())
|
|
||||||
respObj_ex = json.loads(xx)
|
|
||||||
if respObj_ex["error_code"] == 0:
|
if respObj_ex["error_code"] == 0:
|
||||||
all_text = respObj_ex["text"]
|
all_text = respObj_ex["text"]
|
||||||
### 解析返回文本,获取AI回复部分
|
### 解析返回文本,获取AI回复部分
|
||||||
@ -123,7 +122,7 @@ class BaseOutputParser(ABC):
|
|||||||
def __extract_json(slef, s):
|
def __extract_json(slef, s):
|
||||||
i = s.index("{")
|
i = s.index("{")
|
||||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||||
if c == "}":
|
if c == "}":
|
||||||
count -= 1
|
count -= 1
|
||||||
elif c == "{":
|
elif c == "{":
|
||||||
@ -131,7 +130,7 @@ class BaseOutputParser(ABC):
|
|||||||
if count == 0:
|
if count == 0:
|
||||||
break
|
break
|
||||||
assert count == 0 # 检查是否找到最后一个'}'
|
assert count == 0 # 检查是否找到最后一个'}'
|
||||||
return s[i : j + 1]
|
return s[i: j + 1]
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
"""
|
"""
|
||||||
@ -148,9 +147,9 @@ class BaseOutputParser(ABC):
|
|||||||
# if "```" in cleaned_output:
|
# if "```" in cleaned_output:
|
||||||
# cleaned_output, _ = cleaned_output.split("```")
|
# cleaned_output, _ = cleaned_output.split("```")
|
||||||
if cleaned_output.startswith("```json"):
|
if cleaned_output.startswith("```json"):
|
||||||
cleaned_output = cleaned_output[len("```json") :]
|
cleaned_output = cleaned_output[len("```json"):]
|
||||||
if cleaned_output.startswith("```"):
|
if cleaned_output.startswith("```"):
|
||||||
cleaned_output = cleaned_output[len("```") :]
|
cleaned_output = cleaned_output[len("```"):]
|
||||||
if cleaned_output.endswith("```"):
|
if cleaned_output.endswith("```"):
|
||||||
cleaned_output = cleaned_output[: -len("```")]
|
cleaned_output = cleaned_output[: -len("```")]
|
||||||
cleaned_output = cleaned_output.strip()
|
cleaned_output = cleaned_output.strip()
|
||||||
@ -159,9 +158,9 @@ class BaseOutputParser(ABC):
|
|||||||
cleaned_output = self.__extract_json(cleaned_output)
|
cleaned_output = self.__extract_json(cleaned_output)
|
||||||
cleaned_output = (
|
cleaned_output = (
|
||||||
cleaned_output.strip()
|
cleaned_output.strip()
|
||||||
.replace("\n", " ")
|
.replace("\n", " ")
|
||||||
.replace("\\n", " ")
|
.replace("\\n", " ")
|
||||||
.replace("\\", " ")
|
.replace("\\", " ")
|
||||||
)
|
)
|
||||||
return cleaned_output
|
return cleaned_output
|
||||||
|
|
||||||
|
@ -15,6 +15,10 @@ class ExampleSelector(BaseModel, ABC):
|
|||||||
else:
|
else:
|
||||||
return self.__few_shot_context(count)
|
return self.__few_shot_context(count)
|
||||||
|
|
||||||
|
def __examples_text(self, used_examples):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __few_shot_context(self, count: int = 2) -> List[List]:
|
def __few_shot_context(self, count: int = 2) -> List[List]:
|
||||||
"""
|
"""
|
||||||
Use 2 or more examples, default 2
|
Use 2 or more examples, default 2
|
||||||
|
@ -39,6 +39,7 @@ from pilot.scene.base_message import (
|
|||||||
ViewMessage,
|
ViewMessage,
|
||||||
)
|
)
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
from pilot.server.llmserver import worker
|
||||||
|
|
||||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
@ -59,10 +60,10 @@ class BaseChat(ABC):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chat_mode,
|
chat_mode,
|
||||||
chat_session_id,
|
chat_session_id,
|
||||||
current_user_input,
|
current_user_input,
|
||||||
):
|
):
|
||||||
self.chat_session_id = chat_session_id
|
self.chat_session_id = chat_session_id
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
@ -95,7 +96,6 @@ class BaseChat(ABC):
|
|||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
@ -138,24 +138,17 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
show_info = ""
|
if not CFG.NEW_SERVER_MODE:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
stream=True,
|
stream=True,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
else:
|
||||||
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
|
return worker.generate_stream_gate(payload)
|
||||||
|
|
||||||
# for resp_text_trunck in ai_response_text:
|
|
||||||
# show_info = resp_text_trunck
|
|
||||||
# yield resp_text_trunck + "▌"
|
|
||||||
|
|
||||||
self.current_message.add_ai_message(show_info)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
logger.error("model response parase faild!" + str(e))
|
logger.error("model response parase faild!" + str(e))
|
||||||
@ -170,39 +163,28 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
### 走非流式的模型服务接口
|
rsp_str = ""
|
||||||
response = requests.post(
|
if not CFG.NEW_SERVER_MODE:
|
||||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
### 走非流式的模型服务接口
|
||||||
headers=headers,
|
rsp_str = requests.post(
|
||||||
json=payload,
|
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||||
timeout=120,
|
headers=headers,
|
||||||
)
|
json=payload,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
###TODO no stream mode need independent
|
||||||
|
output = worker.generate_stream_gate(payload)
|
||||||
|
for rsp in output:
|
||||||
|
rsp_str = str(rsp, "utf-8")
|
||||||
|
print("[TEST: output]:", rsp_str)
|
||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
ai_response_text = (
|
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
|
||||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
self.prompt_template.sep)
|
||||||
response, self.prompt_template.sep
|
### model result deal
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# ### MOCK
|
|
||||||
# ai_response_text = """{
|
|
||||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
|
||||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
|
||||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
|
||||||
# "command": {
|
|
||||||
# "name": "histogram-executor",
|
|
||||||
# "args": {
|
|
||||||
# "title": "订单城市分布柱状图",
|
|
||||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# }"""
|
|
||||||
|
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||||
|
|
||||||
|
|
||||||
result = self.do_action(prompt_define_response)
|
result = self.do_action(prompt_define_response)
|
||||||
|
|
||||||
if hasattr(prompt_define_response, "thoughts"):
|
if hasattr(prompt_define_response, "thoughts"):
|
||||||
@ -248,41 +230,42 @@ class BaseChat(ABC):
|
|||||||
### 处理历史信息
|
### 处理历史信息
|
||||||
if len(self.history_message) > self.chat_retention_rounds:
|
if len(self.history_message) > self.chat_retention_rounds:
|
||||||
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
||||||
for first_message in self.history_message[0].messages:
|
for first_message in self.history_message[0]['messages']:
|
||||||
if not isinstance(first_message, ViewMessage):
|
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
|
||||||
text += (
|
text += (
|
||||||
first_message.type
|
first_message['type']
|
||||||
+ ":"
|
+ ":"
|
||||||
+ first_message.content
|
+ first_message['data']['content']
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
|
||||||
index = self.chat_retention_rounds - 1
|
index = self.chat_retention_rounds - 1
|
||||||
for last_message in self.history_message[-index:].messages:
|
for round_conv in self.history_message[-index:]:
|
||||||
if not isinstance(last_message, ViewMessage):
|
for round_message in round_conv['messages']:
|
||||||
text += (
|
if not isinstance(round_message, ViewMessage):
|
||||||
last_message.type
|
text += (
|
||||||
+ ":"
|
round_message['type']
|
||||||
+ last_message.content
|
+ ":"
|
||||||
+ self.prompt_template.sep
|
+ round_message['data']['content']
|
||||||
)
|
+ self.prompt_template.sep
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
### 直接历史记录拼接
|
### 直接历史记录拼接
|
||||||
for conversation in self.history_message:
|
for conversation in self.history_message:
|
||||||
for message in conversation.messages:
|
for message in conversation['messages']:
|
||||||
if not isinstance(message, ViewMessage):
|
if not isinstance(message, ViewMessage):
|
||||||
text += (
|
text += (
|
||||||
message.type
|
message['type']
|
||||||
+ ":"
|
+ ":"
|
||||||
+ message.content
|
+ message['data']['content']
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
### current conversation
|
### current conversation
|
||||||
|
|
||||||
for now_message in self.current_message.messages:
|
for now_message in self.current_message.messages:
|
||||||
text += (
|
text += (
|
||||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
|
PROMPT_SCENE_DEFINE = None
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE = """
|
||||||
|
@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
|
|||||||
|
|
||||||
## Two examples are defined by default
|
## Two examples are defined by default
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
|
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}],
|
||||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
|
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
|
||||||
]
|
]
|
||||||
|
|
||||||
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
||||||
|
@ -9,10 +9,8 @@ from pilot.scene.chat_execution.example import example
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
|
||||||
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE = """
|
||||||
Goals:
|
Goals:
|
||||||
{input}
|
{input}
|
||||||
|
@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle
|
|||||||
|
|
||||||
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||||
|
|
||||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
PROMPT_SCENE_DEFINE = None
|
||||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ class OnceConversation:
|
|||||||
|
|
||||||
def _conversation_to_dic(once: OnceConversation) -> dict:
|
def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||||
start_str: str = ""
|
start_str: str = ""
|
||||||
if once.start_date:
|
if hasattr(once, 'start_date') and once.start_date:
|
||||||
if isinstance(once.start_date, datetime):
|
if isinstance(once.start_date, datetime):
|
||||||
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
|
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
else:
|
else:
|
||||||
|
74
pilot/server/dbgpt_server.py
Normal file
74
pilot/server/dbgpt_server.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import traceback
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import (
|
||||||
|
DATASETS_DIR,
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
LLM_MODEL_CONFIG,
|
||||||
|
LOGDIR,
|
||||||
|
)
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
from pilot.server.webserver_base import server_init
|
||||||
|
|
||||||
|
from fastapi import FastAPI, applications
|
||||||
|
from fastapi.openapi.docs import get_swagger_ui_html
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||||
|
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
|
|
||||||
|
|
||||||
|
def swagger_monkey_patch(*args, **kwargs):
|
||||||
|
return get_swagger_ui_html(
|
||||||
|
*args, **kwargs,
|
||||||
|
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||||
|
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
origins = ["*"]
|
||||||
|
|
||||||
|
# 添加跨域中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# app.mount("static", StaticFiles(directory="static"), name="static")
|
||||||
|
app.include_router(api_v1)
|
||||||
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||||
|
|
||||||
|
# old version server config
|
||||||
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
|
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||||
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
|
|
||||||
|
# init server config
|
||||||
|
args = parser.parse_args()
|
||||||
|
server_init(args)
|
||||||
|
CFG.NEW_SERVER_MODE = True
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=5000)
|
@ -27,14 +27,13 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelWorker:
|
class ModelWorker:
|
||||||
def __init__(self, model_path, model_name, device, num_gpus=1):
|
def __init__(self, model_path, model_name, device, num_gpus=1):
|
||||||
if model_path.endswith("/"):
|
if model_path.endswith("/"):
|
||||||
model_path = model_path[:-1]
|
model_path = model_path[:-1]
|
||||||
self.model_name = model_name or model_path.split("/")[-1]
|
self.model_name = model_name or model_path.split("/")[-1]
|
||||||
self.device = device
|
self.device = device
|
||||||
|
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
|
||||||
self.ml = ModelLoader(model_path=model_path)
|
self.ml = ModelLoader(model_path=model_path)
|
||||||
self.model, self.tokenizer = self.ml.loader(
|
self.model, self.tokenizer = self.ml.loader(
|
||||||
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
||||||
@ -42,11 +41,11 @@ class ModelWorker:
|
|||||||
|
|
||||||
if not isinstance(self.model, str):
|
if not isinstance(self.model, str):
|
||||||
if hasattr(self.model, "config") and hasattr(
|
if hasattr(self.model, "config") and hasattr(
|
||||||
self.model.config, "max_sequence_length"
|
self.model.config, "max_sequence_length"
|
||||||
):
|
):
|
||||||
self.context_len = self.model.config.max_sequence_length
|
self.context_len = self.model.config.max_sequence_length
|
||||||
elif hasattr(self.model, "config") and hasattr(
|
elif hasattr(self.model, "config") and hasattr(
|
||||||
self.model.config, "max_position_embeddings"
|
self.model.config, "max_position_embeddings"
|
||||||
):
|
):
|
||||||
self.context_len = self.model.config.max_position_embeddings
|
self.context_len = self.model.config.max_position_embeddings
|
||||||
|
|
||||||
@ -56,29 +55,32 @@ class ModelWorker:
|
|||||||
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
||||||
|
|
||||||
|
def start_check(self):
|
||||||
|
print("LLM Model Loading Success!")
|
||||||
|
|
||||||
def get_queue_length(self):
|
def get_queue_length(self):
|
||||||
if (
|
if (
|
||||||
model_semaphore is None
|
model_semaphore is None
|
||||||
or model_semaphore._value is None
|
or model_semaphore._value is None
|
||||||
or model_semaphore._waiters is None
|
or model_semaphore._waiters is None
|
||||||
):
|
):
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
CFG.LIMIT_MODEL_CONCURRENCY
|
CFG.LIMIT_MODEL_CONCURRENCY
|
||||||
- model_semaphore._value
|
- model_semaphore._value
|
||||||
+ len(model_semaphore._waiters)
|
+ len(model_semaphore._waiters)
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
try:
|
try:
|
||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||||
):
|
):
|
||||||
# Please do not open the output in production!
|
# Please do not open the output in production!
|
||||||
# The gpt4all thread shares stdout with the parent process,
|
# The gpt4all thread shares stdout with the parent process,
|
||||||
# and opening it may affect the frontend output.
|
# and opening it may affect the frontend output.
|
||||||
# print("output: ", output)
|
print("output: ", output)
|
||||||
ret = {
|
ret = {
|
||||||
"text": output,
|
"text": output,
|
||||||
"error_code": 0,
|
"error_code": 0,
|
||||||
@ -106,6 +108,7 @@ worker = ModelWorker(
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
from pilot.openapi.knowledge.knowledge_controller import router
|
from pilot.openapi.knowledge.knowledge_controller import router
|
||||||
|
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
origins = [
|
origins = [
|
||||||
@ -122,6 +125,7 @@ app.add_middleware(
|
|||||||
allow_headers=["*"]
|
allow_headers=["*"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
temperature: float
|
temperature: float
|
||||||
@ -177,10 +181,9 @@ def generate(prompt_request: PromptRequest):
|
|||||||
for rsp in output:
|
for rsp in output:
|
||||||
# rsp = rsp.decode("utf-8")
|
# rsp = rsp.decode("utf-8")
|
||||||
rsp_str = str(rsp, "utf-8")
|
rsp_str = str(rsp, "utf-8")
|
||||||
print("[TEST: output]:", rsp_str)
|
|
||||||
response.append(rsp_str)
|
response.append(rsp_str)
|
||||||
|
|
||||||
return {"response": rsp_str}
|
return rsp_str
|
||||||
|
|
||||||
|
|
||||||
@app.post("/embedding")
|
@app.post("/embedding")
|
||||||
|
@ -39,6 +39,8 @@ def server_init(args):
|
|||||||
# init config
|
# init config
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
|
from pilot.server.llmserver import worker
|
||||||
|
worker.start_check()
|
||||||
load_native_plugins(cfg)
|
load_native_plugins(cfg)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
async_db_summery()
|
async_db_summery()
|
||||||
|
Loading…
Reference in New Issue
Block a user