mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 09:14:44 +00:00
feat: Multi-model prompt adaptation
This commit is contained in:
parent
c2df1c27ea
commit
f234b30f7a
@ -9,6 +9,7 @@ from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
|||||||
|
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
|
from pilot.prompts.prompt_registry import PromptTemplateRegistry
|
||||||
|
|
||||||
|
|
||||||
class Config(metaclass=Singleton):
|
class Config(metaclass=Singleton):
|
||||||
@ -76,7 +77,7 @@ class Config(metaclass=Singleton):
|
|||||||
)
|
)
|
||||||
self.speak_mode = False
|
self.speak_mode = False
|
||||||
|
|
||||||
self.prompt_templates = {}
|
self.prompt_template_registry = PromptTemplateRegistry()
|
||||||
### Related configuration of built-in commands
|
### Related configuration of built-in commands
|
||||||
self.command_registry = []
|
self.command_registry = []
|
||||||
|
|
||||||
|
94
pilot/prompts/prompt_registry.py
Normal file
94
pilot/prompts/prompt_registry.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List
|
||||||
|
import json
|
||||||
|
|
||||||
|
_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__"
|
||||||
|
_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__"
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateRegistry:
|
||||||
|
"""
|
||||||
|
The PromptTemplateRegistry class is a manager of prompt template of all scenes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.registry = defaultdict(dict)
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
prompt_template,
|
||||||
|
language: str = "en",
|
||||||
|
is_default=False,
|
||||||
|
model_names: List[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register prompt template with scene name, language
|
||||||
|
registry dict format:
|
||||||
|
{
|
||||||
|
"<scene_name>": {
|
||||||
|
_DEFAULT_MODEL_KEY: {
|
||||||
|
_DEFUALT_LANGUAGE_KEY: <prompt_template>,
|
||||||
|
"<language>": <prompt_template>
|
||||||
|
},
|
||||||
|
"<model_name>": {
|
||||||
|
"<language>": <prompt_template>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
scene_name = prompt_template.template_scene
|
||||||
|
if not scene_name:
|
||||||
|
raise ValueError("Prompt template scene name cannot be empty")
|
||||||
|
if not model_names:
|
||||||
|
model_names: List[str] = [_DEFAULT_MODEL_KEY]
|
||||||
|
scene_registry = self.registry[scene_name]
|
||||||
|
_register_scene_prompt_template(
|
||||||
|
scene_registry, prompt_template, language, model_names
|
||||||
|
)
|
||||||
|
if is_default:
|
||||||
|
_register_scene_prompt_template(
|
||||||
|
scene_registry,
|
||||||
|
prompt_template,
|
||||||
|
_DEFUALT_LANGUAGE_KEY,
|
||||||
|
[_DEFAULT_MODEL_KEY],
|
||||||
|
)
|
||||||
|
_register_scene_prompt_template(
|
||||||
|
scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY]
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_prompt_template(self, scene_name: str, language: str, model_name: str):
|
||||||
|
"""Get prompt template with scene name, language and model name"""
|
||||||
|
scene_registry = self.registry[scene_name]
|
||||||
|
registry = scene_registry.get(model_name)
|
||||||
|
print(
|
||||||
|
f"Get prompt template of scene_name: {scene_name} with model_name: {model_name} language: {language}"
|
||||||
|
)
|
||||||
|
if not registry:
|
||||||
|
registry = scene_registry.get(_DEFAULT_MODEL_KEY)
|
||||||
|
if not registry:
|
||||||
|
raise ValueError(
|
||||||
|
f"There is no template with scene name {scene_name}, model name {model_name}, language {language}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"scene: {scene_name} has custom prompt template of model: {model_name}, language: {language}"
|
||||||
|
)
|
||||||
|
prompt_template = registry.get(language)
|
||||||
|
if not prompt_template:
|
||||||
|
prompt_template = registry.get(_DEFUALT_LANGUAGE_KEY)
|
||||||
|
return prompt_template
|
||||||
|
|
||||||
|
|
||||||
|
def _register_scene_prompt_template(
|
||||||
|
scene_registry: Dict[str, Dict],
|
||||||
|
prompt_template,
|
||||||
|
language: str,
|
||||||
|
model_names: List[str],
|
||||||
|
):
|
||||||
|
for model_name in model_names:
|
||||||
|
if model_name not in scene_registry:
|
||||||
|
scene_registry[model_name] = dict()
|
||||||
|
registry = scene_registry[model_name]
|
||||||
|
registry[language] = prompt_template
|
@ -73,9 +73,14 @@ class BaseChat(ABC):
|
|||||||
self.memory = DuckdbHistoryMemory(chat_session_id)
|
self.memory = DuckdbHistoryMemory(chat_session_id)
|
||||||
|
|
||||||
### load prompt template
|
### load prompt template
|
||||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
# self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||||
self.chat_mode.value()
|
# self.chat_mode.value()
|
||||||
]
|
# ]
|
||||||
|
self.prompt_template: PromptTemplate = (
|
||||||
|
CFG.prompt_template_registry.get_prompt_template(
|
||||||
|
self.chat_mode.value(), language=CFG.LANGUAGE, model_name=CFG.LLM_MODEL
|
||||||
|
)
|
||||||
|
)
|
||||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
|
@ -56,4 +56,4 @@ prompt = PromptTemplate(
|
|||||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
|
@ -53,4 +53,4 @@ prompt = PromptTemplate(
|
|||||||
# example_selector=sql_data_example,
|
# example_selector=sql_data_example,
|
||||||
temperature=PROMPT_TEMPERATURE,
|
temperature=PROMPT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
|
@ -70,4 +70,4 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
|
||||||
|
@ -52,4 +52,4 @@ prompt = PromptTemplate(
|
|||||||
example_selector=plugin_example,
|
example_selector=plugin_example,
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
|
@ -49,5 +49,4 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
|
||||||
|
@ -50,5 +50,4 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
|
||||||
|
@ -49,5 +49,4 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
|
||||||
|
@ -49,5 +49,4 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
|
||||||
|
@ -50,5 +50,5 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
from . import prompt_chatglm
|
||||||
|
58
pilot/scene/chat_knowledge/v1/prompt_chatglm.py
Normal file
58
pilot/scene/chat_knowledge/v1/prompt_chatglm.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import builtins
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||||
|
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||||
|
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||||
|
已知内容:
|
||||||
|
{context}
|
||||||
|
问题:
|
||||||
|
{question}
|
||||||
|
"""
|
||||||
|
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||||
|
known information:
|
||||||
|
{context}
|
||||||
|
question:
|
||||||
|
{question}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = (
|
||||||
|
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ChatKnowledge.value(),
|
||||||
|
input_variables=["context", "question"],
|
||||||
|
response_format=None,
|
||||||
|
template_define=None,
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=NormalChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(
|
||||||
|
prompt,
|
||||||
|
language=CFG.LANGUAGE,
|
||||||
|
is_default=False,
|
||||||
|
model_names=["chatglm-6b-int4", "chatglm-6b", "chatglm2-6b", "chatglm2-6b-int4"],
|
||||||
|
)
|
@ -28,4 +28,5 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
# CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user