mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-07 11:23:40 +00:00
Merge branch 'llm_fxp' of https://github.com/csunny/DB-GPT into llm_fxp
This commit is contained in:
commit
d587f59143
@ -4,13 +4,25 @@
|
||||
import torch
|
||||
from typing import List
|
||||
from functools import cache
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.config import Config
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype="bfloat16",
|
||||
bnb_4bit_use_double_quant=False,
|
||||
)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class BaseLLMAdaper:
|
||||
"""The Base class for multi model, in our project.
|
||||
We will support those model, which performance resemble ChatGPT"""
|
||||
@ -116,14 +128,14 @@ class FalconAdapater(BaseLLMAdaper):
|
||||
quantization_config=bnb_config,
|
||||
device_map={"": 0},
|
||||
trust_remote_code=True,
|
||||
**from_pretrained_kwagrs
|
||||
**from_pretrained_kwagrs,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
device_map={"": 0},
|
||||
**from_pretrained_kwagrs
|
||||
**from_pretrained_kwagrs,
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
@ -51,7 +51,10 @@ def chatglm_generate_stream(
|
||||
# else:
|
||||
# once_conversation.append(f"""###system:{message} """)
|
||||
|
||||
try:
|
||||
query = messages[-2].split("human:")[1]
|
||||
except IndexError:
|
||||
query = messages[-3].split("human:")[1]
|
||||
print("Query Message: ", query)
|
||||
# output = ""
|
||||
# i = 0
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model, tokenizer, params, device, context_len=42048, stream_interval=2
|
||||
@ -37,7 +38,6 @@ def generate_stream(
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
output_ids.append(token)
|
||||
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
stopped = True
|
||||
else:
|
||||
@ -45,7 +45,11 @@ def generate_stream(
|
||||
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
tmp_output_ids = output_ids[input_echo_len:]
|
||||
output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,)
|
||||
output = tokenizer.decode(
|
||||
tmp_output_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
)
|
||||
pos = output.rfind(stop_str, l_prompt)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
|
@ -106,6 +106,7 @@ class FalconChatAdapter(BaseChatAdpter):
|
||||
|
||||
return falcon_generate_output
|
||||
|
||||
|
||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "proxyllm" in model_path
|
||||
@ -117,7 +118,6 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
|
||||
|
||||
class GorillaChatAdapter(BaseChatAdpter):
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "gorilla" in model_path
|
||||
|
||||
|
@ -28,7 +28,9 @@ class PDFEmbedding(SourceEmbedding):
|
||||
# textsplitter = CHNDocumentSplitter(
|
||||
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
# )
|
||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=1000, chunk_overlap=200)
|
||||
textsplitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
|
Loading…
Reference in New Issue
Block a user