diff --git a/pilot/configs/config.py b/pilot/configs/config.py index c4458eaf7..d14d16808 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -146,6 +146,9 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + # QLoRA + self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") + ### EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500)) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index c914195d8..9da5cbd04 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -6,8 +6,10 @@ from typing import List from functools import cache from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from pilot.configs.model_config import DEVICE +from pilot.configs.config import Config bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) +CFG = Config() class BaseLLMAdaper: """The Base class for multi model, in our project. @@ -106,7 +108,8 @@ class FalconAdapater(BaseLLMAdaper): def loader(self, model_path: str, from_pretrained_kwagrs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - if QLORA: + + if CFG.QLoRA: model = AutoModelForCausalLM.from_pretrained( model_path, load_in_4bit=True, #quantize diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 909023f07..c91987579 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -61,6 +61,8 @@ class BaseOutputParser(ABC): # stream out output output = data["text"][11:].replace("", "").strip() + + # TODO gorilla and falcon output else: output = data["text"].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index a311312a2..b5c7128e7 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -116,10 +116,21 @@ class ProxyllmChatAdapter(BaseChatAdpter): return proxyllm_generate_stream +class GorillaChatAdapter(BaseChatAdpter): + + def match(self, model_path: str): + return "gorilla" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.gorilla_llm import generate_stream + + return generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) -register_llm_model_adapters(FalconChatAdapter) +register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) # Proxy model for test and develop, it's cheap for us now.