llms: add chatglm model

This commit is contained in:
csunny 2023-05-21 14:08:18 +08:00
parent 42b76979a3
commit 7b454d8867
3 changed files with 43 additions and 16 deletions

View File

@ -5,8 +5,15 @@ import requests
import json import json
import time import time
import uuid import uuid
import os
import sys
from urllib.parse import urljoin from urllib.parse import urljoin
import gradio as gr import gradio as gr
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_templates from pilot.conversation import conv_qa_prompt_template, conv_templates
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
@ -21,24 +28,24 @@ def generate(query):
template_name = "conv_one_shot" template_name = "conv_one_shot"
state = conv_templates[template_name].copy() state = conv_templates[template_name].copy()
pt = PromptTemplate( # pt = PromptTemplate(
template=conv_qa_prompt_template, # template=conv_qa_prompt_template,
input_variables=["context", "question"] # input_variables=["context", "question"]
) # )
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.", # result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
question=query) # question=query)
print(result) # print(result)
state.append_message(state.roles[0], result) state.append_message(state.roles[0], query)
state.append_message(state.roles[1], None) state.append_message(state.roles[1], None)
prompt = state.get_prompt() prompt = state.get_prompt()
params = { params = {
"model": "vicuna-13b", "model": "chatglm-6b",
"prompt": prompt, "prompt": prompt,
"temperature": 0.7, "temperature": 1.0,
"max_new_tokens": 1024, "max_new_tokens": 1024,
"stop": "###" "stop": "###"
} }
@ -48,11 +55,17 @@ def generate(query):
) )
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3 skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
if data["error_code"] == 0: if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
state.messages[-1][-1] = output + "" state.messages[-1][-1] = output + ""
yield(output) yield(output)

View File

@ -7,10 +7,11 @@ import torch
def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2): def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2):
"""Generate text using chatglm model's chat api """ """Generate text using chatglm model's chat api """
messages = params["prompt"] prompt = params["prompt"]
max_new_tokens = int(params.get("max_new_tokens", 256)) max_new_tokens = int(params.get("max_new_tokens", 256))
temperature = float(params.get("temperature", 1.0)) temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0)) top_p = float(params.get("top_p", 1.0))
stop = params.get("stop", "###")
echo = params.get("echo", True) echo = params.get("echo", True)
generate_kwargs = { generate_kwargs = {
@ -23,11 +24,16 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
if temperature > 1e-5: if temperature > 1e-5:
generate_kwargs["temperature"] = temperature generate_kwargs["temperature"] = temperature
# TODO, Fix this
hist = [] hist = []
for i in range(0, len(messages) - 2, 2):
hist.append(messages[i][1], messages[i + 1][1])
query = messages[-2][1] messages = prompt.split(stop)
# Add history chat to hist for model.
for i in range(1, len(messages) - 2, 2):
hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1]))
query = messages[-2].split(":")[1]
output = "" output = ""
i = 0 i = 0
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)): for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):

View File

@ -364,8 +364,16 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0: if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output) output = post_process_code(output)
state.messages[-1][-1] = output + "" state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5