make a guanaco stream generate method

This commit is contained in:
zhanghy-sketchzh 2023-06-04 21:35:25 +08:00
parent ff9179bc56
commit 9d6acfb9cd
4 changed files with 31 additions and 2 deletions

View File

@ -34,7 +34,8 @@ LLM_MODEL_CONFIG = {
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"0),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"
}
# Load model config

View File

@ -82,6 +82,21 @@ class ChatGLMAdapater(BaseLLMAdaper):
)
return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""
def match(self, model_path: str):
return "guanaco" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
)
return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper):
pass
@ -122,6 +137,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate
register_llm_model_adapters(BaseLLMAdaper)

View File

@ -3,7 +3,8 @@ from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
def guanaco_stream_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)

View File

@ -59,6 +59,16 @@ class ChatGLMChatAdapter(BaseChatAdpter):
return chatglm_generate_stream
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
def match(self, model_path: str):
return "guanaco" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.guanaco_stream_llm import guanaco_stream_generate_output
return guanaco_generate_output
class CodeT5ChatAdapter(BaseChatAdpter):
@ -86,5 +96,6 @@ class CodeGenChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)