diff --git a/.env.template b/.env.template index 234b12738..2fb5ff649 100644 --- a/.env.template +++ b/.env.template @@ -21,7 +21,7 @@ LLM_MODEL=vicuna-13b MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY=5 MAX_POSITION_EMBEDDINGS=4096 - +QUANTIZE_QLORA=True ## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b) ## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b) # SMART_LLM_MODEL=vicuna-13b @@ -112,4 +112,4 @@ PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address #*******************************************************************# # ** SUMMARY_CONFIG #*******************************************************************# -SUMMARY_CONFIG=FAST \ No newline at end of file +SUMMARY_CONFIG=FAST diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index adfc62f1a..4f3e635d6 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -42,7 +42,7 @@ LLM_MODEL_CONFIG = { # Load model config ISLOAD_8BIT = True ISDEBUG = False - +QLORA = os.getenv("QUANTIZE_QLORA") == "True" VECTOR_SEARCH_TOP_K = 10 VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index f8c65af77..76eb51f26 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -7,6 +7,7 @@ from functools import cache from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from pilot.configs.model_config import DEVICE +bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) class BaseLLMAdaper: """The Base class for multi model, in our project. @@ -105,19 +106,21 @@ class FalconAdapater(BaseLLMAdaper): def loader(self, model_path: str, from_pretrained_kwagrs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype="bfloat16", - bnb_4bit_use_double_quant=False, + if QLORA == True: + model = AutoModelForCausalLM.from_pretrained( + model_path, + load_in_4bit=True, #quantize + quantization_config=bnb_config, + device_map={"": 0}, + trust_remote_code=True, + **from_pretrained_kwagrs ) - model = AutoModelForCausalLM.from_pretrained( - model_path, - #load_in_4bit=True, #quantize - quantization_config=bnb_config, - device_map={"": 0}, - trust_remote_code=True, - **from_pretrained_kwagrs + else: + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + device_map={"": 0}, + **from_pretrained_kwagrs ) return model, tokenizer