Merge branch 'llm_fxp' of https://github.com/csunny/DB-GPT into llm_fxp

This commit is contained in:
csunny 2023-06-09 16:20:11 +08:00
commit d587f59143
8 changed files with 48 additions and 27 deletions

View File

@ -4,13 +4,25 @@
import torch import torch
from typing import List from typing import List
from functools import cache from functools import cache
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
BitsAndBytesConfig,
)
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config from pilot.configs.config import Config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=False,
)
CFG = Config() CFG = Config()
class BaseLLMAdaper: class BaseLLMAdaper:
"""The Base class for multi model, in our project. """The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT""" We will support those model, which performance resemble ChatGPT"""
@ -116,14 +128,14 @@ class FalconAdapater(BaseLLMAdaper):
quantization_config=bnb_config, quantization_config=bnb_config,
device_map={"": 0}, device_map={"": 0},
trust_remote_code=True, trust_remote_code=True,
**from_pretrained_kwagrs **from_pretrained_kwagrs,
) )
else: else:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
trust_remote_code=True, trust_remote_code=True,
device_map={"": 0}, device_map={"": 0},
**from_pretrained_kwagrs **from_pretrained_kwagrs,
) )
return model, tokenizer return model, tokenizer

View File

@ -51,7 +51,10 @@ def chatglm_generate_stream(
# else: # else:
# once_conversation.append(f"""###system:{message} """) # once_conversation.append(f"""###system:{message} """)
try:
query = messages[-2].split("human:")[1] query = messages[-2].split("human:")[1]
except IndexError:
query = messages[-3].split("human:")[1]
print("Query Message: ", query) print("Query Message: ", query)
# output = "" # output = ""
# i = 0 # i = 0

View File

@ -1,5 +1,6 @@
import torch import torch
@torch.inference_mode() @torch.inference_mode()
def generate_stream( def generate_stream(
model, tokenizer, params, device, context_len=42048, stream_interval=2 model, tokenizer, params, device, context_len=42048, stream_interval=2
@ -37,7 +38,6 @@ def generate_stream(
token = int(torch.multinomial(probs, num_samples=1)) token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token) output_ids.append(token)
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
stopped = True stopped = True
else: else:
@ -45,7 +45,11 @@ def generate_stream(
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
tmp_output_ids = output_ids[input_echo_len:] tmp_output_ids = output_ids[input_echo_len:]
output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,) output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
pos = output.rfind(stop_str, l_prompt) pos = output.rfind(stop_str, l_prompt)
if pos != -1: if pos != -1:
output = output[:pos] output = output[:pos]

View File

@ -106,6 +106,7 @@ class FalconChatAdapter(BaseChatAdpter):
return falcon_generate_output return falcon_generate_output
class ProxyllmChatAdapter(BaseChatAdpter): class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "proxyllm" in model_path return "proxyllm" in model_path
@ -117,7 +118,6 @@ class ProxyllmChatAdapter(BaseChatAdpter):
class GorillaChatAdapter(BaseChatAdpter): class GorillaChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "gorilla" in model_path return "gorilla" in model_path

View File

@ -28,7 +28,9 @@ class PDFEmbedding(SourceEmbedding):
# textsplitter = CHNDocumentSplitter( # textsplitter = CHNDocumentSplitter(
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# ) # )
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=1000, chunk_overlap=200) textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)
@register @register