Add support for falcon

This commit is contained in:
zhanghy-sketchzh 2023-06-08 12:17:13 +08:00
parent b24d2fe0c0
commit b357fd9d0c
4 changed files with 91 additions and 15 deletions

View File

@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -1,10 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from functools import cache
import torch
from typing import List from typing import List
from functools import cache
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
@ -95,19 +95,30 @@ class GuanacoAdapter(BaseLLMAdaper):
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
) )
return model, tokenizer return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper): class FalconAdapater(BaseLLMAdaper):
"""TODO Support guanaco""" """falcon Adapter"""
def match(self, model_path: str): def match(self, model_path: str):
return "guanaco" in model_path return "falcon" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwagrs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=False,
)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs model_path,
) #load_in_4bit=True, #quantize
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
**from_pretrained_kwagrs
)
return model, tokenizer return model, tokenizer
@ -166,6 +177,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later # just for test, remove this later

View File

@ -0,0 +1,54 @@
import torch
import copy
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@ -3,7 +3,6 @@
from functools import cache from functools import cache
from typing import List from typing import List
from pilot.model.llm_out.vicuna_base_llm import generate_stream from pilot.model.llm_out.vicuna_base_llm import generate_stream
@ -95,7 +94,17 @@ class GuanacoChatAdapter(BaseChatAdpter):
return guanaco_generate_stream return guanaco_generate_stream
class FalconChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
def match(self, model_path: str):
return "falcon" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.falcon_llm import falcon_generate_output
return falcon_generate_output
class ProxyllmChatAdapter(BaseChatAdpter): class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "proxyllm" in model_path return "proxyllm" in model_path
@ -109,7 +118,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_adapters(FalconChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(ProxyllmChatAdapter)