mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 03:44:14 +00:00
feat:multi-llm select
This commit is contained in:
parent
1a2bf96767
commit
63ad842daa
@ -215,8 +215,8 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/mode/params/file/load")
|
@router.post("/v1/chat/mode/params/file/load")
|
||||||
async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)):
|
async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)):
|
||||||
print(f"params_load: {conv_uid},{chat_mode}")
|
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
||||||
try:
|
try:
|
||||||
if doc_file:
|
if doc_file:
|
||||||
## file save
|
## file save
|
||||||
@ -235,7 +235,7 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
|
|||||||
)
|
)
|
||||||
## chat prepare
|
## chat prepare
|
||||||
dialogue = ConversationVo(
|
dialogue = ConversationVo(
|
||||||
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename
|
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name
|
||||||
)
|
)
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = get_chat_instance(dialogue)
|
||||||
resp = await chat.prepare()
|
resp = await chat.prepare()
|
||||||
@ -287,15 +287,17 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
|||||||
|
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"chat_session_id": dialogue.conv_uid,
|
"chat_session_id": dialogue.conv_uid,
|
||||||
"user_input": dialogue.user_input,
|
"current_user_input": dialogue.user_input,
|
||||||
"select_param": dialogue.select_param,
|
"select_param": dialogue.select_param,
|
||||||
|
"model_name": dialogue.model_name
|
||||||
}
|
}
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param})
|
||||||
return chat
|
return chat
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/prepare")
|
@router.post("/v1/chat/prepare")
|
||||||
async def chat_prepare(dialogue: ConversationVo = Body()):
|
async def chat_prepare(dialogue: ConversationVo = Body()):
|
||||||
|
# dialogue.model_name = CFG.LLM_MODEL
|
||||||
logger.info(f"chat_prepare:{dialogue}")
|
logger.info(f"chat_prepare:{dialogue}")
|
||||||
## check conv_uid
|
## check conv_uid
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = get_chat_instance(dialogue)
|
||||||
@ -307,7 +309,8 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
|
|||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
# dialogue.model_name = CFG.LLM_MODEL
|
||||||
|
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}")
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = get_chat_instance(dialogue)
|
||||||
# background_tasks = BackgroundTasks()
|
# background_tasks = BackgroundTasks()
|
||||||
# background_tasks.add_task(release_model_semaphore)
|
# background_tasks.add_task(release_model_semaphore)
|
||||||
@ -332,6 +335,28 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/model/types")
|
||||||
|
async def model_types():
|
||||||
|
print(f"/controller/model/types")
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true"
|
||||||
|
)
|
||||||
|
types = set()
|
||||||
|
if response.status_code == 200:
|
||||||
|
models = json.loads(response.text)
|
||||||
|
for model in models:
|
||||||
|
worker_type = model["model_name"].split("@")[1]
|
||||||
|
if worker_type == "llm":
|
||||||
|
types.add(model["model_name"].split("@")[0])
|
||||||
|
return Result.succ(list(types))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"controller model types error {e}")
|
||||||
|
|
||||||
|
|
||||||
async def no_stream_generator(chat):
|
async def no_stream_generator(chat):
|
||||||
msg = await chat.nostream_call()
|
msg = await chat.nostream_call()
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
|
Loading…
Reference in New Issue
Block a user