Files
DB-GPT/pilot/scene/chat_agent/chat.py
2023-11-04 18:08:28 +08:00

75 lines
2.7 KiB
Python

from typing import List, Dict
import logging
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent.commands.command_mange import ApiCall
from pilot.base_modules.agent import PluginPromptGenerator
from pilot.common.string_utils import extract_content
from .prompt import prompt
from pilot.component import ComponentType
from pilot.base_modules.agent.controller import ModuleAgent
from pilot.utils.tracer import root_tracer, trace
CFG = Config()
logger = logging.getLogger("chat_agent")
class ChatAgent(BaseChat):
"""Chat With Agent through plugin"""
chat_scene: str = ChatScene.ChatAgent.value()
chat_retention_rounds = 0
def __init__(self, chat_param: Dict):
"""Chat Agent Module Initialization
Args:
- chat_param: Dict
- chat_session_id: (str) chat session_id
- current_user_input: (str) current user input
- model_name:(str) llm model name
- select_param:(str) agent plugin
"""
if not chat_param["select_param"]:
raise ValueError("Please select a Plugin!")
self.select_plugins = chat_param["select_param"].split(",")
chat_param["chat_mode"] = ChatScene.ChatAgent
super().__init__(chat_param=chat_param)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
# load select plugin
agent_module = CFG.SYSTEM_APP.get_component(
ComponentType.AGENT_HUB, ModuleAgent
)
self.plugins_prompt_generator = agent_module.load_select_plugin(
self.plugins_prompt_generator, self.select_plugins
)
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
@trace()
async def generate_input_values(self) -> Dict[str, str]:
input_values = {
"user_goal": self.current_user_input,
"expand_constraints": self.__list_to_prompt_str(
list(self.plugins_prompt_generator.constraints)
),
"tool_list": self.plugins_prompt_generator.generate_commands_string(),
}
return input_values
def stream_plugin_call(self, text):
text = text.replace("\n", " ")
with root_tracer.start_span(
"ChatAgent.stream_plugin_call.api_call", metadata={"text": text}
):
return self.api_call.run(text)
def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))