Merge branch 'llm_fxp' of https://github.com/csunny/DB-GPT into llm_fxp

This commit is contained in:
csunny 2023-06-09 16:20:11 +08:00
commit d587f59143
8 changed files with 48 additions and 27 deletions

View File

@ -146,7 +146,7 @@ class Config(metaclass=Singleton):
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
# QLoRA
# QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
### EMBEDDING Configuration

View File

@ -4,13 +4,25 @@
import torch
from typing import List
from functools import cache
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
BitsAndBytesConfig,
)
from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=False,
)
CFG = Config()
class BaseLLMAdaper:
"""The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT"""
@ -98,8 +110,8 @@ class GuanacoAdapter(BaseLLMAdaper):
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
)
return model, tokenizer
class FalconAdapater(BaseLLMAdaper):
"""falcon Adapter"""
@ -111,23 +123,23 @@ class FalconAdapater(BaseLLMAdaper):
if CFG.QLoRA:
model = AutoModelForCausalLM.from_pretrained(
model_path,
load_in_4bit=True, #quantize
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
**from_pretrained_kwagrs
model_path,
load_in_4bit=True, # quantize
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
**from_pretrained_kwagrs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
model_path,
trust_remote_code=True,
device_map={"": 0},
**from_pretrained_kwagrs
**from_pretrained_kwagrs,
)
return model, tokenizer
class GorillaAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""

View File

@ -51,7 +51,10 @@ def chatglm_generate_stream(
# else:
# once_conversation.append(f"""###system:{message} """)
query = messages[-2].split("human:")[1]
try:
query = messages[-2].split("human:")[1]
except IndexError:
query = messages[-3].split("human:")[1]
print("Query Message: ", query)
# output = ""
# i = 0

View File

@ -19,7 +19,7 @@ def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]

View File

@ -1,5 +1,6 @@
import torch
@torch.inference_mode()
def generate_stream(
model, tokenizer, params, device, context_len=42048, stream_interval=2
@ -22,7 +23,7 @@ def generate_stream(
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
@ -37,7 +38,6 @@ def generate_stream(
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
@ -45,7 +45,11 @@ def generate_stream(
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
tmp_output_ids = output_ids[input_echo_len:]
output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,)
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
output = output[:pos]
@ -55,4 +59,4 @@ def generate_stream(
if stopped:
break
del past_key_values
del past_key_values

View File

@ -62,7 +62,7 @@ class BaseOutputParser(ABC):
# stream out output
output = data["text"][11:].replace("<s>", "").strip()
# TODO gorilla and falcon output
# TODO gorilla and falcon output
else:
output = data["text"].strip()

View File

@ -94,7 +94,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
return guanaco_generate_stream
class FalconChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
@ -105,7 +105,8 @@ class FalconChatAdapter(BaseChatAdpter):
from pilot.model.llm_out.falcon_llm import falcon_generate_output
return falcon_generate_output
class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "proxyllm" in model_path
@ -116,8 +117,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
return proxyllm_generate_stream
class GorillaChatAdapter(BaseChatAdpter):
class GorillaChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gorilla" in model_path

View File

@ -28,7 +28,9 @@ class PDFEmbedding(SourceEmbedding):
# textsplitter = CHNDocumentSplitter(
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# )
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=1000, chunk_overlap=200)
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
)
return loader.load_and_split(textsplitter)
@register