diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index c8febba9b..a0a4dd514 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -2,21 +2,18 @@ # -*- coding: utf-8 -*- import time -from pilot.configs.config import Config - from pilot.model.proxy.llms.chatgpt import chatgpt_generate_stream from pilot.model.proxy.llms.bard import bard_generate_stream from pilot.model.proxy.llms.claude import claude_generate_stream from pilot.model.proxy.llms.wenxin import wenxin_generate_stream from pilot.model.proxy.llms.tongyi import tongyi_generate_stream from pilot.model.proxy.llms.zhipu import zhipu_generate_stream - -# from pilot.model.proxy.llms.gpt4 import gpt4_generate_stream - -CFG = Config() +from pilot.model.proxy.llms.proxy_model import ProxyModel -def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048): +def proxyllm_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): generator_mapping = { "proxyllm": chatgpt_generate_stream, "chatgpt_proxyllm": chatgpt_generate_stream, @@ -27,10 +24,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) "tongyi_proxyllm": tongyi_generate_stream, "zhipu_proxyllm": zhipu_generate_stream, } - - default_error_message = f"{CFG.LLM_MODEL} LLM is not supported" + model_params = model.get_params() + model_name = model_params.model_name + default_error_message = f"{model_name} LLM is not supported" generator_function = generator_mapping.get( - CFG.LLM_MODEL, lambda: default_error_message + model_name, lambda: default_error_message ) yield from generator_function(model, tokenizer, params, device, context_len)