feat: structured prompt data interaction between dbgpt-server and llm-server

This commit is contained in:
FangYin Cheng 2023-07-18 18:07:09 +08:00
parent cde506385c
commit ae761a3bfa
5 changed files with 151 additions and 79 deletions

View File

@ -8,6 +8,11 @@ import copy
import torch import torch
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
from pilot.scene.base_message import ModelMessage, _parse_model_messages
# TODO move sep to scene prompt of model
_CHATGLM_SEP = "\n"
_CHATGLM2_SEP = "\n\n"
@torch.inference_mode() @torch.inference_mode()
@ -32,42 +37,20 @@ def chatglm_generate_stream(
generate_kwargs["temperature"] = temperature generate_kwargs["temperature"] = temperature
# TODO, Fix this # TODO, Fix this
print(prompt) # print(prompt)
messages = prompt.split(stop) # messages = prompt.split(stop)
# messages: List[ModelMessage] = params["messages"]
# # Add history conversation query, system_messages, hist = _parse_model_messages(messages)
hist = [HistoryEntry()] system_messages_str = "".join(system_messages)
system_messages = []
for message in messages[:-2]:
if len(message) <= 0:
continue
if "human:" in message:
hist[-1].add_question(message.split("human:")[1])
elif "system:" in message:
msg = message.split("system:")[1]
hist[-1].add_question(msg)
system_messages.append(msg)
elif "ai:" in message:
hist[-1].add_answer(message.split("ai:")[1])
hist.append(HistoryEntry())
else:
# TODO
# hist[-1].add_question(message.split("system:")[1])
# once_conversation.append(f"""###system:{message} """)
pass
try:
query = messages[-2].split("human:")[1]
except IndexError:
query = messages[-3].split("human:")[1]
hist = build_history(hist)
if not hist: if not hist:
# No history conversation, but has system messages, merge to user`s query # No history conversation, but has system messages, merge to user`s query
query = prompt_adaptation(system_messages, query) query = prompt_adaptation(system_messages_str, query)
else:
# history exist, add system message to head of history
hist[0][0] = system_messages_str + _CHATGLM2_SEP + hist[0][0]
print("Query Message: ", query) print("Query Message: ", query)
print("hist: ", hist) print("hist: ", hist)
# output = ""
# i = 0
for i, (response, new_hist) in enumerate( for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **generate_kwargs) model.stream_chat(tokenizer, query, hist, **generate_kwargs)
@ -103,10 +86,10 @@ def build_history(hist: List[HistoryEntry]) -> List[List[str]]:
return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist))) return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist)))
def prompt_adaptation(system_messages: List[str], human_message: str) -> str: def prompt_adaptation(system_messages_str: str, human_message: str) -> str:
if not system_messages: if not system_messages_str or system_messages_str == "":
return human_message return human_message
system_messages_str = " ".join(system_messages) # TODO Multi-model prompt adaptation
adaptation_rules = [ adaptation_rules = [
r"Question:\s*{}\s*", # chat_db scene r"Question:\s*{}\s*", # chat_db scene
r"Goals:\s*{}\s*", # chat_execution r"Goals:\s*{}\s*", # chat_execution
@ -119,4 +102,4 @@ def prompt_adaptation(system_messages: List[str], human_message: str) -> str:
if re.search(pattern, system_messages_str): if re.search(pattern, system_messages_str):
return system_messages_str return system_messages_str
# https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 # https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
return f"{system_messages_str}\n\n问:{human_message}\n\n答:" return system_messages_str + _CHATGLM2_SEP + human_message

View File

@ -3,8 +3,10 @@
import json import json
import requests import requests
from typing import List
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
from pilot.scene.base_message import ModelMessage
CFG = Config() CFG = Config()
@ -20,36 +22,17 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
"Token": CFG.proxy_api_key, "Token": CFG.proxy_api_key,
} }
messages = prompt.split(stop) messages: List[ModelMessage] = params["messages"]
# Add history conversation # Add history conversation
for message in messages: for message in messages:
if len(message) <= 0: if message.role == "human":
continue history.append({"role": "user", "content": message.content})
if "human:" in message: elif message.role == "system":
history.append( history.append({"role": "system", "content": message.content})
{"role": "user", "content": message.split("human:")[1]}, elif message.role == "ai":
) history.append({"role": "assistant", "content": message.content})
elif "system:" in message:
history.append(
{
"role": "system",
"content": message.split("system:")[1],
}
)
elif "ai:" in message:
history.append(
{
"role": "assistant",
"content": message.split("ai:")[1],
}
)
else: else:
history.append( pass
{
"role": "system",
"content": message,
}
)
# Move the last user's information to the end # Move the last user's information to the end
temp_his = history[::-1] temp_his = history[::-1]

View File

@ -37,6 +37,7 @@ from pilot.scene.base_message import (
HumanMessage, HumanMessage,
AIMessage, AIMessage,
ViewMessage, ViewMessage,
ModelMessage,
) )
from pilot.configs.config import Config from pilot.configs.config import Config
@ -116,6 +117,7 @@ class BaseChat(ABC):
payload = { payload = {
"model": self.llm_model, "model": self.llm_model,
"prompt": self.generate_llm_text(), "prompt": self.generate_llm_text(),
"messages": self.generate_llm_messages(),
"temperature": float(self.prompt_template.temperature), "temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens), "max_new_tokens": int(self.prompt_template.max_new_tokens),
"stop": self.prompt_template.sep, "stop": self.prompt_template.sep,
@ -244,24 +246,62 @@ class BaseChat(ABC):
text += self.__load_user_message() text += self.__load_user_message()
return text return text
def __load_system_message(self): def generate_llm_messages(self) -> List[ModelMessage]:
"""
Structured prompt messages interaction between dbgpt-server and llm-server
See https://github.com/csunny/DB-GPT/issues/328
"""
messages = []
### Load scene setting or character definition as system message
if self.prompt_template.template_define:
messages.append(
ModelMessage(
role="system",
content=self.prompt_template.template_define,
)
)
### Load prompt
messages += self.__load_system_message(str_message=False)
### Load examples
messages += self.__load_example_messages(str_message=False)
### Load History
messages += self.__load_histroy_messages(str_message=False)
### Load User Input
messages += self.__load_user_message(str_message=False)
return messages
def __load_system_message(self, str_message: bool = True):
system_convs = self.current_message.get_system_conv() system_convs = self.current_message.get_system_conv()
system_text = "" system_text = ""
system_messages = []
for system_conv in system_convs: for system_conv in system_convs:
system_text += ( system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep system_conv.type + ":" + system_conv.content + self.prompt_template.sep
) )
return system_text system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
)
return system_text if str_message else system_messages
def __load_user_message(self): def __load_user_message(self, str_message: bool = True):
user_conv = self.current_message.get_user_conv() user_conv = self.current_message.get_user_conv()
user_messages = []
if user_conv: if user_conv:
return user_conv.type + ":" + user_conv.content + self.prompt_template.sep user_text = (
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
)
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
)
return user_text if str_message else user_messages
else: else:
raise ValueError("Hi! What do you want to talk about") raise ValueError("Hi! What do you want to talk about")
def __load_example_messages(self): def __load_example_messages(self, str_message: bool = True):
example_text = "" example_text = ""
example_messages = []
if self.prompt_template.example_selector: if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples(): for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv["messages"]: for round_message in round_conv["messages"]:
@ -269,16 +309,22 @@ class BaseChat(ABC):
SystemMessage.type, SystemMessage.type,
ViewMessage.type, ViewMessage.type,
]: ]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += ( example_text += (
round_message["type"] message_type
+ ":" + ":"
+ round_message["data"]["content"] + message_content
+ self.prompt_template.sep + self.prompt_template.sep
) )
return example_text example_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return example_text if str_message else example_messages
def __load_histroy_messages(self): def __load_histroy_messages(self, str_message: bool = True):
history_text = "" history_text = ""
history_messages = []
if self.prompt_template.need_historical_messages: if self.prompt_template.need_historical_messages:
if self.history_message: if self.history_message:
logger.info( logger.info(
@ -290,12 +336,17 @@ class BaseChat(ABC):
ViewMessage.type, ViewMessage.type,
SystemMessage.type, SystemMessage.type,
]: ]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += ( history_text += (
first_message["type"] message_type
+ ":" + ":"
+ first_message["data"]["content"] + message_content
+ self.prompt_template.sep + self.prompt_template.sep
) )
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
index = self.chat_retention_rounds - 1 index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]: for round_conv in self.history_message[-index:]:
@ -304,12 +355,17 @@ class BaseChat(ABC):
SystemMessage.type, SystemMessage.type,
ViewMessage.type, ViewMessage.type,
]: ]:
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += ( history_text += (
round_message["type"] message_type
+ ":" + ":"
+ round_message["data"]["content"] + message_content
+ self.prompt_template.sep + self.prompt_template.sep
) )
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
else: else:
### user all history ### user all history
@ -320,14 +376,19 @@ class BaseChat(ABC):
SystemMessage.type, SystemMessage.type,
ViewMessage.type, ViewMessage.type,
]: ]:
message_type = message["type"]
message_content = message["data"]["content"]
history_text += ( history_text += (
message["type"] message_type
+ ":" + ":"
+ message["data"]["content"] + message_content
+ self.prompt_template.sep + self.prompt_template.sep
) )
history_messages.append(
ModelMessage(role=message_type, content=message_content)
)
return history_text return history_text if str_message else history_messages
def current_ai_response(self) -> str: def current_ai_response(self) -> str:
for message in self.current_message.messages: for message in self.current_message.messages:

View File

@ -6,6 +6,7 @@ from typing import (
Dict, Dict,
Generic, Generic,
List, List,
Tuple,
NamedTuple, NamedTuple,
Optional, Optional,
Sequence, Sequence,
@ -80,6 +81,14 @@ class SystemMessage(BaseMessage):
return "system" return "system"
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Similar to openai's message format"""
role: str
content: str
class Generation(BaseModel): class Generation(BaseModel):
"""Output of a single generation.""" """Output of a single generation."""
@ -146,3 +155,35 @@ def _message_from_dict(message: dict) -> BaseMessage:
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages] return [_message_from_dict(m) for m in messages]
def _parse_model_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]:
""" "
Parameters:
messages: List of message from base chat.
Returns:
A tuple contains user prompt, system message list and history message list
str: user prompt
List[str]: system messages
List[List[str]]: history message of user and assistant
"""
user_prompt = ""
system_messages: List[str] = []
history_messages: List[List[str]] = [[]]
for message in messages[:-1]:
if message.role == "human":
history_messages[-1].append(message.content)
elif message.role == "system":
system_messages.append(message.content)
elif message.role == "ai":
history_messages[-1].append(message.content)
history_messages.append([])
if messages[-1].role != "human":
raise ValueError("Hi! What do you want to talk about")
# Keep message pair of [user message, assistant message]
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
user_prompt = messages[-1].content
return user_prompt, system_messages, history_messages

View File

@ -5,6 +5,7 @@ import asyncio
import json import json
import os import os
import sys import sys
from typing import List
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
@ -24,6 +25,7 @@ from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader from pilot.model.loader import ModelLoader
from pilot.server.chat_adapter import get_llm_chat_adapter from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.scene.base_message import ModelMessage
CFG = Config() CFG = Config()
@ -128,6 +130,7 @@ app = FastAPI()
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
messages: List[ModelMessage]
prompt: str prompt: str
temperature: float temperature: float
max_new_tokens: int max_new_tokens: int
@ -170,6 +173,7 @@ async def api_generate_stream(request: Request):
@app.post("/generate") @app.post("/generate")
def generate(prompt_request: PromptRequest) -> str: def generate(prompt_request: PromptRequest) -> str:
params = { params = {
"messages": prompt_request.messages,
"prompt": prompt_request.prompt, "prompt": prompt_request.prompt,
"temperature": prompt_request.temperature, "temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens, "max_new_tokens": prompt_request.max_new_tokens,