fix and update

This commit is contained in:
csunny 2023-05-21 16:11:52 +08:00
parent 89970bd71c
commit 604d269797

View File

@ -8,16 +8,15 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
"""Generate text using chatglm model's chat api """
prompt = params["prompt"]
max_new_tokens = int(params.get("max_new_tokens", 256))
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
stop = params.get("stop", "###")
echo = params.get("echo", True)
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": 1.0,
"logits_processor": None
}
@ -34,6 +33,7 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1]))
query = messages[-2].split(":")[1]
print("Query Message: ", query)
output = ""
i = 0
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):