Add support for falcon (#175)

We can set QUANTIZE_QLORA=True in the .env file to enable falcon's
quantization model
This commit is contained in:
magic.chen 2023-06-08 16:32:35 +08:00 committed by GitHub
commit 8b3d7b0ba7
5 changed files with 94 additions and 23 deletions

View File

@ -21,7 +21,7 @@ LLM_MODEL=vicuna-13b
MODEL_SERVER=http://127.0.0.1:8000 MODEL_SERVER=http://127.0.0.1:8000
LIMIT_MODEL_CONCURRENCY=5 LIMIT_MODEL_CONCURRENCY=5
MAX_POSITION_EMBEDDINGS=4096 MAX_POSITION_EMBEDDINGS=4096
QUANTIZE_QLORA=True
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b) ## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b) ## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
# SMART_LLM_MODEL=vicuna-13b # SMART_LLM_MODEL=vicuna-13b

View File

@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }
@ -42,7 +43,7 @@ LLM_MODEL_CONFIG = {
# Load model config # Load model config
ISLOAD_8BIT = True ISLOAD_8BIT = True
ISDEBUG = False ISDEBUG = False
QLORA = os.getenv("QUANTIZE_QLORA") == "True"
VECTOR_SEARCH_TOP_K = 10 VECTOR_SEARCH_TOP_K = 10
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")

View File

@ -1,12 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from functools import cache
import torch
from typing import List from typing import List
from functools import cache
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
class BaseLLMAdaper: class BaseLLMAdaper:
"""The Base class for multi model, in our project. """The Base class for multi model, in our project.
@ -97,16 +98,29 @@ class GuanacoAdapter(BaseLLMAdaper):
return model, tokenizer return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper): class FalconAdapater(BaseLLMAdaper):
"""TODO Support guanaco""" """falcon Adapter"""
def match(self, model_path: str): def match(self, model_path: str):
return "guanaco" in model_path return "falcon" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwagrs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if QLORA:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs model_path,
load_in_4bit=True, #quantize
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
**from_pretrained_kwagrs
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
device_map={"": 0},
**from_pretrained_kwagrs
) )
return model, tokenizer return model, tokenizer
@ -180,6 +194,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GorillaAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate

View File

@ -0,0 +1,54 @@
import torch
import copy
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@ -3,7 +3,6 @@
from functools import cache from functools import cache
from typing import List from typing import List
from pilot.model.llm_out.vicuna_base_llm import generate_stream from pilot.model.llm_out.vicuna_base_llm import generate_stream
@ -95,16 +94,17 @@ class GuanacoChatAdapter(BaseChatAdpter):
return guanaco_generate_stream return guanaco_generate_stream
class GorillaChatAdapter(BaseChatAdpter):
class FalconChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco""" """Model chat adapter for Guanaco"""
def match(self, model_path: str): def match(self, model_path: str):
return "gorilla" in model_path return "falcon" in model_path
def get_generate_stream_func(self): def get_generate_stream_func(self):
from pilot.model.llm_out.gorilla_llm import generate_stream from pilot.model.llm_out.falcon_llm import falcon_generate_output
return generate_stream return falcon_generate_output
class ProxyllmChatAdapter(BaseChatAdpter): class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
@ -119,6 +119,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_adapters(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.