mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 14:06:43 +00:00
fix: gorilla chat adapter and config
This commit is contained in:
parent
8b3d7b0ba7
commit
0948bc45bc
@ -146,6 +146,9 @@ class Config(metaclass=Singleton):
|
||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||
|
||||
# QLoRA
|
||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||
|
||||
### EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))
|
||||
|
@ -6,8 +6,10 @@ from typing import List
|
||||
from functools import cache
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.config import Config
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
|
||||
CFG = Config()
|
||||
|
||||
class BaseLLMAdaper:
|
||||
"""The Base class for multi model, in our project.
|
||||
@ -106,7 +108,8 @@ class FalconAdapater(BaseLLMAdaper):
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
if QLORA:
|
||||
|
||||
if CFG.QLoRA:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
load_in_4bit=True, #quantize
|
||||
|
@ -61,6 +61,8 @@ class BaseOutputParser(ABC):
|
||||
|
||||
# stream out output
|
||||
output = data["text"][11:].replace("<s>", "").strip()
|
||||
|
||||
# TODO gorilla and falcon output
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
|
@ -116,10 +116,21 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
return proxyllm_generate_stream
|
||||
|
||||
|
||||
class GorillaChatAdapter(BaseChatAdpter):
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "gorilla" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.llm_out.gorilla_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||
register_llm_model_adapters(FalconChatAdapter)
|
||||
register_llm_model_chat_adapter(FalconChatAdapter)
|
||||
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
|
Loading…
Reference in New Issue
Block a user