mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 00:28:00 +00:00
llms: add chatglm model
This commit is contained in:
parent
42b76979a3
commit
7b454d8867
@ -5,8 +5,15 @@ import requests
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
@ -21,24 +28,24 @@ def generate(query):
|
|||||||
template_name = "conv_one_shot"
|
template_name = "conv_one_shot"
|
||||||
state = conv_templates[template_name].copy()
|
state = conv_templates[template_name].copy()
|
||||||
|
|
||||||
pt = PromptTemplate(
|
# pt = PromptTemplate(
|
||||||
template=conv_qa_prompt_template,
|
# template=conv_qa_prompt_template,
|
||||||
input_variables=["context", "question"]
|
# input_variables=["context", "question"]
|
||||||
)
|
# )
|
||||||
|
|
||||||
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
|
# result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
|
||||||
question=query)
|
# question=query)
|
||||||
|
|
||||||
print(result)
|
# print(result)
|
||||||
|
|
||||||
state.append_message(state.roles[0], result)
|
state.append_message(state.roles[0], query)
|
||||||
state.append_message(state.roles[1], None)
|
state.append_message(state.roles[1], None)
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
params = {
|
params = {
|
||||||
"model": "vicuna-13b",
|
"model": "chatglm-6b",
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"temperature": 0.7,
|
"temperature": 1.0,
|
||||||
"max_new_tokens": 1024,
|
"max_new_tokens": 1024,
|
||||||
"stop": "###"
|
"stop": "###"
|
||||||
}
|
}
|
||||||
@ -48,11 +55,17 @@ def generate(query):
|
|||||||
)
|
)
|
||||||
|
|
||||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
|
||||||
if chunk:
|
if chunk:
|
||||||
data = json.loads(chunk.decode())
|
data = json.loads(chunk.decode())
|
||||||
if data["error_code"] == 0:
|
if data["error_code"] == 0:
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
|
if "vicuna" in CFG.LLM_MODEL:
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
else:
|
||||||
|
output = data["text"].strip()
|
||||||
state.messages[-1][-1] = output + "▌"
|
state.messages[-1][-1] = output + "▌"
|
||||||
yield(output)
|
yield(output)
|
||||||
|
|
||||||
|
@ -7,10 +7,11 @@ import torch
|
|||||||
def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||||
|
|
||||||
"""Generate text using chatglm model's chat api """
|
"""Generate text using chatglm model's chat api """
|
||||||
messages = params["prompt"]
|
prompt = params["prompt"]
|
||||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
top_p = float(params.get("top_p", 1.0))
|
top_p = float(params.get("top_p", 1.0))
|
||||||
|
stop = params.get("stop", "###")
|
||||||
echo = params.get("echo", True)
|
echo = params.get("echo", True)
|
||||||
|
|
||||||
generate_kwargs = {
|
generate_kwargs = {
|
||||||
@ -23,11 +24,16 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
|
|||||||
if temperature > 1e-5:
|
if temperature > 1e-5:
|
||||||
generate_kwargs["temperature"] = temperature
|
generate_kwargs["temperature"] = temperature
|
||||||
|
|
||||||
|
# TODO, Fix this
|
||||||
hist = []
|
hist = []
|
||||||
for i in range(0, len(messages) - 2, 2):
|
|
||||||
hist.append(messages[i][1], messages[i + 1][1])
|
|
||||||
|
|
||||||
query = messages[-2][1]
|
messages = prompt.split(stop)
|
||||||
|
|
||||||
|
# Add history chat to hist for model.
|
||||||
|
for i in range(1, len(messages) - 2, 2):
|
||||||
|
hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1]))
|
||||||
|
|
||||||
|
query = messages[-2].split(":")[1]
|
||||||
output = ""
|
output = ""
|
||||||
i = 0
|
i = 0
|
||||||
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):
|
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):
|
||||||
|
@ -364,8 +364,16 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
data = json.loads(chunk.decode())
|
data = json.loads(chunk.decode())
|
||||||
|
|
||||||
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
|
"""
|
||||||
if data["error_code"] == 0:
|
if data["error_code"] == 0:
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
|
if "vicuna" in CFG.LLM_MODEL:
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
else:
|
||||||
|
output = data["text"].strip()
|
||||||
|
|
||||||
output = post_process_code(output)
|
output = post_process_code(output)
|
||||||
state.messages[-1][-1] = output + "▌"
|
state.messages[-1][-1] = output + "▌"
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
Loading…
Reference in New Issue
Block a user