From ff9179bc5608e62cbf98a0a05accb91962f1cf1c Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Sun, 4 Jun 2023 19:37:56 +0800 Subject: [PATCH 1/2] test guanaco_llm --- pilot/model/guanaco_llm.py | 54 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 pilot/model/guanaco_llm.py diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py new file mode 100644 index 000000000..627179dc5 --- /dev/null +++ b/pilot/model/guanaco_llm.py @@ -0,0 +1,54 @@ +import torch +from threading import Thread +from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria + + +def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): + """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + tokenizer.bos_token_id = 1 + print(params) + stop = params.get("stop", "###") + prompt = params["prompt"] + query = prompt + print("Query Message: ", query) + + input_ids = tokenizer(query, return_tensors="pt").input_ids + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer( + tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True + ) + + tokenizer.bos_token_id = 1 + stop_token_ids = [0] + + class StopOnTokens(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + for stop_id in stop_token_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + stop = StopOnTokens() + + generate_kwargs = dict( + input_ids=input_ids, + max_new_tokens=512, + temperature=1.0, + do_sample=True, + top_k=1, + streamer=streamer, + repetition_penalty=1.7, + stopping_criteria=StoppingCriteriaList([stop]), + ) + + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() + + out = "" + for new_text in streamer: + out += new_text + yield new_text + return out \ No newline at end of file From 9d6acfb9cd043c53ad78e33fda0966f5cfa363ee Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Sun, 4 Jun 2023 21:35:25 +0800 Subject: [PATCH 2/2] make a guanaco stream generate method --- pilot/configs/model_config.py | 3 ++- pilot/model/adapter.py | 16 ++++++++++++++++ .../{guanaco_llm.py => guanaco_stream_llm.py} | 3 ++- pilot/server/chat_adapter.py | 11 +++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) rename pilot/model/{guanaco_llm.py => guanaco_stream_llm.py} (94%) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 09b428327..2a9d48346 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -34,7 +34,8 @@ LLM_MODEL_CONFIG = { "chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), - "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), + "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"0), + "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged" } # Load model config diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 83fad3d5f..208d9f438 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -82,6 +82,21 @@ class ChatGLMAdapater(BaseLLMAdaper): ) return model, tokenizer + +class GuanacoAdapter(BaseLLMAdaper): + """TODO Support guanaco""" + + def match(self, model_path: str): + return "guanaco" in model_path + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = LlamaTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs + ) + return model, tokenizer + + class CodeGenAdapter(BaseLLMAdaper): pass @@ -122,6 +137,7 @@ class GPT4AllAdapter(BaseLLMAdaper): register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) +register_llm_model_adapters(GuanacoAdapter) # TODO Default support vicuna, other model need to tests and Evaluate register_llm_model_adapters(BaseLLMAdaper) diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_stream_llm.py similarity index 94% rename from pilot/model/guanaco_llm.py rename to pilot/model/guanaco_stream_llm.py index 627179dc5..8f72699d1 100644 --- a/pilot/model/guanaco_llm.py +++ b/pilot/model/guanaco_stream_llm.py @@ -3,7 +3,8 @@ from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria -def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): + +def guanaco_stream_generate_output(model, tokenizer, params, device, context_len=2048): """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" tokenizer.bos_token_id = 1 print(params) diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index b7e102be3..400b27697 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -59,6 +59,16 @@ class ChatGLMChatAdapter(BaseChatAdpter): return chatglm_generate_stream +class GuanacoChatAdapter(BaseChatAdpter): + """Model chat adapter for Guanaco""" + + def match(self, model_path: str): + return "guanaco" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.guanaco_stream_llm import guanaco_stream_generate_output + + return guanaco_generate_output class CodeT5ChatAdapter(BaseChatAdpter): @@ -86,5 +96,6 @@ class CodeGenChatAdapter(BaseChatAdpter): register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) +register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(BaseChatAdpter)