From b357fd9d0c4b42e576bfcc686470691a1aacf0a9 Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Thu, 8 Jun 2023 12:17:13 +0800 Subject: [PATCH] Add support for falcon --- pilot/configs/model_config.py | 1 + pilot/model/adapter.py | 38 ++++++++++++++-------- pilot/model/llm_out/falcon_llm.py | 54 +++++++++++++++++++++++++++++++ pilot/server/chat_adapter.py | 13 ++++++-- 4 files changed, 91 insertions(+), 15 deletions(-) create mode 100644 pilot/model/llm_out/falcon_llm.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 759245864..adfc62f1a 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = { "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), + "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 05c55fa74..f8c65af77 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from functools import cache + +import torch from typing import List - -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer - +from functools import cache +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from pilot.configs.model_config import DEVICE @@ -95,19 +95,30 @@ class GuanacoAdapter(BaseLLMAdaper): model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs ) return model, tokenizer - - -class GuanacoAdapter(BaseLLMAdaper): - """TODO Support guanaco""" + + +class FalconAdapater(BaseLLMAdaper): + """falcon Adapter""" def match(self, model_path: str): - return "guanaco" in model_path + return "falcon" in model_path - def loader(self, model_path: str, from_pretrained_kwargs: dict): - tokenizer = LlamaTokenizer.from_pretrained(model_path) + def loader(self, model_path: str, from_pretrained_kwagrs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="bfloat16", + bnb_4bit_use_double_quant=False, + ) model = AutoModelForCausalLM.from_pretrained( - model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs - ) + model_path, + #load_in_4bit=True, #quantize + quantization_config=bnb_config, + device_map={"": 0}, + trust_remote_code=True, + **from_pretrained_kwagrs + ) return model, tokenizer @@ -166,6 +177,7 @@ class ProxyllmAdapter(BaseLLMAdaper): register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) +register_llm_model_adapters(FalconAdapater) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/llm_out/falcon_llm.py b/pilot/model/llm_out/falcon_llm.py new file mode 100644 index 000000000..f4cb53eff --- /dev/null +++ b/pilot/model/llm_out/falcon_llm.py @@ -0,0 +1,54 @@ +import torch +import copy +from threading import Thread +from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria + + +def falcon_generate_output(model, tokenizer, params, device, context_len=2048): + """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + tokenizer.bos_token_id = 1 + print(params) + stop = params.get("stop", "###") + prompt = params["prompt"] + query = prompt + print("Query Message: ", query) + + input_ids = tokenizer(query, return_tensors="pt").input_ids + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer( + tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True + ) + + tokenizer.bos_token_id = 1 + stop_token_ids = [0] + + class StopOnTokens(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + for stop_id in stop_token_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + stop = StopOnTokens() + + generate_kwargs = dict( + input_ids=input_ids, + max_new_tokens=512, + temperature=1.0, + do_sample=True, + top_k=1, + streamer=streamer, + repetition_penalty=1.7, + stopping_criteria=StoppingCriteriaList([stop]), + ) + + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() + + out = "" + for new_text in streamer: + out += new_text + yield out diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index d4ab8ae09..1db3beee7 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -3,7 +3,6 @@ from functools import cache from typing import List - from pilot.model.llm_out.vicuna_base_llm import generate_stream @@ -95,7 +94,17 @@ class GuanacoChatAdapter(BaseChatAdpter): return guanaco_generate_stream +class FalconChatAdapter(BaseChatAdpter): + """Model chat adapter for Guanaco""" + def match(self, model_path: str): + return "falcon" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.falcon_llm import falcon_generate_output + + return falcon_generate_output + class ProxyllmChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "proxyllm" in model_path @@ -109,7 +118,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) - +register_llm_model_adapters(FalconChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter)