feat: structured prompt data interaction between dbgpt-server and llm-server

This commit is contained in:
FangYin Cheng 2023-07-18 18:07:09 +08:00
parent cde506385c
commit ae761a3bfa
5 changed files with 151 additions and 79 deletions

View File

@ -8,6 +8,11 @@ import copy
import torch
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
from pilot.scene.base_message import ModelMessage, _parse_model_messages
# TODO move sep to scene prompt of model
_CHATGLM_SEP = "\n"
_CHATGLM2_SEP = "\n\n"
@torch.inference_mode()
@ -32,42 +37,20 @@ def chatglm_generate_stream(
generate_kwargs["temperature"] = temperature
# TODO, Fix this
print(prompt)
messages = prompt.split(stop)
#
# # Add history conversation
hist = [HistoryEntry()]
system_messages = []
for message in messages[:-2]:
if len(message) <= 0:
continue
if "human:" in message:
hist[-1].add_question(message.split("human:")[1])
elif "system:" in message:
msg = message.split("system:")[1]
hist[-1].add_question(msg)
system_messages.append(msg)
elif "ai:" in message:
hist[-1].add_answer(message.split("ai:")[1])
hist.append(HistoryEntry())
else:
# TODO
# hist[-1].add_question(message.split("system:")[1])
# once_conversation.append(f"""###system:{message} """)
pass
try:
query = messages[-2].split("human:")[1]
except IndexError:
query = messages[-3].split("human:")[1]
hist = build_history(hist)
# print(prompt)
# messages = prompt.split(stop)
messages: List[ModelMessage] = params["messages"]
query, system_messages, hist = _parse_model_messages(messages)
system_messages_str = "".join(system_messages)
if not hist:
# No history conversation, but has system messages, merge to user`s query
query = prompt_adaptation(system_messages, query)
query = prompt_adaptation(system_messages_str, query)
else:
# history exist, add system message to head of history
hist[0][0] = system_messages_str + _CHATGLM2_SEP + hist[0][0]
print("Query Message: ", query)
print("hist: ", hist)
# output = ""
# i = 0
for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
@ -103,10 +86,10 @@ def build_history(hist: List[HistoryEntry]) -> List[List[str]]:
return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist)))
def prompt_adaptation(system_messages: List[str], human_message: str) -> str:
if not system_messages:
def prompt_adaptation(system_messages_str: str, human_message: str) -> str:
if not system_messages_str or system_messages_str == "":
return human_message
system_messages_str = " ".join(system_messages)
# TODO Multi-model prompt adaptation
adaptation_rules = [
r"Question:\s*{}\s*", # chat_db scene
r"Goals:\s*{}\s*", # chat_execution
@ -119,4 +102,4 @@ def prompt_adaptation(system_messages: List[str], human_message: str) -> str:
if re.search(pattern, system_messages_str):
return system_messages_str
# https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
return f"{system_messages_str}\n\n问:{human_message}\n\n答:"
return system_messages_str + _CHATGLM2_SEP + human_message

View File

@ -3,8 +3,10 @@
import json
import requests
from typing import List
from pilot.configs.config import Config
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
from pilot.scene.base_message import ModelMessage
CFG = Config()
@ -20,36 +22,17 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
"Token": CFG.proxy_api_key,
}
messages = prompt.split(stop)
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
if len(message) <= 0:
continue
if "human:" in message:
history.append(
{"role": "user", "content": message.split("human:")[1]},
)
elif "system:" in message:
history.append(
{
"role": "system",
"content": message.split("system:")[1],
}
)
elif "ai:" in message:
history.append(
{
"role": "assistant",
"content": message.split("ai:")[1],
}
)
if message.role == "human":
history.append({"role": "user", "content": message.content})
elif message.role == "system":
history.append({"role": "system", "content": message.content})
elif message.role == "ai":
history.append({"role": "assistant", "content": message.content})
else:
history.append(
{
"role": "system",
"content": message,
}
)
pass
# Move the last user's information to the end
temp_his = history[::-1]

View File

@ -37,6 +37,7 @@ from pilot.scene.base_message import (
HumanMessage,
AIMessage,
ViewMessage,
ModelMessage,
)
from pilot.configs.config import Config
@ -116,6 +117,7 @@ class BaseChat(ABC):
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
"messages": self.generate_llm_messages(),
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"stop": self.prompt_template.sep,
@ -244,24 +246,62 @@ class BaseChat(ABC):
text += self.__load_user_message()
return text
def __load_system_message(self):
def generate_llm_messages(self) -> List[ModelMessage]:
"""
Structured prompt messages interaction between dbgpt-server and llm-server
See https://github.com/csunny/DB-GPT/issues/328
"""
messages = []
### Load scene setting or character definition as system message
if self.prompt_template.template_define:
messages.append(
ModelMessage(
role="system",
content=self.prompt_template.template_define,
)
)
### Load prompt
messages += self.__load_system_message(str_message=False)
### Load examples
messages += self.__load_example_messages(str_message=False)
### Load History
messages += self.__load_histroy_messages(str_message=False)
### Load User Input
messages += self.__load_user_message(str_message=False)
return messages
def __load_system_message(self, str_message: bool = True):
system_convs = self.current_message.get_system_conv()
system_text = ""
system_messages = []
for system_conv in system_convs:
system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
return system_text
system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
)
return system_text if str_message else system_messages
def __load_user_message(self):
def __load_user_message(self, str_message: bool = True):
user_conv = self.current_message.get_user_conv()
user_messages = []
if user_conv:
return user_conv.type + ":" + user_conv.content + self.prompt_template.sep
user_text = (
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
)
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
)
return user_text if str_message else user_messages
else:
raise ValueError("Hi! What do you want to talk about")
def __load_example_messages(self):
def __load_example_messages(self, str_message: bool = True):
example_text = ""
example_messages = []
if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv["messages"]:
@ -269,16 +309,22 @@ class BaseChat(ABC):
SystemMessage.type,
ViewMessage.type,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += (
round_message["type"]
message_type
+ ":"
+ round_message["data"]["content"]
+ message_content
+ self.prompt_template.sep
)
return example_text
example_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return example_text if str_message else example_messages
def __load_histroy_messages(self):
def __load_histroy_messages(self, str_message: bool = True):
history_text = ""
history_messages = []
if self.prompt_template.need_historical_messages:
if self.history_message:
logger.info(
@ -290,12 +336,17 @@ class BaseChat(ABC):
ViewMessage.type,
SystemMessage.type,
]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
first_message["type"]
message_type
+ ":"
+ first_message["data"]["content"]
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]:
@ -304,12 +355,17 @@ class BaseChat(ABC):
SystemMessage.type,
ViewMessage.type,
]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += (
round_message["type"]
message_type
+ ":"
+ round_message["data"]["content"]
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
else:
### user all history
@ -320,14 +376,19 @@ class BaseChat(ABC):
SystemMessage.type,
ViewMessage.type,
]:
message_type = message["type"]
message_content = message["data"]["content"]
history_text += (
message["type"]
message_type
+ ":"
+ message["data"]["content"]
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return history_text
return history_text if str_message else history_messages
def current_ai_response(self) -> str:
for message in self.current_message.messages:

View File

@ -6,6 +6,7 @@ from typing import (
Dict,
Generic,
List,
Tuple,
NamedTuple,
Optional,
Sequence,
@ -80,6 +81,14 @@ class SystemMessage(BaseMessage):
return "system"
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Similar to openai's message format"""
role: str
content: str
class Generation(BaseModel):
"""Output of a single generation."""
@ -146,3 +155,35 @@ def _message_from_dict(message: dict) -> BaseMessage:
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]
def _parse_model_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]:
""" "
Parameters:
messages: List of message from base chat.
Returns:
A tuple contains user prompt, system message list and history message list
str: user prompt
List[str]: system messages
List[List[str]]: history message of user and assistant
"""
user_prompt = ""
system_messages: List[str] = []
history_messages: List[List[str]] = [[]]
for message in messages[:-1]:
if message.role == "human":
history_messages[-1].append(message.content)
elif message.role == "system":
system_messages.append(message.content)
elif message.role == "ai":
history_messages[-1].append(message.content)
history_messages.append([])
if messages[-1].role != "human":
raise ValueError("Hi! What do you want to talk about")
# Keep message pair of [user message, assistant message]
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
user_prompt = messages[-1].content
return user_prompt, system_messages, history_messages

View File

@ -5,6 +5,7 @@ import asyncio
import json
import os
import sys
from typing import List
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
@ -24,6 +25,7 @@ from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader
from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.scene.base_message import ModelMessage
CFG = Config()
@ -128,6 +130,7 @@ app = FastAPI()
class PromptRequest(BaseModel):
messages: List[ModelMessage]
prompt: str
temperature: float
max_new_tokens: int
@ -170,6 +173,7 @@ async def api_generate_stream(request: Request):
@app.post("/generate")
def generate(prompt_request: PromptRequest) -> str:
params = {
"messages": prompt_request.messages,
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,