mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
feat: structured prompt data interaction between dbgpt-server and llm-server
This commit is contained in:
parent
cde506385c
commit
ae761a3bfa
@ -8,6 +8,11 @@ import copy
|
||||
import torch
|
||||
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
from pilot.scene.base_message import ModelMessage, _parse_model_messages
|
||||
|
||||
# TODO move sep to scene prompt of model
|
||||
_CHATGLM_SEP = "\n"
|
||||
_CHATGLM2_SEP = "\n\n"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -32,42 +37,20 @@ def chatglm_generate_stream(
|
||||
generate_kwargs["temperature"] = temperature
|
||||
|
||||
# TODO, Fix this
|
||||
print(prompt)
|
||||
messages = prompt.split(stop)
|
||||
#
|
||||
# # Add history conversation
|
||||
hist = [HistoryEntry()]
|
||||
system_messages = []
|
||||
for message in messages[:-2]:
|
||||
if len(message) <= 0:
|
||||
continue
|
||||
if "human:" in message:
|
||||
hist[-1].add_question(message.split("human:")[1])
|
||||
elif "system:" in message:
|
||||
msg = message.split("system:")[1]
|
||||
hist[-1].add_question(msg)
|
||||
system_messages.append(msg)
|
||||
elif "ai:" in message:
|
||||
hist[-1].add_answer(message.split("ai:")[1])
|
||||
hist.append(HistoryEntry())
|
||||
else:
|
||||
# TODO
|
||||
# hist[-1].add_question(message.split("system:")[1])
|
||||
# once_conversation.append(f"""###system:{message} """)
|
||||
pass
|
||||
|
||||
try:
|
||||
query = messages[-2].split("human:")[1]
|
||||
except IndexError:
|
||||
query = messages[-3].split("human:")[1]
|
||||
hist = build_history(hist)
|
||||
# print(prompt)
|
||||
# messages = prompt.split(stop)
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
query, system_messages, hist = _parse_model_messages(messages)
|
||||
system_messages_str = "".join(system_messages)
|
||||
if not hist:
|
||||
# No history conversation, but has system messages, merge to user`s query
|
||||
query = prompt_adaptation(system_messages, query)
|
||||
query = prompt_adaptation(system_messages_str, query)
|
||||
else:
|
||||
# history exist, add system message to head of history
|
||||
hist[0][0] = system_messages_str + _CHATGLM2_SEP + hist[0][0]
|
||||
|
||||
print("Query Message: ", query)
|
||||
print("hist: ", hist)
|
||||
# output = ""
|
||||
# i = 0
|
||||
|
||||
for i, (response, new_hist) in enumerate(
|
||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||
@ -103,10 +86,10 @@ def build_history(hist: List[HistoryEntry]) -> List[List[str]]:
|
||||
return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist)))
|
||||
|
||||
|
||||
def prompt_adaptation(system_messages: List[str], human_message: str) -> str:
|
||||
if not system_messages:
|
||||
def prompt_adaptation(system_messages_str: str, human_message: str) -> str:
|
||||
if not system_messages_str or system_messages_str == "":
|
||||
return human_message
|
||||
system_messages_str = " ".join(system_messages)
|
||||
# TODO Multi-model prompt adaptation
|
||||
adaptation_rules = [
|
||||
r"Question:\s*{}\s*", # chat_db scene
|
||||
r"Goals:\s*{}\s*", # chat_execution
|
||||
@ -119,4 +102,4 @@ def prompt_adaptation(system_messages: List[str], human_message: str) -> str:
|
||||
if re.search(pattern, system_messages_str):
|
||||
return system_messages_str
|
||||
# https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
return f"{system_messages_str}\n\n问:{human_message}\n\n答:"
|
||||
return system_messages_str + _CHATGLM2_SEP + human_message
|
||||
|
@ -3,8 +3,10 @@
|
||||
|
||||
import json
|
||||
import requests
|
||||
from typing import List
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -20,36 +22,17 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
"Token": CFG.proxy_api_key,
|
||||
}
|
||||
|
||||
messages = prompt.split(stop)
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
if len(message) <= 0:
|
||||
continue
|
||||
if "human:" in message:
|
||||
history.append(
|
||||
{"role": "user", "content": message.split("human:")[1]},
|
||||
)
|
||||
elif "system:" in message:
|
||||
history.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": message.split("system:")[1],
|
||||
}
|
||||
)
|
||||
elif "ai:" in message:
|
||||
history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.split("ai:")[1],
|
||||
}
|
||||
)
|
||||
if message.role == "human":
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == "system":
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == "ai":
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
history.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": message,
|
||||
}
|
||||
)
|
||||
pass
|
||||
|
||||
# Move the last user's information to the end
|
||||
temp_his = history[::-1]
|
||||
|
@ -37,6 +37,7 @@ from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ViewMessage,
|
||||
ModelMessage,
|
||||
)
|
||||
from pilot.configs.config import Config
|
||||
|
||||
@ -116,6 +117,7 @@ class BaseChat(ABC):
|
||||
payload = {
|
||||
"model": self.llm_model,
|
||||
"prompt": self.generate_llm_text(),
|
||||
"messages": self.generate_llm_messages(),
|
||||
"temperature": float(self.prompt_template.temperature),
|
||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||
"stop": self.prompt_template.sep,
|
||||
@ -244,24 +246,62 @@ class BaseChat(ABC):
|
||||
text += self.__load_user_message()
|
||||
return text
|
||||
|
||||
def __load_system_message(self):
|
||||
def generate_llm_messages(self) -> List[ModelMessage]:
|
||||
"""
|
||||
Structured prompt messages interaction between dbgpt-server and llm-server
|
||||
See https://github.com/csunny/DB-GPT/issues/328
|
||||
"""
|
||||
messages = []
|
||||
### Load scene setting or character definition as system message
|
||||
if self.prompt_template.template_define:
|
||||
messages.append(
|
||||
ModelMessage(
|
||||
role="system",
|
||||
content=self.prompt_template.template_define,
|
||||
)
|
||||
)
|
||||
### Load prompt
|
||||
messages += self.__load_system_message(str_message=False)
|
||||
### Load examples
|
||||
messages += self.__load_example_messages(str_message=False)
|
||||
|
||||
### Load History
|
||||
messages += self.__load_histroy_messages(str_message=False)
|
||||
|
||||
### Load User Input
|
||||
messages += self.__load_user_message(str_message=False)
|
||||
return messages
|
||||
|
||||
def __load_system_message(self, str_message: bool = True):
|
||||
system_convs = self.current_message.get_system_conv()
|
||||
system_text = ""
|
||||
system_messages = []
|
||||
for system_conv in system_convs:
|
||||
system_text += (
|
||||
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||
)
|
||||
return system_text
|
||||
system_messages.append(
|
||||
ModelMessage(role=system_conv.type, content=system_conv.content)
|
||||
)
|
||||
return system_text if str_message else system_messages
|
||||
|
||||
def __load_user_message(self):
|
||||
def __load_user_message(self, str_message: bool = True):
|
||||
user_conv = self.current_message.get_user_conv()
|
||||
user_messages = []
|
||||
if user_conv:
|
||||
return user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
||||
user_text = (
|
||||
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
||||
)
|
||||
user_messages.append(
|
||||
ModelMessage(role=user_conv.type, content=user_conv.content)
|
||||
)
|
||||
return user_text if str_message else user_messages
|
||||
else:
|
||||
raise ValueError("Hi! What do you want to talk about?")
|
||||
|
||||
def __load_example_messages(self):
|
||||
def __load_example_messages(self, str_message: bool = True):
|
||||
example_text = ""
|
||||
example_messages = []
|
||||
if self.prompt_template.example_selector:
|
||||
for round_conv in self.prompt_template.example_selector.examples():
|
||||
for round_message in round_conv["messages"]:
|
||||
@ -269,16 +309,22 @@ class BaseChat(ABC):
|
||||
SystemMessage.type,
|
||||
ViewMessage.type,
|
||||
]:
|
||||
message_type = round_message["type"]
|
||||
message_content = round_message["data"]["content"]
|
||||
example_text += (
|
||||
round_message["type"]
|
||||
message_type
|
||||
+ ":"
|
||||
+ round_message["data"]["content"]
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
return example_text
|
||||
example_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
return example_text if str_message else example_messages
|
||||
|
||||
def __load_histroy_messages(self):
|
||||
def __load_histroy_messages(self, str_message: bool = True):
|
||||
history_text = ""
|
||||
history_messages = []
|
||||
if self.prompt_template.need_historical_messages:
|
||||
if self.history_message:
|
||||
logger.info(
|
||||
@ -290,12 +336,17 @@ class BaseChat(ABC):
|
||||
ViewMessage.type,
|
||||
SystemMessage.type,
|
||||
]:
|
||||
message_type = first_message["type"]
|
||||
message_content = first_message["data"]["content"]
|
||||
history_text += (
|
||||
first_message["type"]
|
||||
message_type
|
||||
+ ":"
|
||||
+ first_message["data"]["content"]
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
|
||||
index = self.chat_retention_rounds - 1
|
||||
for round_conv in self.history_message[-index:]:
|
||||
@ -304,12 +355,17 @@ class BaseChat(ABC):
|
||||
SystemMessage.type,
|
||||
ViewMessage.type,
|
||||
]:
|
||||
message_type = round_message["type"]
|
||||
message_content = round_message["data"]["content"]
|
||||
history_text += (
|
||||
round_message["type"]
|
||||
message_type
|
||||
+ ":"
|
||||
+ round_message["data"]["content"]
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
|
||||
else:
|
||||
### user all history
|
||||
@ -320,14 +376,19 @@ class BaseChat(ABC):
|
||||
SystemMessage.type,
|
||||
ViewMessage.type,
|
||||
]:
|
||||
message_type = message["type"]
|
||||
message_content = message["data"]["content"]
|
||||
history_text += (
|
||||
message["type"]
|
||||
message_type
|
||||
+ ":"
|
||||
+ message["data"]["content"]
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
|
||||
return history_text
|
||||
return history_text if str_message else history_messages
|
||||
|
||||
def current_ai_response(self) -> str:
|
||||
for message in self.current_message.messages:
|
||||
|
@ -6,6 +6,7 @@ from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Tuple,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -80,6 +81,14 @@ class SystemMessage(BaseMessage):
|
||||
return "system"
|
||||
|
||||
|
||||
class ModelMessage(BaseModel):
|
||||
"""Type of message that interaction between dbgpt-server and llm-server"""
|
||||
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
"""Output of a single generation."""
|
||||
|
||||
@ -146,3 +155,35 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
||||
|
||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
def _parse_model_messages(
|
||||
messages: List[ModelMessage],
|
||||
) -> Tuple[str, List[str], List[List[str, str]]]:
|
||||
""" "
|
||||
Parameters:
|
||||
messages: List of message from base chat.
|
||||
Returns:
|
||||
A tuple contains user prompt, system message list and history message list
|
||||
str: user prompt
|
||||
List[str]: system messages
|
||||
List[List[str]]: history message of user and assistant
|
||||
"""
|
||||
user_prompt = ""
|
||||
system_messages: List[str] = []
|
||||
history_messages: List[List[str]] = [[]]
|
||||
|
||||
for message in messages[:-1]:
|
||||
if message.role == "human":
|
||||
history_messages[-1].append(message.content)
|
||||
elif message.role == "system":
|
||||
system_messages.append(message.content)
|
||||
elif message.role == "ai":
|
||||
history_messages[-1].append(message.content)
|
||||
history_messages.append([])
|
||||
if messages[-1].role != "human":
|
||||
raise ValueError("Hi! What do you want to talk about?")
|
||||
# Keep message pair of [user message, assistant message]
|
||||
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
|
||||
user_prompt = messages[-1].content
|
||||
return user_prompt, system_messages, history_messages
|
||||
|
@ -5,6 +5,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
@ -24,6 +25,7 @@ from pilot.configs.model_config import *
|
||||
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -128,6 +130,7 @@ app = FastAPI()
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
messages: List[ModelMessage]
|
||||
prompt: str
|
||||
temperature: float
|
||||
max_new_tokens: int
|
||||
@ -170,6 +173,7 @@ async def api_generate_stream(request: Request):
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest) -> str:
|
||||
params = {
|
||||
"messages": prompt_request.messages,
|
||||
"prompt": prompt_request.prompt,
|
||||
"temperature": prompt_request.temperature,
|
||||
"max_new_tokens": prompt_request.max_new_tokens,
|
||||
|
Loading…
Reference in New Issue
Block a user