Files
DB-GPT/pilot/model/llm_out/chatglm_llm.py
2023-06-08 21:29:02 +08:00

73 lines
2.0 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import copy
import torch
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@torch.inference_mode()
def chatglm_generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2
):
"""Generate text using chatglm model's chat api"""
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
stop = params.get("stop", "###")
echo = params.get("echo", False)
generate_kwargs = {
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": 1.0,
"logits_processor": None,
}
if temperature > 1e-5:
generate_kwargs["temperature"] = temperature
# TODO, Fix this
print(prompt)
messages = prompt.split(stop)
#
# # Add history conversation
hist = []
once_conversation = []
for message in messages[:-2]:
if len(message) <= 0:
continue
if "human:" in message:
once_conversation.append(message.split("human:")[1])
# elif "system:" in message:
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
elif "ai:" in message:
once_conversation.append(message.split("ai:")[1])
last_conversation = copy.deepcopy(once_conversation)
hist.append(last_conversation)
once_conversation = []
# else:
# once_conversation.append(f"""###system:{message} """)
try:
query = messages[-2].split("human:")[1]
except IndexError:
query = messages[-3].split("human:")[1]
print("Query Message: ", query)
# output = ""
# i = 0
for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
):
if echo:
output = query + " " + response
else:
output = response
yield output
yield output