Merge branch 'llm_fxp' of https://github.com/csunny/DB-GPT into llm_fxp

This commit is contained in:
csunny 2023-06-04 20:31:06 +08:00
commit 32ce199173
2 changed files with 58 additions and 9 deletions

View File

@ -1,5 +1,4 @@
import torch
import copy
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@ -57,3 +56,53 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
out = decoded_output.split("### Response:")[-1].strip()
yield out
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
generator = model.generate(**generate_kwargs)
out = ""
for new_text in streamer:
out += new_text
yield new_text
return out

View File

@ -4,7 +4,7 @@
from functools import cache
from typing import List
from pilot.model.llm_out.vicuna_base_llm import generate_stream
from pilot.model.inference import generate_stream
class BaseChatAdpter:
@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
return "chatglm" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream
from pilot.model.chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream
@ -85,15 +85,15 @@ class CodeGenChatAdapter(BaseChatAdpter):
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
"""Model chat adapter for Guanaco """
def match(self, model_path: str):
return "guanaco" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
return guanaco_generate_output
from pilot.model.guanaco_llm import guanaco_generate_stream
return guanaco_generate_stream
class ProxyllmChatAdapter(BaseChatAdpter):
@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
return "proxyllm" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream
from pilot.model.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream