diff --git a/examples/embdserver.py b/examples/embdserver.py index bb0016f00..b8525fe15 100644 --- a/examples/embdserver.py +++ b/examples/embdserver.py @@ -5,8 +5,15 @@ import requests import json import time import uuid +import os +import sys from urllib.parse import urljoin import gradio as gr + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(ROOT_PATH) + + from pilot.configs.config import Config from pilot.conversation import conv_qa_prompt_template, conv_templates from langchain.prompts import PromptTemplate @@ -21,24 +28,24 @@ def generate(query): template_name = "conv_one_shot" state = conv_templates[template_name].copy() - pt = PromptTemplate( - template=conv_qa_prompt_template, - input_variables=["context", "question"] - ) + # pt = PromptTemplate( + # template=conv_qa_prompt_template, + # input_variables=["context", "question"] + # ) - result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.", - question=query) + # result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.", + # question=query) - print(result) + # print(result) - state.append_message(state.roles[0], result) + state.append_message(state.roles[0], query) state.append_message(state.roles[1], None) prompt = state.get_prompt() params = { - "model": "vicuna-13b", + "model": "chatglm-6b", "prompt": prompt, - "temperature": 0.7, + "temperature": 1.0, "max_new_tokens": 1024, "stop": "###" } @@ -48,11 +55,17 @@ def generate(query): ) skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: - output = data["text"][skip_echo_len:].strip() + + if "vicuna" in CFG.LLM_MODEL: + output = data["text"][skip_echo_len:].strip() + else: + output = data["text"].strip() state.messages[-1][-1] = output + "▌" yield(output) diff --git a/pilot/model/chatglm_llm.py b/pilot/model/chatglm_llm.py index ef54e92d7..656252785 100644 --- a/pilot/model/chatglm_llm.py +++ b/pilot/model/chatglm_llm.py @@ -7,10 +7,11 @@ import torch def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2): """Generate text using chatglm model's chat api """ - messages = params["prompt"] + prompt = params["prompt"] max_new_tokens = int(params.get("max_new_tokens", 256)) temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) + stop = params.get("stop", "###") echo = params.get("echo", True) generate_kwargs = { @@ -23,11 +24,16 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, if temperature > 1e-5: generate_kwargs["temperature"] = temperature + # TODO, Fix this hist = [] - for i in range(0, len(messages) - 2, 2): - hist.append(messages[i][1], messages[i + 1][1]) - query = messages[-2][1] + messages = prompt.split(stop) + + # Add history chat to hist for model. + for i in range(1, len(messages) - 2, 2): + hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1])) + + query = messages[-2].split(":")[1] output = "" i = 0 for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index ea8d8fc6d..2dd2ba9e0 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -364,8 +364,16 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) + + """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. + """ if data["error_code"] == 0: - output = data["text"][skip_echo_len:].strip() + + if "vicuna" in CFG.LLM_MODEL: + output = data["text"][skip_echo_len:].strip() + else: + output = data["text"].strip() + output = post_process_code(output) state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5