diff --git a/pilot/model/chatglm_llm.py b/pilot/model/chatglm_llm.py index 656252785..f8279be7f 100644 --- a/pilot/model/chatglm_llm.py +++ b/pilot/model/chatglm_llm.py @@ -8,16 +8,15 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, """Generate text using chatglm model's chat api """ prompt = params["prompt"] - max_new_tokens = int(params.get("max_new_tokens", 256)) temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) stop = params.get("stop", "###") echo = params.get("echo", True) generate_kwargs = { - "max_new_tokens": max_new_tokens, "do_sample": True if temperature > 1e-5 else False, "top_p": top_p, + "repetition_penalty": 1.0, "logits_processor": None } @@ -34,6 +33,7 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1])) query = messages[-2].split(":")[1] + print("Query Message: ", query) output = "" i = 0 for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):