Add support for falcon (#175)

We can set QUANTIZE_QLORA=True in the .env file to enable falcon's
quantization model
This commit is contained in:
magic.chen 2023-06-08 16:32:35 +08:00 committed by GitHub
commit 8b3d7b0ba7
5 changed files with 94 additions and 23 deletions

View File

@ -21,7 +21,7 @@ LLM_MODEL=vicuna-13b
MODEL_SERVER=http://127.0.0.1:8000
LIMIT_MODEL_CONCURRENCY=5
MAX_POSITION_EMBEDDINGS=4096
QUANTIZE_QLORA=True
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
# SMART_LLM_MODEL=vicuna-13b
@ -112,4 +112,4 @@ PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address
#*******************************************************************#
# ** SUMMARY_CONFIG
#*******************************************************************#
SUMMARY_CONFIG=FAST
SUMMARY_CONFIG=FAST

View File

@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"proxyllm": "proxyllm",
}
@ -42,7 +43,7 @@ LLM_MODEL_CONFIG = {
# Load model config
ISLOAD_8BIT = True
ISDEBUG = False
QLORA = os.getenv("QUANTIZE_QLORA") == "True"
VECTOR_SEARCH_TOP_K = 10
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")

View File

@ -1,12 +1,13 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from functools import cache
import torch
from typing import List
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from functools import cache
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
from pilot.configs.model_config import DEVICE
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
class BaseLLMAdaper:
"""The Base class for multi model, in our project.
@ -95,19 +96,32 @@ class GuanacoAdapter(BaseLLMAdaper):
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
)
return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""
class FalconAdapater(BaseLLMAdaper):
"""falcon Adapter"""
def match(self, model_path: str):
return "guanaco" in model_path
return "falcon" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
)
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if QLORA:
model = AutoModelForCausalLM.from_pretrained(
model_path,
load_in_4bit=True, #quantize
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
**from_pretrained_kwagrs
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
device_map={"": 0},
**from_pretrained_kwagrs
)
return model, tokenizer
@ -180,6 +194,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate

View File

@ -0,0 +1,54 @@
import torch
import copy
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@ -3,7 +3,6 @@
from functools import cache
from typing import List
from pilot.model.llm_out.vicuna_base_llm import generate_stream
@ -95,17 +94,18 @@ class GuanacoChatAdapter(BaseChatAdpter):
return guanaco_generate_stream
class GorillaChatAdapter(BaseChatAdpter):
class FalconChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
def match(self, model_path: str):
return "gorilla" in model_path
return "falcon" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.gorilla_llm import generate_stream
return generate_stream
from pilot.model.llm_out.falcon_llm import falcon_generate_output
return falcon_generate_output
class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "proxyllm" in model_path
@ -119,6 +119,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_adapters(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter)
# Proxy model for test and develop, it's cheap for us now.