Add support for falcon

This commit is contained in:
zhanghy-sketchzh 2023-06-08 12:17:13 +08:00
parent b24d2fe0c0
commit b357fd9d0c
4 changed files with 91 additions and 15 deletions

View File

@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"proxyllm": "proxyllm",
}

View File

@ -1,10 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from functools import cache
import torch
from typing import List
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from functools import cache
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
from pilot.configs.model_config import DEVICE
@ -97,16 +97,27 @@ class GuanacoAdapter(BaseLLMAdaper):
return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""
class FalconAdapater(BaseLLMAdaper):
"""falcon Adapter"""
def match(self, model_path: str):
return "guanaco" in model_path
return "falcon" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path)
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=False,
)
model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
model_path,
#load_in_4bit=True, #quantize
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
**from_pretrained_kwagrs
)
return model, tokenizer
@ -166,6 +177,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later

View File

@ -0,0 +1,54 @@
import torch
import copy
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@ -3,7 +3,6 @@
from functools import cache
from typing import List
from pilot.model.llm_out.vicuna_base_llm import generate_stream
@ -95,6 +94,16 @@ class GuanacoChatAdapter(BaseChatAdpter):
return guanaco_generate_stream
class FalconChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
def match(self, model_path: str):
return "falcon" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.falcon_llm import falcon_generate_output
return falcon_generate_output
class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
@ -109,7 +118,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_adapters(FalconChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)