Merge branch 'llm_proxy' into dev

# Conflicts:
#	pilot/server/webserver.py
This commit is contained in:
yhjun1026
2023-05-30 17:21:40 +08:00
9 changed files with 119 additions and 8 deletions

View File

@@ -84,7 +84,30 @@ class CodeGenChatAdapter(BaseChatAdpter):
pass
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco """
def match(self, model_path: str):
return "guanaco" in model_path
def get_generate_stream_func(self):
# TODO
pass
class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "proxyllm" in model_path
def get_generate_stream_func(self):
from pilot.model.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)

View File

@@ -37,11 +37,12 @@ class ModelWorker:
self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
)
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048