mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 14:34:28 +00:00
guanaco: add stream output func (#154)
This commit is contained in:
parent
fe8291b198
commit
ff6cc05e11
@ -64,6 +64,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
|||||||
print(params)
|
print(params)
|
||||||
stop = params.get("stop", "###")
|
stop = params.get("stop", "###")
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
|
max_new_tokens = params.get("max_new_tokens", 512)
|
||||||
|
temerature = params.get("temperature", 1.0)
|
||||||
|
|
||||||
query = prompt
|
query = prompt
|
||||||
print("Query Message: ", query)
|
print("Query Message: ", query)
|
||||||
|
|
||||||
@ -82,7 +85,7 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
|||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||||
) -> bool:
|
) -> bool:
|
||||||
for stop_id in stop_token_ids:
|
for stop_id in stop_token_ids:
|
||||||
if input_ids[0][-1] == stop_id:
|
if input_ids[-1][-1] == stop_id:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -90,8 +93,8 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
|||||||
|
|
||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_new_tokens=512,
|
max_new_tokens=max_new_tokens,
|
||||||
temperature=1.0,
|
temperature=temerature,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_k=1,
|
top_k=1,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
@ -100,9 +103,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
generator = model.generate(**generate_kwargs)
|
model.generate(**generate_kwargs)
|
||||||
|
|
||||||
out = ""
|
out = ""
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
out += new_text
|
out += new_text
|
||||||
yield new_text
|
yield out
|
||||||
return out
|
|
@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
"max_tokens": params.get("max_new_tokens"),
|
"max_tokens": params.get("max_new_tokens"),
|
||||||
}
|
}
|
||||||
|
|
||||||
print(payloads)
|
|
||||||
print(headers)
|
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
print("====================================res================")
|
|
||||||
print(res)
|
|
||||||
for line in res.iter_lines():
|
for line in res.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode("utf-8")
|
decoded_line = line.decode("utf-8")
|
||||||
|
@ -56,8 +56,12 @@ class BaseOutputParser(ABC):
|
|||||||
# output = data["text"][skip_echo_len + 11:].strip()
|
# output = data["text"][skip_echo_len + 11:].strip()
|
||||||
output = data["text"][skip_echo_len:].strip()
|
output = data["text"][skip_echo_len:].strip()
|
||||||
elif "guanaco" in CFG.LLM_MODEL:
|
elif "guanaco" in CFG.LLM_MODEL:
|
||||||
# output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
|
|
||||||
output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
# NO stream output
|
||||||
|
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
||||||
|
|
||||||
|
# stream out output
|
||||||
|
output = data["text"][11:].replace("<s>", "").strip()
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user