mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-23 02:27:55 +00:00
Merge branch 'main' into dev
This commit is contained in:
commit
61113727a4
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import copy
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
@ -57,3 +56,55 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
out = decoded_output.split("### Response:")[-1].strip()
|
||||
|
||||
yield out
|
||||
|
||||
|
||||
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
tokenizer.bos_token_id = 1
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
max_new_tokens = params.get("max_new_tokens", 512)
|
||||
temerature = params.get("temperature", 1.0)
|
||||
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[-1][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temerature,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
model.generate(**generate_kwargs)
|
||||
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
||||
|
@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
}
|
||||
|
||||
print(payloads)
|
||||
print(headers)
|
||||
res = requests.post(
|
||||
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||
)
|
||||
|
||||
text = ""
|
||||
print("====================================res================")
|
||||
print(res)
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode("utf-8")
|
||||
|
@ -118,6 +118,8 @@ class ModelLoader(metaclass=Singleton):
|
||||
model.to(self.device)
|
||||
except ValueError:
|
||||
pass
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if debug:
|
||||
print(model)
|
||||
|
@ -56,8 +56,11 @@ class BaseOutputParser(ABC):
|
||||
# output = data["text"][skip_echo_len + 11:].strip()
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
elif "guanaco" in CFG.LLM_MODEL:
|
||||
# output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
|
||||
output = data["text"][skip_echo_len:].replace("<s>", "").strip()
|
||||
# NO stream output
|
||||
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
||||
|
||||
# stream out output
|
||||
output = data["text"][11:].replace("<s>", "").strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
|
@ -101,9 +101,9 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
|
||||
|
||||
return guanaco_generate_output
|
||||
return guanaco_generate_stream
|
||||
|
||||
|
||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
|
Loading…
Reference in New Issue
Block a user