llms: add chatglm model

This commit is contained in:
csunny
2023-05-21 14:08:18 +08:00
parent 42b76979a3
commit 7b454d8867
3 changed files with 43 additions and 16 deletions

View File

@@ -5,8 +5,15 @@ import requests
import json
import time
import uuid
import os
import sys
from urllib.parse import urljoin
import gradio as gr
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_templates
from langchain.prompts import PromptTemplate
@@ -21,24 +28,24 @@ def generate(query):
template_name = "conv_one_shot"
state = conv_templates[template_name].copy()
pt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
# pt = PromptTemplate(
# template=conv_qa_prompt_template,
# input_variables=["context", "question"]
# )
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
question=query)
# result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
# question=query)
print(result)
# print(result)
state.append_message(state.roles[0], result)
state.append_message(state.roles[0], query)
state.append_message(state.roles[1], None)
prompt = state.get_prompt()
params = {
"model": "vicuna-13b",
"model": "chatglm-6b",
"prompt": prompt,
"temperature": 0.7,
"temperature": 1.0,
"max_new_tokens": 1024,
"stop": "###"
}
@@ -48,11 +55,17 @@ def generate(query):
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
state.messages[-1][-1] = output + ""
yield(output)