diff --git a/pilot/configs/config.py b/pilot/configs/config.py index ee407e6a2..5b529d447 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -9,6 +9,7 @@ from auto_gpt_plugin_template import AutoGPTPluginTemplate from pilot.singleton import Singleton from pilot.common.sql_database import Database +from pilot.prompts.prompt_registry import PromptTemplateRegistry class Config(metaclass=Singleton): @@ -76,7 +77,7 @@ class Config(metaclass=Singleton): ) self.speak_mode = False - self.prompt_templates = {} + self.prompt_template_registry = PromptTemplateRegistry() ### Related configuration of built-in commands self.command_registry = [] diff --git a/pilot/prompts/prompt_registry.py b/pilot/prompts/prompt_registry.py new file mode 100644 index 000000000..0372e92b3 --- /dev/null +++ b/pilot/prompts/prompt_registry.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from collections import defaultdict +from typing import Dict, List +import json + +_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__" +_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__" + + +class PromptTemplateRegistry: + """ + The PromptTemplateRegistry class is a manager of prompt template of all scenes. + """ + + def __init__(self) -> None: + self.registry = defaultdict(dict) + + def register( + self, + prompt_template, + language: str = "en", + is_default=False, + model_names: List[str] = None, + ) -> None: + """Register prompt template with scene name, language + registry dict format: + { + "": { + _DEFAULT_MODEL_KEY: { + _DEFUALT_LANGUAGE_KEY: , + "": + }, + "": { + "": + } + } + } + """ + scene_name = prompt_template.template_scene + if not scene_name: + raise ValueError("Prompt template scene name cannot be empty") + if not model_names: + model_names: List[str] = [_DEFAULT_MODEL_KEY] + scene_registry = self.registry[scene_name] + _register_scene_prompt_template( + scene_registry, prompt_template, language, model_names + ) + if is_default: + _register_scene_prompt_template( + scene_registry, + prompt_template, + _DEFUALT_LANGUAGE_KEY, + [_DEFAULT_MODEL_KEY], + ) + _register_scene_prompt_template( + scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY] + ) + + def get_prompt_template(self, scene_name: str, language: str, model_name: str): + """Get prompt template with scene name, language and model name""" + scene_registry = self.registry[scene_name] + registry = scene_registry.get(model_name) + print( + f"Get prompt template of scene_name: {scene_name} with model_name: {model_name} language: {language}" + ) + if not registry: + registry = scene_registry.get(_DEFAULT_MODEL_KEY) + if not registry: + raise ValueError( + f"There is no template with scene name {scene_name}, model name {model_name}, language {language}" + ) + else: + print( + f"scene: {scene_name} has custom prompt template of model: {model_name}, language: {language}" + ) + prompt_template = registry.get(language) + if not prompt_template: + prompt_template = registry.get(_DEFUALT_LANGUAGE_KEY) + return prompt_template + + +def _register_scene_prompt_template( + scene_registry: Dict[str, Dict], + prompt_template, + language: str, + model_names: List[str], +): + for model_name in model_names: + if model_name not in scene_registry: + scene_registry[model_name] = dict() + registry = scene_registry[model_name] + registry[language] = prompt_template diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index e6b9bb9f4..940b3e113 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -73,9 +73,14 @@ class BaseChat(ABC): self.memory = DuckdbHistoryMemory(chat_session_id) ### load prompt template - self.prompt_template: PromptTemplate = CFG.prompt_templates[ - self.chat_mode.value() - ] + # self.prompt_template: PromptTemplate = CFG.prompt_templates[ + # self.chat_mode.value() + # ] + self.prompt_template: PromptTemplate = ( + CFG.prompt_template_registry.get_prompt_template( + self.chat_mode.value(), language=CFG.LANGUAGE, model_name=CFG.LLM_MODEL + ) + ) self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(chat_mode.value()) self.current_tokens_used: int = 0 diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index 212d24e96..147794752 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -56,4 +56,4 @@ prompt = PromptTemplate( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), ) -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 8ea449975..60b748224 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -53,4 +53,4 @@ prompt = PromptTemplate( # example_selector=sql_data_example, temperature=PROMPT_TEMPERATURE, ) -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/scene/chat_db/professional_qa/prompt.py b/pilot/scene/chat_db/professional_qa/prompt.py index f33b0971f..c46382b8e 100644 --- a/pilot/scene/chat_db/professional_qa/prompt.py +++ b/pilot/scene/chat_db/professional_qa/prompt.py @@ -70,4 +70,4 @@ prompt = PromptTemplate( ), ) -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index a4d83459f..e3c76d4c3 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -52,4 +52,4 @@ prompt = PromptTemplate( example_selector=plugin_example, ) -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py index 2b74add93..f3ac94115 100644 --- a/pilot/scene/chat_knowledge/custom/prompt.py +++ b/pilot/scene/chat_knowledge/custom/prompt.py @@ -49,5 +49,4 @@ prompt = PromptTemplate( ), ) - -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True) diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py index 5940295a6..5b0e33e62 100644 --- a/pilot/scene/chat_knowledge/default/prompt.py +++ b/pilot/scene/chat_knowledge/default/prompt.py @@ -50,5 +50,4 @@ prompt = PromptTemplate( ), ) - -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True) diff --git a/pilot/scene/chat_knowledge/inner_db_summary/prompt.py b/pilot/scene/chat_knowledge/inner_db_summary/prompt.py index ea2407329..6153a7b5f 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/prompt.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/prompt.py @@ -49,5 +49,4 @@ prompt = PromptTemplate( ), ) - -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/scene/chat_knowledge/url/prompt.py b/pilot/scene/chat_knowledge/url/prompt.py index 2a65246ec..4e09f1a82 100644 --- a/pilot/scene/chat_knowledge/url/prompt.py +++ b/pilot/scene/chat_knowledge/url/prompt.py @@ -49,5 +49,4 @@ prompt = PromptTemplate( ), ) - -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True) diff --git a/pilot/scene/chat_knowledge/v1/prompt.py b/pilot/scene/chat_knowledge/v1/prompt.py index 60a483b01..5dbdbf602 100644 --- a/pilot/scene/chat_knowledge/v1/prompt.py +++ b/pilot/scene/chat_knowledge/v1/prompt.py @@ -50,5 +50,5 @@ prompt = PromptTemplate( ), ) - -CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True) +from . import prompt_chatglm diff --git a/pilot/scene/chat_knowledge/v1/prompt_chatglm.py b/pilot/scene/chat_knowledge/v1/prompt_chatglm.py new file mode 100644 index 000000000..44353c9ff --- /dev/null +++ b/pilot/scene/chat_knowledge/v1/prompt_chatglm.py @@ -0,0 +1,58 @@ +import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser + + +CFG = Config() + +PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. + The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ + + +_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 + 已知内容: + {context} + 问题: + {question} +""" +_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. + known information: + {context} + question: + {question} +""" + +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatKnowledge.value(), + input_variables=["context", "question"], + response_format=None, + template_define=None, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_template_registry.register( + prompt, + language=CFG.LANGUAGE, + is_default=False, + model_names=["chatglm-6b-int4", "chatglm-6b", "chatglm2-6b", "chatglm2-6b-int4"], +) diff --git a/pilot/scene/chat_normal/prompt.py b/pilot/scene/chat_normal/prompt.py index 9c4f39a2d..ce0c08234 100644 --- a/pilot/scene/chat_normal/prompt.py +++ b/pilot/scene/chat_normal/prompt.py @@ -28,4 +28,5 @@ prompt = PromptTemplate( ), ) -CFG.prompt_templates.update({prompt.template_scene: prompt}) +# CFG.prompt_templates.update({prompt.template_scene: prompt}) +CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)