guanaco: add stream output func (#154)

This commit is contained in:
csunny 2023-06-04 21:20:09 +08:00
parent fe8291b198
commit ff6cc05e11
3 changed files with 15 additions and 12 deletions

View File

@ -64,6 +64,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
max_new_tokens = params.get("max_new_tokens", 512)
temerature = params.get("temperature", 1.0)
query = prompt
print("Query Message: ", query)
@ -82,7 +85,7 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
if input_ids[-1][-1] == stop_id:
return True
return False
@ -90,8 +93,8 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
max_new_tokens=max_new_tokens,
temperature=temerature,
do_sample=True,
top_k=1,
streamer=streamer,
@ -100,9 +103,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
)
generator = model.generate(**generate_kwargs)
model.generate(**generate_kwargs)
out = ""
for new_text in streamer:
out += new_text
yield new_text
return out
yield out

View File

@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
"max_tokens": params.get("max_new_tokens"),
}
print(payloads)
print(headers)
res = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
)
text = ""
print("====================================res================")
print(res)
for line in res.iter_lines():
if line:
decoded_line = line.decode("utf-8")

View File

@ -56,8 +56,12 @@ class BaseOutputParser(ABC):
# output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip()
elif "guanaco" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
# NO stream output
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
# stream out output
output = data["text"][11:].replace("<s>", "").strip()
else:
output = data["text"].strip()