mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
feat:multi-llm select
This commit is contained in:
parent
1a2bf96767
commit
63ad842daa
@ -215,8 +215,8 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
||||
|
||||
|
||||
@router.post("/v1/chat/mode/params/file/load")
|
||||
async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)):
|
||||
print(f"params_load: {conv_uid},{chat_mode}")
|
||||
async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)):
|
||||
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
||||
try:
|
||||
if doc_file:
|
||||
## file save
|
||||
@ -235,7 +235,7 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
|
||||
)
|
||||
## chat prepare
|
||||
dialogue = ConversationVo(
|
||||
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename
|
||||
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name
|
||||
)
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
resp = await chat.prepare()
|
||||
@ -287,15 +287,17 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
"user_input": dialogue.user_input,
|
||||
"current_user_input": dialogue.user_input,
|
||||
"select_param": dialogue.select_param,
|
||||
"model_name": dialogue.model_name
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param})
|
||||
return chat
|
||||
|
||||
|
||||
@router.post("/v1/chat/prepare")
|
||||
async def chat_prepare(dialogue: ConversationVo = Body()):
|
||||
# dialogue.model_name = CFG.LLM_MODEL
|
||||
logger.info(f"chat_prepare:{dialogue}")
|
||||
## check conv_uid
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
@ -307,7 +309,8 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||
# dialogue.model_name = CFG.LLM_MODEL
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}")
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
# background_tasks = BackgroundTasks()
|
||||
# background_tasks.add_task(release_model_semaphore)
|
||||
@ -332,6 +335,28 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1/model/types")
|
||||
async def model_types():
|
||||
print(f"/controller/model/types")
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.get(
|
||||
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true"
|
||||
)
|
||||
types = set()
|
||||
if response.status_code == 200:
|
||||
models = json.loads(response.text)
|
||||
for model in models:
|
||||
worker_type = model["model_name"].split("@")[1]
|
||||
if worker_type == "llm":
|
||||
types.add(model["model_name"].split("@")[0])
|
||||
return Result.succ(list(types))
|
||||
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"controller model types error {e}")
|
||||
|
||||
|
||||
async def no_stream_generator(chat):
|
||||
msg = await chat.nostream_call()
|
||||
msg = msg.replace("\n", "\\n")
|
||||
|
Loading…
Reference in New Issue
Block a user