mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 05:55:54 +00:00
Merge branch 'llm_fxp' of https://github.com/csunny/DB-GPT into llm_fxp
This commit is contained in:
commit
32ce199173
@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import copy
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||||
@ -57,3 +56,53 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
out = decoded_output.split("### Response:")[-1].strip()
|
out = decoded_output.split("### Response:")[-1].strip()
|
||||||
|
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
|
|
||||||
|
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||||
|
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
print(params)
|
||||||
|
stop = params.get("stop", "###")
|
||||||
|
prompt = params["prompt"]
|
||||||
|
query = prompt
|
||||||
|
print("Query Message: ", query)
|
||||||
|
|
||||||
|
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||||
|
input_ids = input_ids.to(model.device)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
stop_token_ids = [0]
|
||||||
|
|
||||||
|
class StopOnTokens(StoppingCriteria):
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||||
|
) -> bool:
|
||||||
|
for stop_id in stop_token_ids:
|
||||||
|
if input_ids[0][-1] == stop_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
stop = StopOnTokens()
|
||||||
|
|
||||||
|
generate_kwargs = dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_new_tokens=512,
|
||||||
|
temperature=1.0,
|
||||||
|
do_sample=True,
|
||||||
|
top_k=1,
|
||||||
|
streamer=streamer,
|
||||||
|
repetition_penalty=1.7,
|
||||||
|
stopping_criteria=StoppingCriteriaList([stop]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
generator = model.generate(**generate_kwargs)
|
||||||
|
out = ""
|
||||||
|
for new_text in streamer:
|
||||||
|
out += new_text
|
||||||
|
yield new_text
|
||||||
|
return out
|
@ -4,7 +4,7 @@
|
|||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
from pilot.model.inference import generate_stream
|
||||||
|
|
||||||
|
|
||||||
class BaseChatAdpter:
|
class BaseChatAdpter:
|
||||||
@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
|||||||
return "chatglm" in model_path
|
return "chatglm" in model_path
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream
|
from pilot.model.chatglm_llm import chatglm_generate_stream
|
||||||
|
|
||||||
return chatglm_generate_stream
|
return chatglm_generate_stream
|
||||||
|
|
||||||
@ -85,15 +85,15 @@ class CodeGenChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
|
|
||||||
class GuanacoChatAdapter(BaseChatAdpter):
|
class GuanacoChatAdapter(BaseChatAdpter):
|
||||||
"""Model chat adapter for Guanaco"""
|
"""Model chat adapter for Guanaco """
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "guanaco" in model_path
|
return "guanaco" in model_path
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
|
from pilot.model.guanaco_llm import guanaco_generate_stream
|
||||||
|
|
||||||
return guanaco_generate_output
|
return guanaco_generate_stream
|
||||||
|
|
||||||
|
|
||||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||||
@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
|||||||
return "proxyllm" in model_path
|
return "proxyllm" in model_path
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream
|
from pilot.model.proxy_llm import proxyllm_generate_stream
|
||||||
|
|
||||||
return proxyllm_generate_stream
|
return proxyllm_generate_stream
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user