feature: add model server proxy

This commit is contained in:
csunny
2023-05-30 17:16:29 +08:00
parent 4c60ab1ea2
commit ea334b172e
9 changed files with 106 additions and 8 deletions

View File

@@ -84,7 +84,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
pass
class GuanacoAdapter(BaseChatAdpter):
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco """
def match(self, model_path: str):
@@ -94,7 +94,20 @@ class GuanacoAdapter(BaseChatAdpter):
# TODO
pass
class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "proxyllm" in model_path
def get_generate_stream_func(self):
from pilot.model.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)

View File

@@ -37,11 +37,12 @@ class ModelWorker:
self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
)
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048

View File

@@ -434,6 +434,7 @@ def http_bot(
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
print("****************:",data)
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else: