From ae761a3bfa01223caf1fa8855054faa6784d8832 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Tue, 18 Jul 2023 18:07:09 +0800 Subject: [PATCH 1/2] feat: structured prompt data interaction between dbgpt-server and llm-server --- pilot/model/llm_out/chatglm_llm.py | 55 ++++++------------ pilot/model/llm_out/proxy_llm.py | 37 ++++-------- pilot/scene/base_chat.py | 93 +++++++++++++++++++++++++----- pilot/scene/base_message.py | 41 +++++++++++++ pilot/server/llmserver.py | 4 ++ 5 files changed, 151 insertions(+), 79 deletions(-) diff --git a/pilot/model/llm_out/chatglm_llm.py b/pilot/model/llm_out/chatglm_llm.py index 9bfcac915..968101c56 100644 --- a/pilot/model/llm_out/chatglm_llm.py +++ b/pilot/model/llm_out/chatglm_llm.py @@ -8,6 +8,11 @@ import copy import torch from pilot.conversation import ROLE_ASSISTANT, ROLE_USER +from pilot.scene.base_message import ModelMessage, _parse_model_messages + +# TODO move sep to scene prompt of model +_CHATGLM_SEP = "\n" +_CHATGLM2_SEP = "\n\n" @torch.inference_mode() @@ -32,42 +37,20 @@ def chatglm_generate_stream( generate_kwargs["temperature"] = temperature # TODO, Fix this - print(prompt) - messages = prompt.split(stop) - # - # # Add history conversation - hist = [HistoryEntry()] - system_messages = [] - for message in messages[:-2]: - if len(message) <= 0: - continue - if "human:" in message: - hist[-1].add_question(message.split("human:")[1]) - elif "system:" in message: - msg = message.split("system:")[1] - hist[-1].add_question(msg) - system_messages.append(msg) - elif "ai:" in message: - hist[-1].add_answer(message.split("ai:")[1]) - hist.append(HistoryEntry()) - else: - # TODO - # hist[-1].add_question(message.split("system:")[1]) - # once_conversation.append(f"""###system:{message} """) - pass - - try: - query = messages[-2].split("human:")[1] - except IndexError: - query = messages[-3].split("human:")[1] - hist = build_history(hist) + # print(prompt) + # messages = prompt.split(stop) + messages: List[ModelMessage] = params["messages"] + query, system_messages, hist = _parse_model_messages(messages) + system_messages_str = "".join(system_messages) if not hist: # No history conversation, but has system messages, merge to user`s query - query = prompt_adaptation(system_messages, query) + query = prompt_adaptation(system_messages_str, query) + else: + # history exist, add system message to head of history + hist[0][0] = system_messages_str + _CHATGLM2_SEP + hist[0][0] + print("Query Message: ", query) print("hist: ", hist) - # output = "" - # i = 0 for i, (response, new_hist) in enumerate( model.stream_chat(tokenizer, query, hist, **generate_kwargs) @@ -103,10 +86,10 @@ def build_history(hist: List[HistoryEntry]) -> List[List[str]]: return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist))) -def prompt_adaptation(system_messages: List[str], human_message: str) -> str: - if not system_messages: +def prompt_adaptation(system_messages_str: str, human_message: str) -> str: + if not system_messages_str or system_messages_str == "": return human_message - system_messages_str = " ".join(system_messages) + # TODO Multi-model prompt adaptation adaptation_rules = [ r"Question:\s*{}\s*", # chat_db scene r"Goals:\s*{}\s*", # chat_execution @@ -119,4 +102,4 @@ def prompt_adaptation(system_messages: List[str], human_message: str) -> str: if re.search(pattern, system_messages_str): return system_messages_str # https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 - return f"{system_messages_str}\n\n问:{human_message}\n\n答:" + return system_messages_str + _CHATGLM2_SEP + human_message diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 3ec5d8504..c353426d2 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -3,8 +3,10 @@ import json import requests +from typing import List from pilot.configs.config import Config from pilot.conversation import ROLE_ASSISTANT, ROLE_USER +from pilot.scene.base_message import ModelMessage CFG = Config() @@ -20,36 +22,17 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) "Token": CFG.proxy_api_key, } - messages = prompt.split(stop) + messages: List[ModelMessage] = params["messages"] # Add history conversation for message in messages: - if len(message) <= 0: - continue - if "human:" in message: - history.append( - {"role": "user", "content": message.split("human:")[1]}, - ) - elif "system:" in message: - history.append( - { - "role": "system", - "content": message.split("system:")[1], - } - ) - elif "ai:" in message: - history.append( - { - "role": "assistant", - "content": message.split("ai:")[1], - } - ) + if message.role == "human": + history.append({"role": "user", "content": message.content}) + elif message.role == "system": + history.append({"role": "system", "content": message.content}) + elif message.role == "ai": + history.append({"role": "assistant", "content": message.content}) else: - history.append( - { - "role": "system", - "content": message, - } - ) + pass # Move the last user's information to the end temp_his = history[::-1] diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 7757d7fd6..ad1afbe88 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -37,6 +37,7 @@ from pilot.scene.base_message import ( HumanMessage, AIMessage, ViewMessage, + ModelMessage, ) from pilot.configs.config import Config @@ -116,6 +117,7 @@ class BaseChat(ABC): payload = { "model": self.llm_model, "prompt": self.generate_llm_text(), + "messages": self.generate_llm_messages(), "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), "stop": self.prompt_template.sep, @@ -244,24 +246,62 @@ class BaseChat(ABC): text += self.__load_user_message() return text - def __load_system_message(self): + def generate_llm_messages(self) -> List[ModelMessage]: + """ + Structured prompt messages interaction between dbgpt-server and llm-server + See https://github.com/csunny/DB-GPT/issues/328 + """ + messages = [] + ### Load scene setting or character definition as system message + if self.prompt_template.template_define: + messages.append( + ModelMessage( + role="system", + content=self.prompt_template.template_define, + ) + ) + ### Load prompt + messages += self.__load_system_message(str_message=False) + ### Load examples + messages += self.__load_example_messages(str_message=False) + + ### Load History + messages += self.__load_histroy_messages(str_message=False) + + ### Load User Input + messages += self.__load_user_message(str_message=False) + return messages + + def __load_system_message(self, str_message: bool = True): system_convs = self.current_message.get_system_conv() system_text = "" + system_messages = [] for system_conv in system_convs: system_text += ( system_conv.type + ":" + system_conv.content + self.prompt_template.sep ) - return system_text + system_messages.append( + ModelMessage(role=system_conv.type, content=system_conv.content) + ) + return system_text if str_message else system_messages - def __load_user_message(self): + def __load_user_message(self, str_message: bool = True): user_conv = self.current_message.get_user_conv() + user_messages = [] if user_conv: - return user_conv.type + ":" + user_conv.content + self.prompt_template.sep + user_text = ( + user_conv.type + ":" + user_conv.content + self.prompt_template.sep + ) + user_messages.append( + ModelMessage(role=user_conv.type, content=user_conv.content) + ) + return user_text if str_message else user_messages else: raise ValueError("Hi! What do you want to talk about?") - def __load_example_messages(self): + def __load_example_messages(self, str_message: bool = True): example_text = "" + example_messages = [] if self.prompt_template.example_selector: for round_conv in self.prompt_template.example_selector.examples(): for round_message in round_conv["messages"]: @@ -269,16 +309,22 @@ class BaseChat(ABC): SystemMessage.type, ViewMessage.type, ]: + message_type = round_message["type"] + message_content = round_message["data"]["content"] example_text += ( - round_message["type"] + message_type + ":" - + round_message["data"]["content"] + + message_content + self.prompt_template.sep ) - return example_text + example_messages.append( + ModelMessage(role=message_type, content=message_content) + ) + return example_text if str_message else example_messages - def __load_histroy_messages(self): + def __load_histroy_messages(self, str_message: bool = True): history_text = "" + history_messages = [] if self.prompt_template.need_historical_messages: if self.history_message: logger.info( @@ -290,12 +336,17 @@ class BaseChat(ABC): ViewMessage.type, SystemMessage.type, ]: + message_type = first_message["type"] + message_content = first_message["data"]["content"] history_text += ( - first_message["type"] + message_type + ":" - + first_message["data"]["content"] + + message_content + self.prompt_template.sep ) + history_messages.append( + ModelMessage(role=message_type, content=message_content) + ) index = self.chat_retention_rounds - 1 for round_conv in self.history_message[-index:]: @@ -304,12 +355,17 @@ class BaseChat(ABC): SystemMessage.type, ViewMessage.type, ]: + message_type = round_message["type"] + message_content = round_message["data"]["content"] history_text += ( - round_message["type"] + message_type + ":" - + round_message["data"]["content"] + + message_content + self.prompt_template.sep ) + history_messages.append( + ModelMessage(role=message_type, content=message_content) + ) else: ### user all history @@ -320,14 +376,19 @@ class BaseChat(ABC): SystemMessage.type, ViewMessage.type, ]: + message_type = message["type"] + message_content = message["data"]["content"] history_text += ( - message["type"] + message_type + ":" - + message["data"]["content"] + + message_content + self.prompt_template.sep ) + history_messages.append( + ModelMessage(role=message_type, content=message_content) + ) - return history_text + return history_text if str_message else history_messages def current_ai_response(self) -> str: for message in self.current_message.messages: diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index 56fbb3b20..20d513c39 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -6,6 +6,7 @@ from typing import ( Dict, Generic, List, + Tuple, NamedTuple, Optional, Sequence, @@ -80,6 +81,14 @@ class SystemMessage(BaseMessage): return "system" +class ModelMessage(BaseModel): + """Type of message that interaction between dbgpt-server and llm-server""" + + """Similar to openai's message format""" + role: str + content: str + + class Generation(BaseModel): """Output of a single generation.""" @@ -146,3 +155,35 @@ def _message_from_dict(message: dict) -> BaseMessage: def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: return [_message_from_dict(m) for m in messages] + + +def _parse_model_messages( + messages: List[ModelMessage], +) -> Tuple[str, List[str], List[List[str, str]]]: + """ " + Parameters: + messages: List of message from base chat. + Returns: + A tuple contains user prompt, system message list and history message list + str: user prompt + List[str]: system messages + List[List[str]]: history message of user and assistant + """ + user_prompt = "" + system_messages: List[str] = [] + history_messages: List[List[str]] = [[]] + + for message in messages[:-1]: + if message.role == "human": + history_messages[-1].append(message.content) + elif message.role == "system": + system_messages.append(message.content) + elif message.role == "ai": + history_messages[-1].append(message.content) + history_messages.append([]) + if messages[-1].role != "human": + raise ValueError("Hi! What do you want to talk about?") + # Keep message pair of [user message, assistant message] + history_messages = list(filter(lambda x: len(x) == 2, history_messages)) + user_prompt = messages[-1].content + return user_prompt, system_messages, history_messages diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index a8839b86a..910a97573 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -5,6 +5,7 @@ import asyncio import json import os import sys +from typing import List import uvicorn from fastapi import BackgroundTasks, FastAPI, Request @@ -24,6 +25,7 @@ from pilot.configs.model_config import * from pilot.model.llm_out.vicuna_base_llm import get_embeddings from pilot.model.loader import ModelLoader from pilot.server.chat_adapter import get_llm_chat_adapter +from pilot.scene.base_message import ModelMessage CFG = Config() @@ -128,6 +130,7 @@ app = FastAPI() class PromptRequest(BaseModel): + messages: List[ModelMessage] prompt: str temperature: float max_new_tokens: int @@ -170,6 +173,7 @@ async def api_generate_stream(request: Request): @app.post("/generate") def generate(prompt_request: PromptRequest) -> str: params = { + "messages": prompt_request.messages, "prompt": prompt_request.prompt, "temperature": prompt_request.temperature, "max_new_tokens": prompt_request.max_new_tokens, From b03f3efe97b62bb14ae37d51849f718f92deeb19 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Wed, 19 Jul 2023 12:17:01 +0800 Subject: [PATCH 2/2] feat: add warning message when using generate_llm_text --- pilot/scene/base_chat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index ad1afbe88..bd660a0cd 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -2,6 +2,7 @@ import time from abc import ABC, abstractmethod import datetime import traceback +import warnings import json from pydantic import BaseModel, Field, root_validator, validator, Extra from typing import ( @@ -229,6 +230,7 @@ class BaseChat(ABC): return self.nostream_call() def generate_llm_text(self) -> str: + warnings.warn("This method is deprecated - please use `generate_llm_messages`.") text = "" ### Load scene setting or character definition if self.prompt_template.template_define: