mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 04:44:14 +00:00
feature: gorllia support (#173)
This commit is contained in:
parent
bea3f7c8d2
commit
716460460f
@ -146,7 +146,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||||
|
|
||||||
# QLoRA
|
# QLoRA
|
||||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||||
|
|
||||||
### EMBEDDING Configuration
|
### EMBEDDING Configuration
|
||||||
|
@ -4,13 +4,25 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import List
|
from typing import List
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
LlamaTokenizer,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
)
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import DEVICE
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype="bfloat16",
|
||||||
|
bnb_4bit_use_double_quant=False,
|
||||||
|
)
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMAdaper:
|
class BaseLLMAdaper:
|
||||||
"""The Base class for multi model, in our project.
|
"""The Base class for multi model, in our project.
|
||||||
We will support those model, which performance resemble ChatGPT"""
|
We will support those model, which performance resemble ChatGPT"""
|
||||||
@ -98,8 +110,8 @@ class GuanacoAdapter(BaseLLMAdaper):
|
|||||||
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
|
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
|
||||||
)
|
)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
class FalconAdapater(BaseLLMAdaper):
|
class FalconAdapater(BaseLLMAdaper):
|
||||||
"""falcon Adapter"""
|
"""falcon Adapter"""
|
||||||
|
|
||||||
@ -111,23 +123,23 @@ class FalconAdapater(BaseLLMAdaper):
|
|||||||
|
|
||||||
if CFG.QLoRA:
|
if CFG.QLoRA:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
load_in_4bit=True, #quantize
|
load_in_4bit=True, # quantize
|
||||||
quantization_config=bnb_config,
|
quantization_config=bnb_config,
|
||||||
device_map={"": 0},
|
device_map={"": 0},
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
**from_pretrained_kwagrs
|
**from_pretrained_kwagrs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
device_map={"": 0},
|
device_map={"": 0},
|
||||||
**from_pretrained_kwagrs
|
**from_pretrained_kwagrs,
|
||||||
)
|
)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
class GorillaAdapter(BaseLLMAdaper):
|
class GorillaAdapter(BaseLLMAdaper):
|
||||||
"""TODO Support guanaco"""
|
"""TODO Support guanaco"""
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
streamer = TextIteratorStreamer(
|
streamer = TextIteratorStreamer(
|
||||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer.bos_token_id = 1
|
tokenizer.bos_token_id = 1
|
||||||
stop_token_ids = [0]
|
stop_token_ids = [0]
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_stream(
|
def generate_stream(
|
||||||
model, tokenizer, params, device, context_len=42048, stream_interval=2
|
model, tokenizer, params, device, context_len=42048, stream_interval=2
|
||||||
@ -22,7 +23,7 @@ def generate_stream(
|
|||||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||||
logits = out.logits
|
logits = out.logits
|
||||||
past_key_values = out.past_key_values
|
past_key_values = out.past_key_values
|
||||||
else:
|
else:
|
||||||
out = model(
|
out = model(
|
||||||
input_ids=torch.as_tensor([[token]], device=device),
|
input_ids=torch.as_tensor([[token]], device=device),
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
@ -37,7 +38,6 @@ def generate_stream(
|
|||||||
token = int(torch.multinomial(probs, num_samples=1))
|
token = int(torch.multinomial(probs, num_samples=1))
|
||||||
output_ids.append(token)
|
output_ids.append(token)
|
||||||
|
|
||||||
|
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
stopped = True
|
stopped = True
|
||||||
else:
|
else:
|
||||||
@ -45,7 +45,11 @@ def generate_stream(
|
|||||||
|
|
||||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||||
tmp_output_ids = output_ids[input_echo_len:]
|
tmp_output_ids = output_ids[input_echo_len:]
|
||||||
output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,)
|
output = tokenizer.decode(
|
||||||
|
tmp_output_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
spaces_between_special_tokens=False,
|
||||||
|
)
|
||||||
pos = output.rfind(stop_str, l_prompt)
|
pos = output.rfind(stop_str, l_prompt)
|
||||||
if pos != -1:
|
if pos != -1:
|
||||||
output = output[:pos]
|
output = output[:pos]
|
||||||
@ -55,4 +59,4 @@ def generate_stream(
|
|||||||
if stopped:
|
if stopped:
|
||||||
break
|
break
|
||||||
|
|
||||||
del past_key_values
|
del past_key_values
|
||||||
|
@ -62,7 +62,7 @@ class BaseOutputParser(ABC):
|
|||||||
# stream out output
|
# stream out output
|
||||||
output = data["text"][11:].replace("<s>", "").strip()
|
output = data["text"][11:].replace("<s>", "").strip()
|
||||||
|
|
||||||
# TODO gorilla and falcon output
|
# TODO gorilla and falcon output
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
return guanaco_generate_stream
|
return guanaco_generate_stream
|
||||||
|
|
||||||
|
|
||||||
class FalconChatAdapter(BaseChatAdpter):
|
class FalconChatAdapter(BaseChatAdpter):
|
||||||
"""Model chat adapter for Guanaco"""
|
"""Model chat adapter for Guanaco"""
|
||||||
|
|
||||||
@ -105,7 +105,8 @@ class FalconChatAdapter(BaseChatAdpter):
|
|||||||
from pilot.model.llm_out.falcon_llm import falcon_generate_output
|
from pilot.model.llm_out.falcon_llm import falcon_generate_output
|
||||||
|
|
||||||
return falcon_generate_output
|
return falcon_generate_output
|
||||||
|
|
||||||
|
|
||||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "proxyllm" in model_path
|
return "proxyllm" in model_path
|
||||||
@ -116,8 +117,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
|||||||
return proxyllm_generate_stream
|
return proxyllm_generate_stream
|
||||||
|
|
||||||
|
|
||||||
class GorillaChatAdapter(BaseChatAdpter):
|
class GorillaChatAdapter(BaseChatAdpter):
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "gorilla" in model_path
|
return "gorilla" in model_path
|
||||||
|
|
||||||
|
@ -28,7 +28,9 @@ class PDFEmbedding(SourceEmbedding):
|
|||||||
# textsplitter = CHNDocumentSplitter(
|
# textsplitter = CHNDocumentSplitter(
|
||||||
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||||
# )
|
# )
|
||||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=1000, chunk_overlap=200)
|
textsplitter = SpacyTextSplitter(
|
||||||
|
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
|
||||||
|
)
|
||||||
return loader.load_and_split(textsplitter)
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
@register
|
@register
|
||||||
|
Loading…
Reference in New Issue
Block a user