mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 00:28:00 +00:00
Add support for falcon (#175)
We can set QUANTIZE_QLORA=True in the .env file to enable falcon's quantization model
This commit is contained in:
commit
8b3d7b0ba7
@ -21,7 +21,7 @@ LLM_MODEL=vicuna-13b
|
||||
MODEL_SERVER=http://127.0.0.1:8000
|
||||
LIMIT_MODEL_CONCURRENCY=5
|
||||
MAX_POSITION_EMBEDDINGS=4096
|
||||
|
||||
QUANTIZE_QLORA=True
|
||||
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||
# SMART_LLM_MODEL=vicuna-13b
|
||||
@ -112,4 +112,4 @@ PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address
|
||||
#*******************************************************************#
|
||||
# ** SUMMARY_CONFIG
|
||||
#*******************************************************************#
|
||||
SUMMARY_CONFIG=FAST
|
||||
SUMMARY_CONFIG=FAST
|
||||
|
@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
|
||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||
"proxyllm": "proxyllm",
|
||||
}
|
||||
@ -42,7 +43,7 @@ LLM_MODEL_CONFIG = {
|
||||
# Load model config
|
||||
ISLOAD_8BIT = True
|
||||
ISDEBUG = False
|
||||
|
||||
QLORA = os.getenv("QUANTIZE_QLORA") == "True"
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 10
|
||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
||||
|
@ -1,12 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from functools import cache
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from functools import cache
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
|
||||
from pilot.configs.model_config import DEVICE
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
|
||||
|
||||
class BaseLLMAdaper:
|
||||
"""The Base class for multi model, in our project.
|
||||
@ -95,19 +96,32 @@ class GuanacoAdapter(BaseLLMAdaper):
|
||||
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class GuanacoAdapter(BaseLLMAdaper):
|
||||
"""TODO Support guanaco"""
|
||||
|
||||
|
||||
class FalconAdapater(BaseLLMAdaper):
|
||||
"""falcon Adapter"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "guanaco" in model_path
|
||||
return "falcon" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
|
||||
)
|
||||
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
if QLORA:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
load_in_4bit=True, #quantize
|
||||
quantization_config=bnb_config,
|
||||
device_map={"": 0},
|
||||
trust_remote_code=True,
|
||||
**from_pretrained_kwagrs
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
device_map={"": 0},
|
||||
**from_pretrained_kwagrs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@ -180,6 +194,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
register_llm_model_adapters(ChatGLMAdapater)
|
||||
register_llm_model_adapters(GuanacoAdapter)
|
||||
register_llm_model_adapters(FalconAdapater)
|
||||
register_llm_model_adapters(GorillaAdapter)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
|
54
pilot/model/llm_out/falcon_llm.py
Normal file
54
pilot/model/llm_out/falcon_llm.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
import copy
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
|
||||
|
||||
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
tokenizer.bos_token_id = 1
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=512,
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
t.start()
|
||||
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
@ -3,7 +3,6 @@
|
||||
|
||||
from functools import cache
|
||||
from typing import List
|
||||
|
||||
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
||||
|
||||
|
||||
@ -95,17 +94,18 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
||||
|
||||
return guanaco_generate_stream
|
||||
|
||||
class GorillaChatAdapter(BaseChatAdpter):
|
||||
|
||||
class FalconChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for Guanaco"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "gorilla" in model_path
|
||||
return "falcon" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.llm_out.gorilla_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
from pilot.model.llm_out.falcon_llm import falcon_generate_output
|
||||
|
||||
return falcon_generate_output
|
||||
|
||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "proxyllm" in model_path
|
||||
@ -119,6 +119,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||
register_llm_model_adapters(FalconChatAdapter)
|
||||
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
|
Loading…
Reference in New Issue
Block a user