mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-13 04:38:25 +00:00
73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
import copy
|
|
|
|
import torch
|
|
|
|
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
|
|
|
|
|
@torch.inference_mode()
|
|
def chatglm_generate_stream(
|
|
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
|
):
|
|
"""Generate text using chatglm model's chat api"""
|
|
prompt = params["prompt"]
|
|
temperature = float(params.get("temperature", 1.0))
|
|
top_p = float(params.get("top_p", 1.0))
|
|
stop = params.get("stop", "###")
|
|
echo = params.get("echo", False)
|
|
|
|
generate_kwargs = {
|
|
"do_sample": True if temperature > 1e-5 else False,
|
|
"top_p": top_p,
|
|
"repetition_penalty": 1.0,
|
|
"logits_processor": None,
|
|
}
|
|
|
|
if temperature > 1e-5:
|
|
generate_kwargs["temperature"] = temperature
|
|
|
|
# TODO, Fix this
|
|
print(prompt)
|
|
messages = prompt.split(stop)
|
|
#
|
|
# # Add history conversation
|
|
hist = []
|
|
once_conversation = []
|
|
for message in messages[:-2]:
|
|
if len(message) <= 0:
|
|
continue
|
|
|
|
if "human:" in message:
|
|
once_conversation.append(message.split("human:")[1])
|
|
# elif "system:" in message:
|
|
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
|
|
elif "ai:" in message:
|
|
once_conversation.append(message.split("ai:")[1])
|
|
last_conversation = copy.deepcopy(once_conversation)
|
|
hist.append(last_conversation)
|
|
once_conversation = []
|
|
# else:
|
|
# once_conversation.append(f"""###system:{message} """)
|
|
|
|
try:
|
|
query = messages[-2].split("human:")[1]
|
|
except IndexError:
|
|
query = messages[-3].split("human:")[1]
|
|
print("Query Message: ", query)
|
|
# output = ""
|
|
# i = 0
|
|
|
|
for i, (response, new_hist) in enumerate(
|
|
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
|
):
|
|
if echo:
|
|
output = query + " " + response
|
|
else:
|
|
output = response
|
|
|
|
yield output
|
|
|
|
yield output
|