mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 21:51:25 +00:00
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,16 +1,22 @@
|
||||
"""Indicator Assistant Agent."""
|
||||
|
||||
"""Indicator Agent."""
|
||||
import logging
|
||||
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from .actions.indicator_action import IndicatorAction
|
||||
from ..expand.actions.indicator_action import IndicatorAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger()
|
||||
|
||||
CHECK_RESULT_SYSTEM_MESSAGE = """
|
||||
You are an expert in analyzing the results of a summary task.
|
||||
Your responsibility is to check whether the summary results can summarize the input provided by the user, and then make a judgment. You need to answer according to the following rules:
|
||||
Rule 1: If you think the summary results can summarize the input provided by the user, only return True.
|
||||
Rule 2: If you think the summary results can NOT summarize the input provided by the user, return False and the reason, splitted by | and ended by TERMINATE. For instance: False|Some important concepts in the input are not summarized. TERMINATE
|
||||
""" # noqa
|
||||
|
||||
|
||||
class IndicatorAssistantAgent(ConversableAgent):
|
||||
"""Indicator Assistant Agent."""
|
||||
"""IndicatorAssistantAgent."""
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
@@ -24,39 +30,32 @@ class IndicatorAssistantAgent(ConversableAgent):
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Summarize answer summaries based on user questions from provided "
|
||||
"resource information or from historical conversation memories.",
|
||||
"Summarize answer summaries based on user questions from provided resource information or from historical conversation memories.", # noqa
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Prioritize the summary of answers to user questions from the "
|
||||
"improved resource text. If no relevant information is found, "
|
||||
"summarize it from the historical dialogue memory given. It is "
|
||||
"forbidden to make up your own.",
|
||||
"You need to first detect user's question that you need to answer "
|
||||
"with your summarization.",
|
||||
"Prioritize the summary of answers to user questions from the improved resource text. If no relevant information is found, summarize it from the historical dialogue memory given. It is forbidden to make up your own.", # noqa
|
||||
"You need to first detect user's question that you need to answer with your summarization.", # noqa
|
||||
"Extract the provided text content used for summarization.",
|
||||
"Then you need to summarize the extracted text content.",
|
||||
"Output the content of summarization ONLY related to user's question. "
|
||||
"The output language must be the same to user's question language.",
|
||||
"If you think the provided text content is not related to user "
|
||||
"questions at all, ONLY output 'Did not find the information you "
|
||||
"want.'!!.",
|
||||
"Output the content of summarization ONLY related to user's question. The output language must be the same to user's question language.", # noqa
|
||||
"""If you think the provided text content is not related to user questions at all, ONLY output "Did not find the information you want."!!.""", # noqa
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"You can summarize provided text content according to user's questions "
|
||||
"and output the summarization.",
|
||||
"You can summarize provided text content according to user's questions and output the summaraization.", # noqa
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_desc",
|
||||
),
|
||||
)
|
||||
max_retry_count: int = 3
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new instance."""
|
||||
"""Init indicator AssistantAgent."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([IndicatorAction])
|
||||
|
@@ -66,7 +66,7 @@ class ChartAction(Action[SqlInput]):
|
||||
logger.exception(f"{str(e)}! \n {ai_message}")
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content="The requested correctly structured answer could not be found.",
|
||||
content="Error:The answer is not output in the required format.",
|
||||
)
|
||||
try:
|
||||
if not self.resource_need:
|
||||
@@ -96,5 +96,5 @@ class ChartAction(Action[SqlInput]):
|
||||
logger.exception("Check your answers, the sql run failed!")
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content=f"Check your answers, the sql run failed!Reason:{str(e)}",
|
||||
content=f"Error:Check your answers, the sql run failed!Reason:{str(e)}",
|
||||
)
|
||||
|
@@ -1,11 +1,11 @@
|
||||
"""Indicator Action."""
|
||||
|
||||
"""Indicator Agent action."""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
|
||||
from dbgpt.vis.tags.vis_api_response import VisApiResponse
|
||||
from dbgpt.vis.tags.vis_plugin import Vis
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...core.schema import Status
|
||||
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndicatorInput(BaseModel):
|
||||
"""Indicator input model."""
|
||||
"""Indicator llm out model."""
|
||||
|
||||
indicator_name: str = Field(
|
||||
...,
|
||||
@@ -31,19 +31,20 @@ class IndicatorInput(BaseModel):
|
||||
)
|
||||
args: dict = Field(
|
||||
default={"arg name1": "", "arg name2": ""},
|
||||
description="The tool selected for the current target, the parameter "
|
||||
"information required for execution",
|
||||
description="The tool selected for the current target, "
|
||||
"the parameter information required for execution",
|
||||
)
|
||||
thought: str = Field(..., description="Summary of thoughts to the user")
|
||||
display: str = Field(None, description="How to display return information")
|
||||
|
||||
|
||||
class IndicatorAction(Action[IndicatorInput]):
|
||||
"""Indicator action class."""
|
||||
"""Indicator Action."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a indicator action."""
|
||||
"""Init Indicator Action."""
|
||||
super().__init__()
|
||||
self._render_protocol = VisPlugin()
|
||||
self._render_protocol = VisApiResponse()
|
||||
|
||||
@property
|
||||
def resource_need(self) -> Optional[ResourceType]:
|
||||
@@ -64,8 +65,7 @@ class IndicatorAction(Action[IndicatorInput]):
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
"""Return the AI output schema."""
|
||||
out_put_schema = {
|
||||
"indicator_name": "The name of a tool that can be used to answer the "
|
||||
"current question or solve the current task.",
|
||||
"indicator_name": "The name of a tool that can be used to answer the current question or solve the current task.", # noqa
|
||||
"api": "",
|
||||
"method": "",
|
||||
"args": {
|
||||
@@ -80,6 +80,10 @@ class IndicatorAction(Action[IndicatorInput]):
|
||||
Make sure the response is correct json and can be parsed by Python json.loads.
|
||||
"""
|
||||
|
||||
def build_headers(self):
|
||||
"""Build headers."""
|
||||
return None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ai_message: str,
|
||||
@@ -93,40 +97,43 @@ class IndicatorAction(Action[IndicatorInput]):
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
try:
|
||||
input_param = self._input_convert(ai_message, IndicatorInput)
|
||||
logger.info(
|
||||
f"_input_convert: {type(self).__name__} ai_message: {ai_message}"
|
||||
)
|
||||
param: IndicatorInput = self._input_convert(ai_message, IndicatorInput)
|
||||
except Exception as e:
|
||||
logger.exception((str(e)))
|
||||
logger.exception(str(e))
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content="The requested correctly structured answer could not be found.",
|
||||
)
|
||||
if isinstance(input_param, list):
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content="The requested correctly structured answer could not be found.",
|
||||
)
|
||||
param: IndicatorInput = input_param
|
||||
response_success = True
|
||||
result: Optional[str] = None
|
||||
|
||||
try:
|
||||
status = Status.COMPLETE.value
|
||||
status = Status.RUNNING.value
|
||||
response_success = True
|
||||
response_text = ""
|
||||
err_msg = None
|
||||
try:
|
||||
status = Status.RUNNING.value
|
||||
if param.method.lower() == "get":
|
||||
response = requests.get(param.api, params=param.args)
|
||||
response = requests.get(
|
||||
param.api, params=param.args, headers=self.build_headers()
|
||||
)
|
||||
elif param.method.lower() == "post":
|
||||
response = requests.post(param.api, data=param.args)
|
||||
response = requests.post(
|
||||
param.api, json=param.args, headers=self.build_headers()
|
||||
)
|
||||
else:
|
||||
response = requests.request(
|
||||
param.method.lower(), param.api, data=param.args
|
||||
param.method.lower(),
|
||||
param.api,
|
||||
data=param.args,
|
||||
headers=self.build_headers(),
|
||||
)
|
||||
# Raise an HTTPError if the HTTP request returned an unsuccessful
|
||||
# status code
|
||||
response.raise_for_status()
|
||||
result = response.text
|
||||
response_text = response.text
|
||||
logger.info(f"API:{param.api}\nResult:{response_text}")
|
||||
response.raise_for_status() # 如果请求返回一个错误状态码,则抛出HTTPError异常
|
||||
status = Status.COMPLETE.value
|
||||
except HTTPError as http_err:
|
||||
response_success = False
|
||||
print(f"HTTP error occurred: {http_err}")
|
||||
except Exception as e:
|
||||
response_success = False
|
||||
@@ -139,16 +146,18 @@ class IndicatorAction(Action[IndicatorInput]):
|
||||
"args": param.args,
|
||||
"status": status,
|
||||
"logo": None,
|
||||
"result": result,
|
||||
"result": response_text,
|
||||
"err_msg": err_msg,
|
||||
}
|
||||
|
||||
if not self.render_protocol:
|
||||
raise NotImplementedError("The render_protocol should be implemented.")
|
||||
view = await self.render_protocol.display(content=plugin_param)
|
||||
view = (
|
||||
await self.render_protocol.display(content=plugin_param)
|
||||
if self.render_protocol
|
||||
else response_text
|
||||
)
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=response_success, content=result, view=view
|
||||
is_exe_success=response_success, content=response_text, view=view
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Indicator Action Run Failed!")
|
||||
|
@@ -126,10 +126,8 @@ class CodeAssistantAgent(ConversableAgent):
|
||||
"""Verify whether the current execution results meet the target expectations."""
|
||||
task_goal = message.current_goal
|
||||
action_report = message.action_report
|
||||
task_result = ""
|
||||
if action_report:
|
||||
task_result = action_report.get("content", "")
|
||||
|
||||
if not action_report:
|
||||
return False, "No execution solution results were checked"
|
||||
check_result, model = await self.thinking(
|
||||
messages=[
|
||||
AgentMessage(
|
||||
@@ -137,7 +135,7 @@ class CodeAssistantAgent(ConversableAgent):
|
||||
content="Please understand the following task objectives and "
|
||||
f"results and give your judgment:\n"
|
||||
f"Task goal: {task_goal}\n"
|
||||
f"Execution Result: {task_result}",
|
||||
f"Execution Result: {action_report.content}",
|
||||
)
|
||||
],
|
||||
prompt=CHECK_RESULT_SYSTEM_MESSAGE,
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Dashboard Assistant Agent."""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
@@ -58,8 +58,12 @@ class DashboardAssistantAgent(ConversableAgent):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([DashboardAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
def _init_reply_message(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message, rely_messages)
|
||||
|
||||
dbs: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
|
||||
|
@@ -2,9 +2,8 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..core.action.base import ActionOutput
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
@@ -29,28 +28,34 @@ class DataScientistAgent(ConversableAgent):
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Use correct {{ dialect }} SQL to analyze and solve tasks based on the data"
|
||||
" structure information of the database given in the resource.",
|
||||
"Use correct {{dialect}} SQL to analyze and resolve user "
|
||||
"input targets based on the data structure information of the "
|
||||
"database given in the resource.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Please check the generated SQL carefully. Please strictly abide by "
|
||||
"the data structure definition given. It is prohibited to use "
|
||||
"non-existent fields and data values. Do not use fields from table A "
|
||||
"to table B. You can perform multi-table related queries.",
|
||||
"Please ensure that the output is in the required format. "
|
||||
"Please ensure that each analysis only outputs one analysis "
|
||||
"result SQL, including as much analysis target content as possible.",
|
||||
"If there is a recent message record, pay attention to refer to "
|
||||
"the answers and execution results inside when analyzing, "
|
||||
"and do not generate the same wrong answer.Please check carefully "
|
||||
"to make sure the correct SQL is generated. Please strictly adhere "
|
||||
"to the data structure definition given. The use of non-existing "
|
||||
"fields is prohibited. Be careful not to confuse fields from "
|
||||
"different tables, and you can perform multi-table related queries.",
|
||||
"If the data and fields that need to be analyzed in the target are in "
|
||||
"different tables, it is recommended to use multi-table correlation "
|
||||
"queries first, and pay attention to the correlation between multiple "
|
||||
"table structures.",
|
||||
"It is forbidden to construct data by yourself as a query condition. "
|
||||
"If you want to query a specific field, if the value of the field is "
|
||||
"provided, then you can perform a group statistical query on the "
|
||||
"field.",
|
||||
"It is prohibited to construct data yourself as query conditions. "
|
||||
"Only the data values given by the famous songs in the input can "
|
||||
"be used as query conditions.",
|
||||
"Please select an appropriate one from the supported display methods "
|
||||
"for data display. If no suitable display type is found, "
|
||||
"table display is used by default. Supported display types: \n"
|
||||
"use 'response_table' as default value. Supported display types: \n"
|
||||
"{{ display_type }}",
|
||||
],
|
||||
category="agent",
|
||||
@@ -65,14 +70,19 @@ class DataScientistAgent(ConversableAgent):
|
||||
)
|
||||
|
||||
max_retry_count: int = 5
|
||||
language: str = "zh"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new DataScientistAgent instance."""
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([ChartAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
def _init_reply_message(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message, rely_messages)
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": self.database.dialect,
|
||||
@@ -93,13 +103,13 @@ class DataScientistAgent(ConversableAgent):
|
||||
self, message: AgentMessage
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""Verify whether the current execution results meet the target expectations."""
|
||||
action_reply = message.action_report
|
||||
if action_reply is None:
|
||||
action_out = message.action_report
|
||||
if action_out is None:
|
||||
return (
|
||||
False,
|
||||
f"No executable analysis SQL is generated,{message.content}.",
|
||||
)
|
||||
action_out = cast(ActionOutput, ActionOutput.from_dict(action_reply))
|
||||
|
||||
if not action_out.is_exe_success:
|
||||
return (
|
||||
False,
|
||||
|
@@ -38,7 +38,6 @@ def baidu_search(
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:112.0) "
|
||||
"Gecko/20100101 Firefox/112.0"
|
||||
}
|
||||
num_results = int(num_results)
|
||||
if num_results < 8:
|
||||
num_results = 8
|
||||
url = f"https://www.baidu.com/s?wd={query}&rn={num_results}"
|
||||
|
@@ -1,595 +0,0 @@
|
||||
"""Retrieve Summary Assistant Agent."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.core import ModelMessageRoleType
|
||||
|
||||
from ..core.action.base import Action, ActionOutput
|
||||
from ..core.agent import Agent, AgentMessage, AgentReviewInfo
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import ProfileConfig
|
||||
from ..resource.base import AgentResource
|
||||
from ..util.cmp import cmp_string_equal
|
||||
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
|
||||
HAS_UNSTRUCTURED = True
|
||||
except ImportError:
|
||||
HAS_UNSTRUCTURED = False
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
TEXT_FORMATS = [
|
||||
"txt",
|
||||
"json",
|
||||
"csv",
|
||||
"tsv",
|
||||
"md",
|
||||
"html",
|
||||
"htm",
|
||||
"rtf",
|
||||
"rst",
|
||||
"jsonl",
|
||||
"log",
|
||||
"xml",
|
||||
"yaml",
|
||||
"yml",
|
||||
"pdf",
|
||||
]
|
||||
UNSTRUCTURED_FORMATS = [
|
||||
"doc",
|
||||
"docx",
|
||||
"epub",
|
||||
"msg",
|
||||
"odt",
|
||||
"org",
|
||||
"pdf",
|
||||
"ppt",
|
||||
"pptx",
|
||||
"rtf",
|
||||
"rst",
|
||||
"xlsx",
|
||||
] # These formats will be parsed by the 'unstructured' library, if installed.
|
||||
if HAS_UNSTRUCTURED:
|
||||
TEXT_FORMATS += UNSTRUCTURED_FORMATS
|
||||
TEXT_FORMATS = list(set(TEXT_FORMATS))
|
||||
|
||||
VALID_CHUNK_MODES = frozenset({"one_line", "multi_lines"})
|
||||
|
||||
|
||||
def _get_max_tokens(model="gpt-3.5-turbo"):
|
||||
"""Get the maximum number of tokens for a given model."""
|
||||
if "32k" in model:
|
||||
return 32000
|
||||
elif "16k" in model:
|
||||
return 16000
|
||||
elif "gpt-4" in model:
|
||||
return 8000
|
||||
else:
|
||||
return 4000
|
||||
|
||||
|
||||
_NO_RESPONSE = "NO RELATIONSHIP.UPDATE TEXT CONTENT."
|
||||
|
||||
|
||||
class RetrieveSummaryAssistantAgent(ConversableAgent):
|
||||
"""Assistant agent, designed to solve a task with LLM.
|
||||
|
||||
AssistantAgent is a subclass of ConversableAgent configured with a default
|
||||
system message.
|
||||
The default system message is designed to solve a task with LLM,
|
||||
including suggesting python code blocks and debugging.
|
||||
"""
|
||||
|
||||
PROMPT_QA: str = (
|
||||
"You are a great summary writer to summarize the provided text content "
|
||||
"according to user questions.\n"
|
||||
"User's Question is: {input_question}\n\n"
|
||||
"Provided text content is: {input_context}\n\n"
|
||||
"Please complete this task step by step following instructions below:\n"
|
||||
" 1. You need to first detect user's question that you need to answer with "
|
||||
"your summarization.\n"
|
||||
" 2. Then you need to summarize the provided text content that ONLY CAN "
|
||||
"ANSWER user's question and filter useless information as possible as you can. "
|
||||
"YOU CAN ONLY USE THE PROVIDED TEXT CONTENT!! DO NOT CREATE ANY SUMMARIZATION "
|
||||
"WITH YOUR OWN KNOWLEDGE!!!\n"
|
||||
" 3. Output the content of summarization that ONLY CAN ANSWER user's question"
|
||||
" and filter useless information as possible as you can. The output language "
|
||||
"must be the same to user's question language!! You must give as short an "
|
||||
"summarization as possible!!! DO NOT CREATE ANY SUMMARIZATION WITH YOUR OWN "
|
||||
"KNOWLEDGE!!!\n\n"
|
||||
"####Important Notice####\n"
|
||||
"If the provided text content CAN NOT ANSWER user's question, ONLY output "
|
||||
"'NO RELATIONSHIP.UPDATE TEXT CONTENT.'!!."
|
||||
)
|
||||
CHECK_RESULT_SYSTEM_MESSAGE: str = (
|
||||
"You are an expert in analyzing the results of a summary task."
|
||||
"Your responsibility is to check whether the summary results can summarize the "
|
||||
"input provided by the user, and then make a judgment. You need to answer "
|
||||
"according to the following rules:\n"
|
||||
" Rule 1: If you think the summary results can summarize the input provided"
|
||||
" by the user, only return True.\n"
|
||||
" Rule 2: If you think the summary results can NOT summarize the input "
|
||||
"provided by the user, return False and the reason, split by | and ended "
|
||||
"by TERMINATE. For instance: False|Some important concepts in the input are "
|
||||
"not summarized. TERMINATE"
|
||||
)
|
||||
|
||||
DEFAULT_DESCRIBE: str = (
|
||||
"Summarize provided content according to user's questions and "
|
||||
"the provided file paths."
|
||||
)
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="RetrieveSummarizer",
|
||||
role="Assistant",
|
||||
goal="You're an extraction expert. You need to extract Please complete this "
|
||||
"task step by step following instructions below:\n"
|
||||
" 1. You need to first ONLY extract user's question that you need to answer "
|
||||
"without ANY file paths and URLs. \n"
|
||||
" 2. Extract the provided file paths and URLs.\n"
|
||||
" 3. Construct the extracted file paths and URLs as a list of strings.\n"
|
||||
" 4. ONLY output the extracted results with the following json format: "
|
||||
"{{ response }}.",
|
||||
desc=DEFAULT_DESCRIBE,
|
||||
)
|
||||
|
||||
chunk_token_size: int = 4000
|
||||
chunk_mode: str = "multi_lines"
|
||||
|
||||
_model: str = "gpt-3.5-turbo-16k"
|
||||
_max_tokens: int = _get_max_tokens(_model)
|
||||
context_max_tokens: int = int(_max_tokens * 0.8)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new instance of the agent."""
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
self._init_actions([SummaryAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
json_data = {"user_question": "user's question", "file_list": "file&URL list"}
|
||||
reply_message.context = {"response": json.dumps(json_data, ensure_ascii=False)}
|
||||
return reply_message
|
||||
|
||||
async def generate_reply(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
sender: Agent,
|
||||
reviewer: Optional[Agent] = None,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate a reply based on the received messages."""
|
||||
reply_message: AgentMessage = self._init_reply_message(
|
||||
received_message=received_message
|
||||
)
|
||||
# 1.Think about how to do things
|
||||
llm_reply, model_name = await self.thinking(
|
||||
await self._load_thinking_messages(
|
||||
received_message,
|
||||
sender,
|
||||
rely_messages,
|
||||
context=reply_message.get_dict_context(),
|
||||
)
|
||||
)
|
||||
|
||||
if not llm_reply:
|
||||
raise ValueError("No reply from LLM.")
|
||||
ai_reply_dic = json.loads(llm_reply)
|
||||
user_question = ai_reply_dic["user_question"]
|
||||
file_list = ai_reply_dic["file_list"]
|
||||
|
||||
# 2. Split files and URLs in the file list dictionary into chunks
|
||||
extracted_files = self._get_files_from_dir(file_list)
|
||||
chunks = await self._split_files_to_chunks(files=extracted_files)
|
||||
|
||||
summaries = ""
|
||||
for count, chunk in enumerate(chunks[:]):
|
||||
print(count)
|
||||
temp_sys_message = self.PROMPT_QA.format(
|
||||
input_question=user_question, input_context=chunk
|
||||
)
|
||||
chunk_ai_reply, model = await self.thinking(
|
||||
messages=[
|
||||
AgentMessage(role=ModelMessageRoleType.HUMAN, content=user_question)
|
||||
],
|
||||
prompt=temp_sys_message,
|
||||
)
|
||||
if chunk_ai_reply and not cmp_string_equal(
|
||||
_NO_RESPONSE, chunk_ai_reply, True, True, True
|
||||
):
|
||||
summaries += f"{chunk_ai_reply}\n"
|
||||
|
||||
temp_sys_message = self.PROMPT_QA.format(
|
||||
input_question=user_question, input_context=summaries
|
||||
)
|
||||
|
||||
final_summary_ai_reply, model = await self.thinking(
|
||||
messages=[
|
||||
AgentMessage(role=ModelMessageRoleType.HUMAN, content=user_question)
|
||||
],
|
||||
prompt=temp_sys_message,
|
||||
)
|
||||
reply_message.model_name = model
|
||||
reply_message.content = final_summary_ai_reply
|
||||
|
||||
print("HERE IS THE FINAL SUMMARY!!!!!")
|
||||
print(final_summary_ai_reply)
|
||||
|
||||
approve = True
|
||||
comments = None
|
||||
if reviewer and final_summary_ai_reply:
|
||||
approve, comments = await reviewer.review(final_summary_ai_reply, self)
|
||||
|
||||
reply_message.review_info = AgentReviewInfo(
|
||||
approve=approve,
|
||||
comments=comments,
|
||||
)
|
||||
if approve:
|
||||
# 3.Act based on the results of your thinking
|
||||
act_extent_param = self.prepare_act_param()
|
||||
act_out: Optional[ActionOutput] = await self.act(
|
||||
message=final_summary_ai_reply,
|
||||
sender=sender,
|
||||
reviewer=reviewer,
|
||||
**act_extent_param,
|
||||
)
|
||||
if act_out:
|
||||
reply_message.action_report = act_out.to_dict()
|
||||
# 4.Reply information verification
|
||||
check_pass, reason = await self.verify(reply_message, sender, reviewer)
|
||||
is_success = check_pass
|
||||
# 5.Optimize wrong answers myself
|
||||
if not check_pass:
|
||||
reply_message.content = reason
|
||||
reply_message.success = is_success
|
||||
return reply_message
|
||||
|
||||
async def correctness_check(
|
||||
self, message: AgentMessage
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""Verify the correctness of the results."""
|
||||
action_report = message.action_report
|
||||
task_result = ""
|
||||
if action_report:
|
||||
task_result = action_report.get("content", "")
|
||||
|
||||
check_result, model = await self.thinking(
|
||||
messages=[
|
||||
AgentMessage(
|
||||
role=ModelMessageRoleType.HUMAN,
|
||||
content=(
|
||||
"Please understand the following user input and summary results"
|
||||
" and give your judgment:\n"
|
||||
f"User Input: {message.current_goal}\n"
|
||||
f"Summary Results: {task_result}"
|
||||
),
|
||||
)
|
||||
],
|
||||
prompt=self.CHECK_RESULT_SYSTEM_MESSAGE,
|
||||
)
|
||||
fail_reason = ""
|
||||
if check_result and (
|
||||
"true" in check_result.lower() or "yes" in check_result.lower()
|
||||
):
|
||||
success = True
|
||||
elif not check_result:
|
||||
success = False
|
||||
fail_reason = (
|
||||
"The summary results cannot summarize the user input. "
|
||||
"Please re-understand and complete the summary task."
|
||||
)
|
||||
else:
|
||||
success = False
|
||||
try:
|
||||
_, fail_reason = check_result.split("|")
|
||||
fail_reason = (
|
||||
"The summary results cannot summarize the user input due"
|
||||
f" to: {fail_reason}. Please re-understand and complete the summary"
|
||||
" task."
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"The model thought the results are irrelevant but did not give the"
|
||||
" correct format of results."
|
||||
)
|
||||
fail_reason = (
|
||||
"The summary results cannot summarize the user input. "
|
||||
"Please re-understand and complete the summary task."
|
||||
)
|
||||
return success, fail_reason
|
||||
|
||||
def _get_files_from_dir(
|
||||
self,
|
||||
dir_path: Union[str, List[str]],
|
||||
types: list = TEXT_FORMATS,
|
||||
recursive: bool = True,
|
||||
):
|
||||
"""Return a list of all the files in a given directory.
|
||||
|
||||
A url, a file path or a list of them.
|
||||
"""
|
||||
if len(types) == 0:
|
||||
raise ValueError("types cannot be empty.")
|
||||
types = [t[1:].lower() if t.startswith(".") else t.lower() for t in set(types)]
|
||||
types += [t.upper() for t in types]
|
||||
|
||||
files = []
|
||||
# If the path is a list of files or urls, process and return them
|
||||
if isinstance(dir_path, list):
|
||||
for item in dir_path:
|
||||
if os.path.isfile(item):
|
||||
files.append(item)
|
||||
elif self._is_url(item):
|
||||
files.append(self._get_file_from_url(item))
|
||||
elif os.path.exists(item):
|
||||
try:
|
||||
files.extend(self._get_files_from_dir(item, types, recursive))
|
||||
except ValueError:
|
||||
logger.warning(f"Directory {item} does not exist. Skipping.")
|
||||
else:
|
||||
logger.warning(f"File {item} does not exist. Skipping.")
|
||||
return files
|
||||
|
||||
# If the path is a file, return it
|
||||
if os.path.isfile(dir_path):
|
||||
return [dir_path]
|
||||
|
||||
# If the path is a url, download it and return the downloaded file
|
||||
if self._is_url(dir_path):
|
||||
return [self._get_file_from_url(dir_path)]
|
||||
|
||||
if os.path.exists(dir_path):
|
||||
for type in types:
|
||||
if recursive:
|
||||
files += glob.glob(
|
||||
os.path.join(dir_path, f"**/*.{type}"), recursive=True
|
||||
)
|
||||
else:
|
||||
files += glob.glob(
|
||||
os.path.join(dir_path, f"*.{type}"), recursive=False
|
||||
)
|
||||
else:
|
||||
logger.error(f"Directory {dir_path} does not exist.")
|
||||
raise ValueError(f"Directory {dir_path} does not exist.")
|
||||
return files
|
||||
|
||||
def _get_file_from_url(self, url: str, save_path: Optional[str] = None):
|
||||
"""Download a file from a URL."""
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if save_path is None:
|
||||
target_directory = os.path.join(PILOT_PATH, "data")
|
||||
os.makedirs(target_directory, exist_ok=True)
|
||||
save_path = os.path.join(target_directory, os.path.basename(url))
|
||||
else:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
proxies: Dict[str, Any] = {}
|
||||
if os.getenv("http_proxy"):
|
||||
proxies["http"] = os.getenv("http_proxy")
|
||||
if os.getenv("https_proxy"):
|
||||
proxies["https"] = os.getenv("https_proxy")
|
||||
with requests.get(url, proxies=proxies, timeout=10, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(save_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
with open(save_path, "r", encoding="utf-8") as file:
|
||||
html_content = file.read()
|
||||
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
|
||||
# 可以根据需要从Beautiful Soup对象中提取数据,例如:
|
||||
# title = soup.title.string # 获取网页标题
|
||||
paragraphs = soup.find_all("p") # 获取所有段落文本
|
||||
|
||||
# 将解析后的内容重新写入到相同的save_path
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for paragraph in paragraphs:
|
||||
f.write(paragraph.get_text() + "\n") # 获取段落文本并写入文件
|
||||
|
||||
return save_path
|
||||
|
||||
def _is_url(self, string: str):
|
||||
"""Return True if the string is a valid URL."""
|
||||
try:
|
||||
result = urlparse(string)
|
||||
return all([result.scheme, result.netloc])
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
async def _split_text_to_chunks(
|
||||
self,
|
||||
text: str,
|
||||
chunk_mode: str = "multi_lines",
|
||||
must_break_at_empty_line: bool = True,
|
||||
):
|
||||
"""Split a long text into chunks of max_tokens."""
|
||||
max_tokens = self.chunk_token_size
|
||||
if chunk_mode not in VALID_CHUNK_MODES:
|
||||
raise AssertionError
|
||||
if chunk_mode == "one_line":
|
||||
must_break_at_empty_line = False
|
||||
chunks = []
|
||||
lines = text.split("\n")
|
||||
lines_tokens = [await self._count_token(line) for line in lines]
|
||||
sum_tokens = sum(lines_tokens)
|
||||
while sum_tokens > max_tokens:
|
||||
if chunk_mode == "one_line":
|
||||
estimated_line_cut = 2
|
||||
else:
|
||||
estimated_line_cut = int(max_tokens / sum_tokens * len(lines)) + 1
|
||||
cnt = 0
|
||||
prev = ""
|
||||
for cnt in reversed(range(estimated_line_cut)):
|
||||
if must_break_at_empty_line and lines[cnt].strip() != "":
|
||||
continue
|
||||
if sum(lines_tokens[:cnt]) <= max_tokens:
|
||||
prev = "\n".join(lines[:cnt])
|
||||
break
|
||||
if cnt == 0:
|
||||
logger.warning(
|
||||
f"max_tokens is too small to fit a single line of text. Breaking "
|
||||
f"this line:\n\t{lines[0][:100]} ..."
|
||||
)
|
||||
if not must_break_at_empty_line:
|
||||
split_len = int(max_tokens / lines_tokens[0] * 0.9 * len(lines[0]))
|
||||
prev = lines[0][:split_len]
|
||||
lines[0] = lines[0][split_len:]
|
||||
lines_tokens[0] = await self._count_token(lines[0])
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to split docs with must_break_at_empty_line being True,"
|
||||
" set to False."
|
||||
)
|
||||
must_break_at_empty_line = False
|
||||
(
|
||||
chunks.append(prev) if len(prev) > 10 else None
|
||||
) # don't add chunks less than 10 characters
|
||||
lines = lines[cnt:]
|
||||
lines_tokens = lines_tokens[cnt:]
|
||||
sum_tokens = sum(lines_tokens)
|
||||
text_to_chunk = "\n".join(lines)
|
||||
(
|
||||
chunks.append(text_to_chunk) if len(text_to_chunk) > 10 else None
|
||||
) # don't add chunks less than 10 characters
|
||||
return chunks
|
||||
|
||||
def _extract_text_from_pdf(self, file: str) -> str:
|
||||
"""Extract text from PDF files."""
|
||||
text = ""
|
||||
import pypdf
|
||||
|
||||
with open(file, "rb") as f:
|
||||
reader = pypdf.PdfReader(f)
|
||||
if reader.is_encrypted: # Check if the PDF is encrypted
|
||||
try:
|
||||
reader.decrypt("")
|
||||
except pypdf.errors.FileNotDecryptedError as e:
|
||||
logger.warning(f"Could not decrypt PDF {file}, {e}")
|
||||
return text # Return empty text if PDF could not be decrypted
|
||||
|
||||
for page_num in range(len(reader.pages)):
|
||||
page = reader.pages[page_num]
|
||||
text += page.extract_text()
|
||||
|
||||
if not text.strip(): # Debugging line to check if text is empty
|
||||
logger.warning(f"Could not decrypt PDF {file}")
|
||||
|
||||
return text
|
||||
|
||||
async def _split_files_to_chunks(
|
||||
self,
|
||||
files: list,
|
||||
chunk_mode: str = "multi_lines",
|
||||
must_break_at_empty_line: bool = True,
|
||||
custom_text_split_function: Optional[Callable] = None,
|
||||
):
|
||||
"""Split a list of files into chunks of max_tokens."""
|
||||
chunks = []
|
||||
|
||||
for file in files:
|
||||
_, file_extension = os.path.splitext(file)
|
||||
file_extension = file_extension.lower()
|
||||
|
||||
if HAS_UNSTRUCTURED and file_extension[1:] in UNSTRUCTURED_FORMATS:
|
||||
text = partition(file)
|
||||
text = "\n".join([t.text for t in text]) if len(text) > 0 else ""
|
||||
elif file_extension == ".pdf":
|
||||
text = self._extract_text_from_pdf(file)
|
||||
else: # For non-PDF text-based files
|
||||
with open(file, "r", encoding="utf-8", errors="ignore") as f:
|
||||
text = f.read()
|
||||
|
||||
if (
|
||||
not text.strip()
|
||||
): # Debugging line to check if text is empty after reading
|
||||
logger.warning(f"No text available in file: {file}")
|
||||
continue # Skip to the next file if no text is available
|
||||
|
||||
if custom_text_split_function is not None:
|
||||
chunks += custom_text_split_function(text)
|
||||
else:
|
||||
chunks += await self._split_text_to_chunks(
|
||||
text, chunk_mode, must_break_at_empty_line
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
async def _count_token(
|
||||
self, input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613"
|
||||
) -> int:
|
||||
"""Count number of tokens used by an OpenAI model.
|
||||
|
||||
Args:
|
||||
input: (str, list, dict): Input to the model.
|
||||
model: (str): Model name.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens from the input.
|
||||
"""
|
||||
_llm_client = self.not_null_llm_client
|
||||
if isinstance(input, str):
|
||||
return await _llm_client.count_token(model, input)
|
||||
elif isinstance(input, list):
|
||||
return sum([await _llm_client.count_token(model, i) for i in input])
|
||||
else:
|
||||
raise ValueError("input must be str or list")
|
||||
|
||||
|
||||
class SummaryAction(Action[None]):
|
||||
"""Simple Summary Action."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a new instance of the action."""
|
||||
super().__init__()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ai_message: str,
|
||||
resource: Optional[AgentResource] = None,
|
||||
rely_action_out: Optional[ActionOutput] = None,
|
||||
need_vis_render: bool = True,
|
||||
**kwargs,
|
||||
) -> ActionOutput:
|
||||
"""Perform the action."""
|
||||
fail_reason = None
|
||||
response_success = True
|
||||
view = None
|
||||
content = None
|
||||
if ai_message is None:
|
||||
# Answer failed, turn on automatic repair
|
||||
fail_reason += "Nothing is summarized, please check your input."
|
||||
response_success = False
|
||||
else:
|
||||
try:
|
||||
if "NO RELATIONSHIP." in ai_message:
|
||||
fail_reason = (
|
||||
"Return summarization error, the provided text "
|
||||
"content has no relationship to user's question. TERMINATE."
|
||||
)
|
||||
response_success = False
|
||||
else:
|
||||
content = ai_message
|
||||
view = content
|
||||
except Exception as e:
|
||||
fail_reason = f"Return summarization error, {str(e)}"
|
||||
response_success = False
|
||||
|
||||
if not response_success:
|
||||
content = fail_reason
|
||||
return ActionOutput(is_exe_success=response_success, content=content, view=view)
|
136
dbgpt/agent/expand/simple_assistant_agent.py
Normal file
136
dbgpt/agent/expand/simple_assistant_agent.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Simple Assistant Agent."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.rag.retriever.rerank import RetrieverNameRanker
|
||||
|
||||
from .. import AgentMessage
|
||||
from ..core.action.blank_action import BlankAction
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleAssistantAgent(ConversableAgent):
|
||||
"""Simple Assistant Agent."""
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"Tom",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_simple_assistant_agent_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"AI Assistant",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_simple_assistant_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Understand user questions and give professional answer",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_simple_assistant_agent_profile_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Please make sure your answer is clear, logical, "
|
||||
"friendly, and human-readable."
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_simple_assistant_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"I am a universal simple AI assistant.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_summary_assistant_agent_profile_desc",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new SummaryAssistantAgent instance."""
|
||||
super().__init__(**kwargs)
|
||||
self._post_reranks = [RetrieverNameRanker(5)]
|
||||
self._init_actions([BlankAction])
|
||||
|
||||
async def load_resource(self, question: str, is_retry_chat: bool = False):
|
||||
"""Load agent bind resource."""
|
||||
if self.resource:
|
||||
if self.resource.is_pack:
|
||||
sub_resources = self.resource.sub_resources
|
||||
candidates_results: List = []
|
||||
resource_candidates_map = {}
|
||||
info_map = {}
|
||||
prompt_list = []
|
||||
for resource in sub_resources:
|
||||
(
|
||||
candidates,
|
||||
prompt_template,
|
||||
resource_reference,
|
||||
) = await resource.get_resources(question=question)
|
||||
resource_candidates_map[resource.name] = (
|
||||
candidates,
|
||||
resource_reference,
|
||||
prompt_template,
|
||||
)
|
||||
candidates_results.extend(candidates) # type: ignore # noqa
|
||||
new_candidates_map = self.post_filters(resource_candidates_map)
|
||||
for resource, (
|
||||
candidates,
|
||||
references,
|
||||
prompt_template,
|
||||
) in new_candidates_map.items():
|
||||
content = "\n".join(
|
||||
[
|
||||
f"--{i}--:" + chunk.content
|
||||
for i, chunk in enumerate(candidates) # type: ignore # noqa
|
||||
]
|
||||
)
|
||||
prompt_list.append(
|
||||
prompt_template.format(name=resource, content=content)
|
||||
)
|
||||
info_map.update(references)
|
||||
return "\n".join(prompt_list), info_map
|
||||
else:
|
||||
resource_prompt, resource_reference = await self.resource.get_prompt(
|
||||
lang=self.language, question=question
|
||||
)
|
||||
return resource_prompt, resource_reference
|
||||
return None, None
|
||||
|
||||
def _init_reply_message(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message, rely_messages)
|
||||
reply_message.context = {
|
||||
"user_question": received_message.content,
|
||||
}
|
||||
return reply_message
|
||||
|
||||
def post_filters(self, resource_candidates_map: Optional[Dict[str, Tuple]] = None):
|
||||
"""Post filters for resource candidates."""
|
||||
if resource_candidates_map:
|
||||
new_candidates_map = resource_candidates_map.copy()
|
||||
filter_hit = False
|
||||
for resource, (
|
||||
candidates,
|
||||
references,
|
||||
prompt_template,
|
||||
) in resource_candidates_map.items():
|
||||
for rerank in self._post_reranks:
|
||||
filter_candidates = rerank.rank(candidates)
|
||||
new_candidates_map[resource] = [], [], prompt_template
|
||||
if filter_candidates and len(filter_candidates) > 0:
|
||||
new_candidates_map[resource] = (
|
||||
filter_candidates,
|
||||
references,
|
||||
prompt_template,
|
||||
)
|
||||
filter_hit = True
|
||||
break
|
||||
if filter_hit:
|
||||
logger.info("Post filters hit, use new candidates.")
|
||||
return new_candidates_map
|
||||
return resource_candidates_map
|
@@ -1,7 +1,11 @@
|
||||
"""Summary Assistant Agent."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.rag.retriever.rerank import RetrieverNameRanker
|
||||
|
||||
from .. import AgentMessage
|
||||
from ..core.action.blank_action import BlankAction
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
@@ -59,4 +63,87 @@ class SummaryAssistantAgent(ConversableAgent):
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new SummaryAssistantAgent instance."""
|
||||
super().__init__(**kwargs)
|
||||
self._post_reranks = [RetrieverNameRanker(5)]
|
||||
self._init_actions([BlankAction])
|
||||
|
||||
async def load_resource(self, question: str, is_retry_chat: bool = False):
|
||||
"""Load agent bind resource."""
|
||||
if self.resource:
|
||||
if self.resource.is_pack:
|
||||
sub_resources = self.resource.sub_resources
|
||||
candidates_results: List = []
|
||||
resource_candidates_map = {}
|
||||
info_map = {}
|
||||
prompt_list = []
|
||||
for resource in sub_resources:
|
||||
(
|
||||
candidates,
|
||||
prompt_template,
|
||||
resource_reference,
|
||||
) = await resource.get_resources(question=question)
|
||||
resource_candidates_map[resource.name] = (
|
||||
candidates,
|
||||
resource_reference,
|
||||
prompt_template,
|
||||
)
|
||||
candidates_results.extend(candidates) # type: ignore # noqa
|
||||
new_candidates_map = self.post_filters(resource_candidates_map)
|
||||
for resource, (
|
||||
candidates,
|
||||
references,
|
||||
prompt_template,
|
||||
) in new_candidates_map.items():
|
||||
content = "\n".join(
|
||||
[
|
||||
f"--{i}--:" + chunk.content
|
||||
for i, chunk in enumerate(candidates) # type: ignore # noqa
|
||||
]
|
||||
)
|
||||
prompt_list.append(
|
||||
prompt_template.format(name=resource, content=content)
|
||||
)
|
||||
info_map.update(references)
|
||||
return "\n".join(prompt_list), info_map
|
||||
else:
|
||||
resource_prompt, resource_reference = await self.resource.get_prompt(
|
||||
lang=self.language, question=question
|
||||
)
|
||||
return resource_prompt, resource_reference
|
||||
return None, None
|
||||
|
||||
def _init_reply_message(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message, rely_messages)
|
||||
reply_message.context = {
|
||||
"user_question": received_message.content,
|
||||
}
|
||||
return reply_message
|
||||
|
||||
def post_filters(self, resource_candidates_map: Optional[Dict[str, Tuple]] = None):
|
||||
"""Post filters for resource candidates."""
|
||||
if resource_candidates_map:
|
||||
new_candidates_map = resource_candidates_map.copy()
|
||||
filter_hit = False
|
||||
for resource, (
|
||||
candidates,
|
||||
references,
|
||||
prompt_template,
|
||||
) in resource_candidates_map.items():
|
||||
for rerank in self._post_reranks:
|
||||
filter_candidates = rerank.rank(candidates)
|
||||
new_candidates_map[resource] = [], [], prompt_template
|
||||
if filter_candidates and len(filter_candidates) > 0:
|
||||
new_candidates_map[resource] = (
|
||||
filter_candidates,
|
||||
references,
|
||||
prompt_template,
|
||||
)
|
||||
filter_hit = True
|
||||
break
|
||||
if filter_hit:
|
||||
logger.info("Post filters hit, use new candidates.")
|
||||
return new_candidates_map
|
||||
return resource_candidates_map
|
||||
|
Reference in New Issue
Block a user