mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-24 19:08:15 +00:00
llms: add chatglm model
This commit is contained in:
@@ -7,10 +7,11 @@ import torch
|
||||
def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
|
||||
"""Generate text using chatglm model's chat api """
|
||||
messages = params["prompt"]
|
||||
prompt = params["prompt"]
|
||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
stop = params.get("stop", "###")
|
||||
echo = params.get("echo", True)
|
||||
|
||||
generate_kwargs = {
|
||||
@@ -23,11 +24,16 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
|
||||
if temperature > 1e-5:
|
||||
generate_kwargs["temperature"] = temperature
|
||||
|
||||
# TODO, Fix this
|
||||
hist = []
|
||||
for i in range(0, len(messages) - 2, 2):
|
||||
hist.append(messages[i][1], messages[i + 1][1])
|
||||
|
||||
query = messages[-2][1]
|
||||
messages = prompt.split(stop)
|
||||
|
||||
# Add history chat to hist for model.
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1]))
|
||||
|
||||
query = messages[-2].split(":")[1]
|
||||
output = ""
|
||||
i = 0
|
||||
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):
|
||||
|
@@ -364,8 +364,16 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
output = post_process_code(output)
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
|
Reference in New Issue
Block a user