mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 18:10:02 +00:00
feat: structured prompt data interaction between dbgpt-server and llm-server
This commit is contained in:
parent
cde506385c
commit
ae761a3bfa
@ -8,6 +8,11 @@ import copy
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||||
|
from pilot.scene.base_message import ModelMessage, _parse_model_messages
|
||||||
|
|
||||||
|
# TODO move sep to scene prompt of model
|
||||||
|
_CHATGLM_SEP = "\n"
|
||||||
|
_CHATGLM2_SEP = "\n\n"
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -32,42 +37,20 @@ def chatglm_generate_stream(
|
|||||||
generate_kwargs["temperature"] = temperature
|
generate_kwargs["temperature"] = temperature
|
||||||
|
|
||||||
# TODO, Fix this
|
# TODO, Fix this
|
||||||
print(prompt)
|
# print(prompt)
|
||||||
messages = prompt.split(stop)
|
# messages = prompt.split(stop)
|
||||||
#
|
messages: List[ModelMessage] = params["messages"]
|
||||||
# # Add history conversation
|
query, system_messages, hist = _parse_model_messages(messages)
|
||||||
hist = [HistoryEntry()]
|
system_messages_str = "".join(system_messages)
|
||||||
system_messages = []
|
|
||||||
for message in messages[:-2]:
|
|
||||||
if len(message) <= 0:
|
|
||||||
continue
|
|
||||||
if "human:" in message:
|
|
||||||
hist[-1].add_question(message.split("human:")[1])
|
|
||||||
elif "system:" in message:
|
|
||||||
msg = message.split("system:")[1]
|
|
||||||
hist[-1].add_question(msg)
|
|
||||||
system_messages.append(msg)
|
|
||||||
elif "ai:" in message:
|
|
||||||
hist[-1].add_answer(message.split("ai:")[1])
|
|
||||||
hist.append(HistoryEntry())
|
|
||||||
else:
|
|
||||||
# TODO
|
|
||||||
# hist[-1].add_question(message.split("system:")[1])
|
|
||||||
# once_conversation.append(f"""###system:{message} """)
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
query = messages[-2].split("human:")[1]
|
|
||||||
except IndexError:
|
|
||||||
query = messages[-3].split("human:")[1]
|
|
||||||
hist = build_history(hist)
|
|
||||||
if not hist:
|
if not hist:
|
||||||
# No history conversation, but has system messages, merge to user`s query
|
# No history conversation, but has system messages, merge to user`s query
|
||||||
query = prompt_adaptation(system_messages, query)
|
query = prompt_adaptation(system_messages_str, query)
|
||||||
|
else:
|
||||||
|
# history exist, add system message to head of history
|
||||||
|
hist[0][0] = system_messages_str + _CHATGLM2_SEP + hist[0][0]
|
||||||
|
|
||||||
print("Query Message: ", query)
|
print("Query Message: ", query)
|
||||||
print("hist: ", hist)
|
print("hist: ", hist)
|
||||||
# output = ""
|
|
||||||
# i = 0
|
|
||||||
|
|
||||||
for i, (response, new_hist) in enumerate(
|
for i, (response, new_hist) in enumerate(
|
||||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||||
@ -103,10 +86,10 @@ def build_history(hist: List[HistoryEntry]) -> List[List[str]]:
|
|||||||
return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist)))
|
return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist)))
|
||||||
|
|
||||||
|
|
||||||
def prompt_adaptation(system_messages: List[str], human_message: str) -> str:
|
def prompt_adaptation(system_messages_str: str, human_message: str) -> str:
|
||||||
if not system_messages:
|
if not system_messages_str or system_messages_str == "":
|
||||||
return human_message
|
return human_message
|
||||||
system_messages_str = " ".join(system_messages)
|
# TODO Multi-model prompt adaptation
|
||||||
adaptation_rules = [
|
adaptation_rules = [
|
||||||
r"Question:\s*{}\s*", # chat_db scene
|
r"Question:\s*{}\s*", # chat_db scene
|
||||||
r"Goals:\s*{}\s*", # chat_execution
|
r"Goals:\s*{}\s*", # chat_execution
|
||||||
@ -119,4 +102,4 @@ def prompt_adaptation(system_messages: List[str], human_message: str) -> str:
|
|||||||
if re.search(pattern, system_messages_str):
|
if re.search(pattern, system_messages_str):
|
||||||
return system_messages_str
|
return system_messages_str
|
||||||
# https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
# https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
return f"{system_messages_str}\n\n问:{human_message}\n\n答:"
|
return system_messages_str + _CHATGLM2_SEP + human_message
|
||||||
|
@ -3,8 +3,10 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
from typing import List
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||||
|
from pilot.scene.base_message import ModelMessage
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -20,36 +22,17 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
"Token": CFG.proxy_api_key,
|
"Token": CFG.proxy_api_key,
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = prompt.split(stop)
|
messages: List[ModelMessage] = params["messages"]
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if len(message) <= 0:
|
if message.role == "human":
|
||||||
continue
|
history.append({"role": "user", "content": message.content})
|
||||||
if "human:" in message:
|
elif message.role == "system":
|
||||||
history.append(
|
history.append({"role": "system", "content": message.content})
|
||||||
{"role": "user", "content": message.split("human:")[1]},
|
elif message.role == "ai":
|
||||||
)
|
history.append({"role": "assistant", "content": message.content})
|
||||||
elif "system:" in message:
|
|
||||||
history.append(
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": message.split("system:")[1],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif "ai:" in message:
|
|
||||||
history.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": message.split("ai:")[1],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
history.append(
|
pass
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": message,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move the last user's information to the end
|
# Move the last user's information to the end
|
||||||
temp_his = history[::-1]
|
temp_his = history[::-1]
|
||||||
|
@ -37,6 +37,7 @@ from pilot.scene.base_message import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
AIMessage,
|
AIMessage,
|
||||||
ViewMessage,
|
ViewMessage,
|
||||||
|
ModelMessage,
|
||||||
)
|
)
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
@ -116,6 +117,7 @@ class BaseChat(ABC):
|
|||||||
payload = {
|
payload = {
|
||||||
"model": self.llm_model,
|
"model": self.llm_model,
|
||||||
"prompt": self.generate_llm_text(),
|
"prompt": self.generate_llm_text(),
|
||||||
|
"messages": self.generate_llm_messages(),
|
||||||
"temperature": float(self.prompt_template.temperature),
|
"temperature": float(self.prompt_template.temperature),
|
||||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||||
"stop": self.prompt_template.sep,
|
"stop": self.prompt_template.sep,
|
||||||
@ -244,24 +246,62 @@ class BaseChat(ABC):
|
|||||||
text += self.__load_user_message()
|
text += self.__load_user_message()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def __load_system_message(self):
|
def generate_llm_messages(self) -> List[ModelMessage]:
|
||||||
|
"""
|
||||||
|
Structured prompt messages interaction between dbgpt-server and llm-server
|
||||||
|
See https://github.com/csunny/DB-GPT/issues/328
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
### Load scene setting or character definition as system message
|
||||||
|
if self.prompt_template.template_define:
|
||||||
|
messages.append(
|
||||||
|
ModelMessage(
|
||||||
|
role="system",
|
||||||
|
content=self.prompt_template.template_define,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
### Load prompt
|
||||||
|
messages += self.__load_system_message(str_message=False)
|
||||||
|
### Load examples
|
||||||
|
messages += self.__load_example_messages(str_message=False)
|
||||||
|
|
||||||
|
### Load History
|
||||||
|
messages += self.__load_histroy_messages(str_message=False)
|
||||||
|
|
||||||
|
### Load User Input
|
||||||
|
messages += self.__load_user_message(str_message=False)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def __load_system_message(self, str_message: bool = True):
|
||||||
system_convs = self.current_message.get_system_conv()
|
system_convs = self.current_message.get_system_conv()
|
||||||
system_text = ""
|
system_text = ""
|
||||||
|
system_messages = []
|
||||||
for system_conv in system_convs:
|
for system_conv in system_convs:
|
||||||
system_text += (
|
system_text += (
|
||||||
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||||
)
|
)
|
||||||
return system_text
|
system_messages.append(
|
||||||
|
ModelMessage(role=system_conv.type, content=system_conv.content)
|
||||||
|
)
|
||||||
|
return system_text if str_message else system_messages
|
||||||
|
|
||||||
def __load_user_message(self):
|
def __load_user_message(self, str_message: bool = True):
|
||||||
user_conv = self.current_message.get_user_conv()
|
user_conv = self.current_message.get_user_conv()
|
||||||
|
user_messages = []
|
||||||
if user_conv:
|
if user_conv:
|
||||||
return user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
user_text = (
|
||||||
|
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
||||||
|
)
|
||||||
|
user_messages.append(
|
||||||
|
ModelMessage(role=user_conv.type, content=user_conv.content)
|
||||||
|
)
|
||||||
|
return user_text if str_message else user_messages
|
||||||
else:
|
else:
|
||||||
raise ValueError("Hi! What do you want to talk about?")
|
raise ValueError("Hi! What do you want to talk about?")
|
||||||
|
|
||||||
def __load_example_messages(self):
|
def __load_example_messages(self, str_message: bool = True):
|
||||||
example_text = ""
|
example_text = ""
|
||||||
|
example_messages = []
|
||||||
if self.prompt_template.example_selector:
|
if self.prompt_template.example_selector:
|
||||||
for round_conv in self.prompt_template.example_selector.examples():
|
for round_conv in self.prompt_template.example_selector.examples():
|
||||||
for round_message in round_conv["messages"]:
|
for round_message in round_conv["messages"]:
|
||||||
@ -269,16 +309,22 @@ class BaseChat(ABC):
|
|||||||
SystemMessage.type,
|
SystemMessage.type,
|
||||||
ViewMessage.type,
|
ViewMessage.type,
|
||||||
]:
|
]:
|
||||||
|
message_type = round_message["type"]
|
||||||
|
message_content = round_message["data"]["content"]
|
||||||
example_text += (
|
example_text += (
|
||||||
round_message["type"]
|
message_type
|
||||||
+ ":"
|
+ ":"
|
||||||
+ round_message["data"]["content"]
|
+ message_content
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
return example_text
|
example_messages.append(
|
||||||
|
ModelMessage(role=message_type, content=message_content)
|
||||||
|
)
|
||||||
|
return example_text if str_message else example_messages
|
||||||
|
|
||||||
def __load_histroy_messages(self):
|
def __load_histroy_messages(self, str_message: bool = True):
|
||||||
history_text = ""
|
history_text = ""
|
||||||
|
history_messages = []
|
||||||
if self.prompt_template.need_historical_messages:
|
if self.prompt_template.need_historical_messages:
|
||||||
if self.history_message:
|
if self.history_message:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -290,12 +336,17 @@ class BaseChat(ABC):
|
|||||||
ViewMessage.type,
|
ViewMessage.type,
|
||||||
SystemMessage.type,
|
SystemMessage.type,
|
||||||
]:
|
]:
|
||||||
|
message_type = first_message["type"]
|
||||||
|
message_content = first_message["data"]["content"]
|
||||||
history_text += (
|
history_text += (
|
||||||
first_message["type"]
|
message_type
|
||||||
+ ":"
|
+ ":"
|
||||||
+ first_message["data"]["content"]
|
+ message_content
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
history_messages.append(
|
||||||
|
ModelMessage(role=message_type, content=message_content)
|
||||||
|
)
|
||||||
|
|
||||||
index = self.chat_retention_rounds - 1
|
index = self.chat_retention_rounds - 1
|
||||||
for round_conv in self.history_message[-index:]:
|
for round_conv in self.history_message[-index:]:
|
||||||
@ -304,12 +355,17 @@ class BaseChat(ABC):
|
|||||||
SystemMessage.type,
|
SystemMessage.type,
|
||||||
ViewMessage.type,
|
ViewMessage.type,
|
||||||
]:
|
]:
|
||||||
|
message_type = round_message["type"]
|
||||||
|
message_content = round_message["data"]["content"]
|
||||||
history_text += (
|
history_text += (
|
||||||
round_message["type"]
|
message_type
|
||||||
+ ":"
|
+ ":"
|
||||||
+ round_message["data"]["content"]
|
+ message_content
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
history_messages.append(
|
||||||
|
ModelMessage(role=message_type, content=message_content)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
### user all history
|
### user all history
|
||||||
@ -320,14 +376,19 @@ class BaseChat(ABC):
|
|||||||
SystemMessage.type,
|
SystemMessage.type,
|
||||||
ViewMessage.type,
|
ViewMessage.type,
|
||||||
]:
|
]:
|
||||||
|
message_type = message["type"]
|
||||||
|
message_content = message["data"]["content"]
|
||||||
history_text += (
|
history_text += (
|
||||||
message["type"]
|
message_type
|
||||||
+ ":"
|
+ ":"
|
||||||
+ message["data"]["content"]
|
+ message_content
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
history_messages.append(
|
||||||
|
ModelMessage(role=message_type, content=message_content)
|
||||||
|
)
|
||||||
|
|
||||||
return history_text
|
return history_text if str_message else history_messages
|
||||||
|
|
||||||
def current_ai_response(self) -> str:
|
def current_ai_response(self) -> str:
|
||||||
for message in self.current_message.messages:
|
for message in self.current_message.messages:
|
||||||
|
@ -6,6 +6,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
List,
|
List,
|
||||||
|
Tuple,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -80,6 +81,14 @@ class SystemMessage(BaseMessage):
|
|||||||
return "system"
|
return "system"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMessage(BaseModel):
|
||||||
|
"""Type of message that interaction between dbgpt-server and llm-server"""
|
||||||
|
|
||||||
|
"""Similar to openai's message format"""
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class Generation(BaseModel):
|
class Generation(BaseModel):
|
||||||
"""Output of a single generation."""
|
"""Output of a single generation."""
|
||||||
|
|
||||||
@ -146,3 +155,35 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
|||||||
|
|
||||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||||
return [_message_from_dict(m) for m in messages]
|
return [_message_from_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_model_messages(
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
) -> Tuple[str, List[str], List[List[str, str]]]:
|
||||||
|
""" "
|
||||||
|
Parameters:
|
||||||
|
messages: List of message from base chat.
|
||||||
|
Returns:
|
||||||
|
A tuple contains user prompt, system message list and history message list
|
||||||
|
str: user prompt
|
||||||
|
List[str]: system messages
|
||||||
|
List[List[str]]: history message of user and assistant
|
||||||
|
"""
|
||||||
|
user_prompt = ""
|
||||||
|
system_messages: List[str] = []
|
||||||
|
history_messages: List[List[str]] = [[]]
|
||||||
|
|
||||||
|
for message in messages[:-1]:
|
||||||
|
if message.role == "human":
|
||||||
|
history_messages[-1].append(message.content)
|
||||||
|
elif message.role == "system":
|
||||||
|
system_messages.append(message.content)
|
||||||
|
elif message.role == "ai":
|
||||||
|
history_messages[-1].append(message.content)
|
||||||
|
history_messages.append([])
|
||||||
|
if messages[-1].role != "human":
|
||||||
|
raise ValueError("Hi! What do you want to talk about?")
|
||||||
|
# Keep message pair of [user message, assistant message]
|
||||||
|
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
|
||||||
|
user_prompt = messages[-1].content
|
||||||
|
return user_prompt, system_messages, history_messages
|
||||||
|
@ -5,6 +5,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import BackgroundTasks, FastAPI, Request
|
from fastapi import BackgroundTasks, FastAPI, Request
|
||||||
@ -24,6 +25,7 @@ from pilot.configs.model_config import *
|
|||||||
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||||
from pilot.model.loader import ModelLoader
|
from pilot.model.loader import ModelLoader
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||||
|
from pilot.scene.base_message import ModelMessage
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -128,6 +130,7 @@ app = FastAPI()
|
|||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
|
messages: List[ModelMessage]
|
||||||
prompt: str
|
prompt: str
|
||||||
temperature: float
|
temperature: float
|
||||||
max_new_tokens: int
|
max_new_tokens: int
|
||||||
@ -170,6 +173,7 @@ async def api_generate_stream(request: Request):
|
|||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
def generate(prompt_request: PromptRequest) -> str:
|
def generate(prompt_request: PromptRequest) -> str:
|
||||||
params = {
|
params = {
|
||||||
|
"messages": prompt_request.messages,
|
||||||
"prompt": prompt_request.prompt,
|
"prompt": prompt_request.prompt,
|
||||||
"temperature": prompt_request.temperature,
|
"temperature": prompt_request.temperature,
|
||||||
"max_new_tokens": prompt_request.max_new_tokens,
|
"max_new_tokens": prompt_request.max_new_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user