diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 9fe6207c1..6fd6143ff 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -118,6 +118,8 @@ class ModelLoader(metaclass=Singleton): model.to(self.device) except ValueError: pass + except AttributeError: + pass if debug: print(model) diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 63d922672..8db61d09f 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -4,7 +4,7 @@ from functools import cache from typing import List -from pilot.model.inference import generate_stream +from pilot.model.llm_out.vicuna_base_llm import generate_stream class BaseChatAdpter: @@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter): return "chatglm" in model_path def get_generate_stream_func(self): - from pilot.model.chatglm_llm import chatglm_generate_stream + from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream return chatglm_generate_stream @@ -85,14 +85,14 @@ class CodeGenChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco """ - + """Model chat adapter for Guanaco""" + def match(self, model_path: str): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.guanaco_llm import guanaco_generate_stream - + from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream + return guanaco_generate_stream @@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): return "proxyllm" in model_path def get_generate_stream_func(self): - from pilot.model.proxy_llm import proxyllm_generate_stream + from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream return proxyllm_generate_stream