Merge branch 'dev' of https://github.com/csunny/DB-GPT into dev

This commit is contained in:
csunny 2023-06-04 21:47:56 +08:00
commit 759798948b
4 changed files with 82 additions and 0 deletions

View File

@ -34,6 +34,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"proxyllm": "proxyllm",

View File

@ -83,6 +83,21 @@ class ChatGLMAdapater(BaseLLMAdaper):
return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""
def match(self, model_path: str):
return "guanaco" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
)
return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""

View File

@ -0,0 +1,55 @@
import torch
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
def guanaco_stream_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
out = ""
for new_text in streamer:
out += new_text
yield new_text
return out

View File

@ -59,6 +59,16 @@ class ChatGLMChatAdapter(BaseChatAdpter):
return chatglm_generate_stream
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
def match(self, model_path: str):
return "guanaco" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.guanaco_stream_llm import guanaco_stream_generate_output
return guanaco_generate_output
class CodeT5ChatAdapter(BaseChatAdpter):
@ -110,6 +120,7 @@ register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)