mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 14:34:28 +00:00
fix and update
This commit is contained in:
parent
89970bd71c
commit
604d269797
@ -8,16 +8,15 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
|
|||||||
|
|
||||||
"""Generate text using chatglm model's chat api """
|
"""Generate text using chatglm model's chat api """
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
top_p = float(params.get("top_p", 1.0))
|
top_p = float(params.get("top_p", 1.0))
|
||||||
stop = params.get("stop", "###")
|
stop = params.get("stop", "###")
|
||||||
echo = params.get("echo", True)
|
echo = params.get("echo", True)
|
||||||
|
|
||||||
generate_kwargs = {
|
generate_kwargs = {
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
"do_sample": True if temperature > 1e-5 else False,
|
"do_sample": True if temperature > 1e-5 else False,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
"logits_processor": None
|
"logits_processor": None
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,6 +33,7 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
|
|||||||
hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1]))
|
hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1]))
|
||||||
|
|
||||||
query = messages[-2].split(":")[1]
|
query = messages[-2].split(":")[1]
|
||||||
|
print("Query Message: ", query)
|
||||||
output = ""
|
output = ""
|
||||||
i = 0
|
i = 0
|
||||||
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):
|
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):
|
||||||
|
Loading…
Reference in New Issue
Block a user