Merge branch 'main' into dev

This commit is contained in:
csunny 2023-06-05 22:25:35 +08:00
commit 61113727a4
5 changed files with 61 additions and 9 deletions

View File

@ -1,5 +1,4 @@
import torch import torch
import copy
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@ -57,3 +56,55 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
out = decoded_output.split("### Response:")[-1].strip() out = decoded_output.split("### Response:")[-1].strip()
yield out yield out
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
max_new_tokens = params.get("max_new_tokens", 512)
temerature = params.get("temperature", 1.0)
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[-1][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temerature,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
model.generate(**generate_kwargs)
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
"max_tokens": params.get("max_new_tokens"), "max_tokens": params.get("max_new_tokens"),
} }
print(payloads)
print(headers)
res = requests.post( res = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=True CFG.proxy_server_url, headers=headers, json=payloads, stream=True
) )
text = "" text = ""
print("====================================res================")
print(res)
for line in res.iter_lines(): for line in res.iter_lines():
if line: if line:
decoded_line = line.decode("utf-8") decoded_line = line.decode("utf-8")

View File

@ -118,6 +118,8 @@ class ModelLoader(metaclass=Singleton):
model.to(self.device) model.to(self.device)
except ValueError: except ValueError:
pass pass
except AttributeError:
pass
if debug: if debug:
print(model) print(model)

View File

@ -56,8 +56,11 @@ class BaseOutputParser(ABC):
# output = data["text"][skip_echo_len + 11:].strip() # output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()
elif "guanaco" in CFG.LLM_MODEL: elif "guanaco" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip() # NO stream output
output = data["text"][skip_echo_len:].replace("<s>", "").strip() # output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
# stream out output
output = data["text"][11:].replace("<s>", "").strip()
else: else:
output = data["text"].strip() output = data["text"].strip()

View File

@ -101,9 +101,9 @@ class GuanacoChatAdapter(BaseChatAdpter):
return "guanaco" in model_path return "guanaco" in model_path
def get_generate_stream_func(self): def get_generate_stream_func(self):
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
return guanaco_generate_output return guanaco_generate_stream
class ProxyllmChatAdapter(BaseChatAdpter): class ProxyllmChatAdapter(BaseChatAdpter):