mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-24 11:03:19 +00:00
Add quantize_qlora support for falcon
This commit is contained in:
parent
b357fd9d0c
commit
bb9081e00f
@ -21,7 +21,7 @@ LLM_MODEL=vicuna-13b
|
|||||||
MODEL_SERVER=http://127.0.0.1:8000
|
MODEL_SERVER=http://127.0.0.1:8000
|
||||||
LIMIT_MODEL_CONCURRENCY=5
|
LIMIT_MODEL_CONCURRENCY=5
|
||||||
MAX_POSITION_EMBEDDINGS=4096
|
MAX_POSITION_EMBEDDINGS=4096
|
||||||
|
QUANTIZE_QLORA=True
|
||||||
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
||||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||||
# SMART_LLM_MODEL=vicuna-13b
|
# SMART_LLM_MODEL=vicuna-13b
|
||||||
|
@ -42,7 +42,7 @@ LLM_MODEL_CONFIG = {
|
|||||||
# Load model config
|
# Load model config
|
||||||
ISLOAD_8BIT = True
|
ISLOAD_8BIT = True
|
||||||
ISDEBUG = False
|
ISDEBUG = False
|
||||||
|
QLORA = os.getenv("QUANTIZE_QLORA") == "True"
|
||||||
|
|
||||||
VECTOR_SEARCH_TOP_K = 10
|
VECTOR_SEARCH_TOP_K = 10
|
||||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
||||||
|
@ -7,6 +7,7 @@ from functools import cache
|
|||||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
|
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import DEVICE
|
||||||
|
|
||||||
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
|
||||||
|
|
||||||
class BaseLLMAdaper:
|
class BaseLLMAdaper:
|
||||||
"""The Base class for multi model, in our project.
|
"""The Base class for multi model, in our project.
|
||||||
@ -105,20 +106,22 @@ class FalconAdapater(BaseLLMAdaper):
|
|||||||
|
|
||||||
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
|
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||||
bnb_config = BitsAndBytesConfig(
|
if QLORA == True:
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_quant_type="nf4",
|
|
||||||
bnb_4bit_compute_dtype="bfloat16",
|
|
||||||
bnb_4bit_use_double_quant=False,
|
|
||||||
)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
#load_in_4bit=True, #quantize
|
load_in_4bit=True, #quantize
|
||||||
quantization_config=bnb_config,
|
quantization_config=bnb_config,
|
||||||
device_map={"": 0},
|
device_map={"": 0},
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
**from_pretrained_kwagrs
|
**from_pretrained_kwagrs
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map={"": 0},
|
||||||
|
**from_pretrained_kwagrs
|
||||||
|
)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user