mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 06:30:02 +00:00
feature: add model server proxy
This commit is contained in:
@@ -84,7 +84,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
|
||||
pass
|
||||
|
||||
|
||||
class GuanacoAdapter(BaseChatAdpter):
|
||||
class GuanacoChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for Guanaco """
|
||||
|
||||
def match(self, model_path: str):
|
||||
@@ -94,7 +94,20 @@ class GuanacoAdapter(BaseChatAdpter):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "proxyllm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.proxy_llm import proxyllm_generate_stream
|
||||
return proxyllm_generate_stream
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
|
||||
register_llm_model_chat_adapter(BaseChatAdpter)
|
||||
|
@@ -37,11 +37,12 @@ class ModelWorker:
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
||||
)
|
||||
|
||||
if hasattr(self.model.config, "max_sequence_length"):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model.config, "max_position_embeddings"):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
if not isinstance(self.model, str):
|
||||
if hasattr(self.model.config, "max_sequence_length"):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model.config, "max_position_embeddings"):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
else:
|
||||
self.context_len = 2048
|
||||
|
@@ -434,6 +434,7 @@ def http_bot(
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
print("****************:",data)
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
|
Reference in New Issue
Block a user