fix: gorilla chat adapter and config

This commit is contained in:
csunny 2023-06-08 17:35:17 +08:00
parent 8b3d7b0ba7
commit 0948bc45bc
4 changed files with 21 additions and 2 deletions

View File

@ -146,6 +146,9 @@ class Config(metaclass=Singleton):
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
# QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
### EMBEDDING Configuration ### EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500)) self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))

View File

@ -6,8 +6,10 @@ from typing import List
from functools import cache from functools import cache
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
CFG = Config()
class BaseLLMAdaper: class BaseLLMAdaper:
"""The Base class for multi model, in our project. """The Base class for multi model, in our project.
@ -106,7 +108,8 @@ class FalconAdapater(BaseLLMAdaper):
def loader(self, model_path: str, from_pretrained_kwagrs: dict): def loader(self, model_path: str, from_pretrained_kwagrs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if QLORA:
if CFG.QLoRA:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
load_in_4bit=True, #quantize load_in_4bit=True, #quantize

View File

@ -61,6 +61,8 @@ class BaseOutputParser(ABC):
# stream out output # stream out output
output = data["text"][11:].replace("<s>", "").strip() output = data["text"][11:].replace("<s>", "").strip()
# TODO gorilla and falcon output
else: else:
output = data["text"].strip() output = data["text"].strip()

View File

@ -116,10 +116,21 @@ class ProxyllmChatAdapter(BaseChatAdpter):
return proxyllm_generate_stream return proxyllm_generate_stream
class GorillaChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gorilla" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.gorilla_llm import generate_stream
return generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_adapters(FalconChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.