fix(model): fix proxy llm with fix model from CONFIG

This commit is contained in:
FangYin Cheng 2023-09-08 19:46:50 +08:00
parent a8846c40aa
commit 4744d2161d

View File

@ -2,21 +2,18 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import time import time
from pilot.configs.config import Config
from pilot.model.proxy.llms.chatgpt import chatgpt_generate_stream from pilot.model.proxy.llms.chatgpt import chatgpt_generate_stream
from pilot.model.proxy.llms.bard import bard_generate_stream from pilot.model.proxy.llms.bard import bard_generate_stream
from pilot.model.proxy.llms.claude import claude_generate_stream from pilot.model.proxy.llms.claude import claude_generate_stream
from pilot.model.proxy.llms.wenxin import wenxin_generate_stream from pilot.model.proxy.llms.wenxin import wenxin_generate_stream
from pilot.model.proxy.llms.tongyi import tongyi_generate_stream from pilot.model.proxy.llms.tongyi import tongyi_generate_stream
from pilot.model.proxy.llms.zhipu import zhipu_generate_stream from pilot.model.proxy.llms.zhipu import zhipu_generate_stream
from pilot.model.proxy.llms.proxy_model import ProxyModel
# from pilot.model.proxy.llms.gpt4 import gpt4_generate_stream
CFG = Config()
def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048): def proxyllm_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
generator_mapping = { generator_mapping = {
"proxyllm": chatgpt_generate_stream, "proxyllm": chatgpt_generate_stream,
"chatgpt_proxyllm": chatgpt_generate_stream, "chatgpt_proxyllm": chatgpt_generate_stream,
@ -27,10 +24,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
"tongyi_proxyllm": tongyi_generate_stream, "tongyi_proxyllm": tongyi_generate_stream,
"zhipu_proxyllm": zhipu_generate_stream, "zhipu_proxyllm": zhipu_generate_stream,
} }
model_params = model.get_params()
default_error_message = f"{CFG.LLM_MODEL} LLM is not supported" model_name = model_params.model_name
default_error_message = f"{model_name} LLM is not supported"
generator_function = generator_mapping.get( generator_function = generator_mapping.get(
CFG.LLM_MODEL, lambda: default_error_message model_name, lambda: default_error_message
) )
yield from generator_function(model, tokenizer, params, device, context_len) yield from generator_function(model, tokenizer, params, device, context_len)