mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 05:59:59 +00:00
refactor(agent): Refactor resource of agents (#1518)
This commit is contained in:
@@ -17,7 +17,8 @@ from dbgpt.agent.core.memory.gpts.gpts_memory import GptsMemory
|
||||
from dbgpt.agent.core.plan import AutoPlanChatManager, DefaultAWELLayoutManager
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from dbgpt.agent.core.user_proxy_agent import UserProxyAgent
|
||||
from dbgpt.agent.resource.resource_loader import ResourceLoader
|
||||
from dbgpt.agent.resource.base import Resource
|
||||
from dbgpt.agent.resource.manage import get_resource_manager
|
||||
from dbgpt.agent.util.llm.llm import LLMConfig, LLMStrategyType
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.app.scene.base import ChatScene
|
||||
@@ -32,9 +33,6 @@ from dbgpt.util.json_utils import serialize
|
||||
from ..db.gpts_app import GptsApp, GptsAppDao, GptsAppQuery
|
||||
from ..db.gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
|
||||
from ..db.gpts_manage_db import GptsInstanceEntity
|
||||
from ..resource_loader.datasource_load_client import DatasourceLoadClient
|
||||
from ..resource_loader.knowledge_space_load_client import KnowledgeSpaceLoadClient
|
||||
from ..resource_loader.plugin_hub_load_client import PluginHubLoadClient
|
||||
from ..team.base import TeamMode
|
||||
from .db_gpts_memory import MetaDbGptsMessageMemory, MetaDbGptsPlansMemory
|
||||
|
||||
@@ -93,8 +91,6 @@ class MultiAgents(BaseComponent, ABC):
|
||||
from dbgpt.agent.core.memory.hybrid import HybridMemory
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
memory_key = f"{dbgpts_name}_{conv_id}"
|
||||
if memory_key in self.agent_memory_map:
|
||||
@@ -105,13 +101,17 @@ class MultiAgents(BaseComponent, ABC):
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vstore_name = f"_chroma_agent_memory_{dbgpts_name}_{conv_id}"
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=VectorStoreConfig(
|
||||
name=vstore_name, embedding_fn=embedding_fn
|
||||
),
|
||||
# Just use chroma store now
|
||||
# vector_store_connector = VectorStoreConnector(
|
||||
# vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
# vector_store_config=VectorStoreConfig(
|
||||
# name=vstore_name, embedding_fn=embedding_fn
|
||||
# ),
|
||||
# )
|
||||
memory = HybridMemory[AgentMemoryFragment].from_chroma(
|
||||
vstore_name=vstore_name,
|
||||
embeddings=embedding_fn,
|
||||
)
|
||||
memory = HybridMemory[AgentMemoryFragment].from_vstore(vector_store_connector)
|
||||
agent_memory = AgentMemory(memory, gpts_memory=self.memory)
|
||||
self.agent_memory_map[memory_key] = agent_memory
|
||||
return agent_memory
|
||||
@@ -243,14 +243,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
agent_memory: Optional[AgentMemory] = None,
|
||||
):
|
||||
employees: List[Agent] = []
|
||||
# Prepare resource loader
|
||||
resource_loader = ResourceLoader()
|
||||
plugin_hub_loader = PluginHubLoadClient()
|
||||
resource_loader.register_resource_api(plugin_hub_loader)
|
||||
datasource_loader = DatasourceLoadClient()
|
||||
resource_loader.register_resource_api(datasource_loader)
|
||||
knowledge_space_loader = KnowledgeSpaceLoadClient()
|
||||
resource_loader.register_resource_api(knowledge_space_loader)
|
||||
rm = get_resource_manager()
|
||||
context: AgentContext = AgentContext(
|
||||
conv_id=conv_uid,
|
||||
gpts_app_name=gpts_app.app_name,
|
||||
@@ -264,6 +257,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
).create()
|
||||
self.llm_provider = DefaultLLMClient(worker_manager, auto_convert_message=True)
|
||||
|
||||
depend_resource: Optional[Resource] = None
|
||||
for record in gpts_app.details:
|
||||
cls: Type[ConversableAgent] = get_agent_manager().get_by_name(
|
||||
record.agent_name
|
||||
@@ -273,12 +267,13 @@ class MultiAgents(BaseComponent, ABC):
|
||||
llm_strategy=LLMStrategyType(record.llm_strategy),
|
||||
strategy_context=record.llm_strategy_value,
|
||||
)
|
||||
depend_resource = rm.build_resource(record.resources, version="v1")
|
||||
|
||||
agent = (
|
||||
await cls()
|
||||
.bind(context)
|
||||
.bind(llm_config)
|
||||
.bind(record.resources)
|
||||
.bind(resource_loader)
|
||||
.bind(depend_resource)
|
||||
.bind(agent_memory)
|
||||
.build()
|
||||
)
|
||||
@@ -298,7 +293,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
manager = (
|
||||
await manager.bind(context)
|
||||
.bind(llm_config)
|
||||
.bind(resource_loader)
|
||||
.bind(depend_resource)
|
||||
.bind(agent_memory)
|
||||
.build()
|
||||
)
|
||||
@@ -308,7 +303,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
user_proxy: UserProxyAgent = (
|
||||
await UserProxyAgent()
|
||||
.bind(context)
|
||||
.bind(resource_loader)
|
||||
.bind(depend_resource)
|
||||
.bind(agent_memory)
|
||||
.build()
|
||||
)
|
||||
|
@@ -4,10 +4,8 @@ from fastapi import APIRouter
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.core.agent_manage import get_agent_manager
|
||||
from dbgpt.agent.resource.resource_api import ResourceType
|
||||
from dbgpt.agent.resource.manage import get_resource_manager
|
||||
from dbgpt.agent.util.llm.llm import LLMStrategyType
|
||||
from dbgpt.app.knowledge.api import knowledge_space_service
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.serve.agent.app.gpts_server import available_llms
|
||||
from dbgpt.serve.agent.db.gpts_app import (
|
||||
@@ -17,7 +15,6 @@ from dbgpt.serve.agent.db.gpts_app import (
|
||||
GptsAppQuery,
|
||||
GptsAppResponse,
|
||||
)
|
||||
from dbgpt.serve.agent.hub.plugin_hub import plugin_hub
|
||||
from dbgpt.serve.agent.team.base import TeamMode
|
||||
|
||||
CFG = Config()
|
||||
@@ -109,7 +106,8 @@ async def team_mode_list():
|
||||
@router.get("/v1/resource-type/list")
|
||||
async def team_mode_list():
|
||||
try:
|
||||
return Result.succ([type.value for type in ResourceType])
|
||||
resources = get_resource_manager().get_supported_resources(version="v1")
|
||||
return Result.succ(list(resources.keys()))
|
||||
except Exception as ex:
|
||||
return Result.failed(code="E000X", msg=f"query resource type list error: {ex}")
|
||||
|
||||
@@ -146,29 +144,8 @@ async def app_resources(
|
||||
Get agent resources, such as db, knowledge, internet, plugin.
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
match type:
|
||||
case ResourceType.DB.value:
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
results = [db["db_name"] for db in dbs]
|
||||
if name:
|
||||
results = [r for r in results if name in r]
|
||||
case ResourceType.Knowledge.value:
|
||||
knowledge_spaces = knowledge_space_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest()
|
||||
)
|
||||
results = [ks.name for ks in knowledge_spaces]
|
||||
if name:
|
||||
results = [r for r in results if name in r]
|
||||
case ResourceType.Plugin.value:
|
||||
plugins = plugin_hub.get_my_plugin(user_code)
|
||||
results = [plugin.name for plugin in plugins]
|
||||
if name:
|
||||
results = [r for r in results if name in r]
|
||||
case ResourceType.Internet.value:
|
||||
return Result.succ(None)
|
||||
case ResourceType.File.value:
|
||||
return Result.succ(None)
|
||||
resources = get_resource_manager().get_supported_resources("v1")
|
||||
results = resources.get(type, [])
|
||||
return Result.succ(results)
|
||||
except Exception as ex:
|
||||
return Result.failed(code="E000X", msg=f"query app resources error: {ex}")
|
||||
|
@@ -15,7 +15,7 @@ from dbgpt._private.pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.agent.core.plan import AWELTeamContext
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
from dbgpt.agent.resource.base import AgentResource
|
||||
from dbgpt.serve.agent.team.base import TeamMode
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
@@ -4,8 +4,8 @@ from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, File, UploadFile
|
||||
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
from dbgpt.agent.plugin.plugins_util import scan_plugins
|
||||
from dbgpt.agent.resource.tool.autogpt.plugins_util import scan_plugins
|
||||
from dbgpt.agent.resource.tool.pack import AutoGPTPluginToolPack
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.configs.model_config import PLUGINS_DIR
|
||||
@@ -30,25 +30,15 @@ class ModulePlugin(BaseComponent, ABC):
|
||||
|
||||
def __init__(self):
|
||||
# load plugins
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
self.refresh_plugins()
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
|
||||
|
||||
def refresh_plugins(self):
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def load_select_plugin(
|
||||
self, generator: PluginPromptGenerator, select_plugins: List[str]
|
||||
) -> PluginPromptGenerator:
|
||||
logger.info(f"load_select_plugin:{select_plugins}")
|
||||
# load select plugin
|
||||
for plugin in self.plugins:
|
||||
if plugin._name in select_plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
generator = plugin.post_prompt(generator)
|
||||
return generator
|
||||
self.tools = AutoGPTPluginToolPack(PLUGINS_DIR)
|
||||
self.tools.preload_resource()
|
||||
|
||||
|
||||
module_plugin = ModulePlugin()
|
||||
|
@@ -9,7 +9,7 @@ from typing import Any
|
||||
from fastapi import UploadFile
|
||||
|
||||
from dbgpt.agent.core.schema import PluginStorageType
|
||||
from dbgpt.agent.plugin.plugins_util import scan_plugins, update_from_git
|
||||
from dbgpt.agent.resource.tool.autogpt.plugins_util import scan_plugins, update_from_git
|
||||
from dbgpt.configs.model_config import PLUGINS_DIR
|
||||
|
||||
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
||||
|
95
dbgpt/serve/agent/resource/datasource.py
Normal file
95
dbgpt/serve/agent/resource/datasource.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource
|
||||
from dbgpt.util import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DatasourceDBParameters(DBParameters):
|
||||
"""The DB parameters for the datasource."""
|
||||
|
||||
db_name: str = dataclasses.field(metadata={"help": "DB name"})
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["DatasourceDBParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "db_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "DatasourceDBParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "db_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["db_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class DatasourceResource(RDBMSConnectorResource):
|
||||
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
|
||||
conn = CFG.local_db_manager.get_connector(db_name)
|
||||
super().__init__(name, connector=conn, db_name=db_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[DatasourceDBParameters]:
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
results = [db["db_name"] for db in dbs]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynDBParameters(DatasourceDBParameters):
|
||||
db_name: str = dataclasses.field(
|
||||
metadata={"help": "DB name", "valid_values": results}
|
||||
)
|
||||
|
||||
return _DynDBParameters
|
||||
|
||||
def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
try:
|
||||
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
table_infos = None
|
||||
try:
|
||||
table_infos = client.get_db_summary(
|
||||
db,
|
||||
question,
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"db summary find error!{str(e)}")
|
||||
if not table_infos:
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
table_infos = conn.table_simple_info()
|
||||
|
||||
return table_infos
|
89
dbgpt/serve/agent/resource/knowledge.py
Normal file
89
dbgpt/serve/agent/resource/knowledge.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.knowledge import (
|
||||
RetrieverResource,
|
||||
RetrieverResourceParameters,
|
||||
)
|
||||
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
from dbgpt.util import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KnowledgeSpaceLoadResourceParameters(RetrieverResourceParameters):
|
||||
space_name: str = dataclasses.field(
|
||||
default=None, metadata={"help": "Knowledge space name"}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["KnowledgeSpaceLoadResourceParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "space_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "KnowledgeSpaceLoadResourceParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "space_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["space_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class KnowledgeSpaceRetrieverResource(RetrieverResource):
|
||||
"""Knowledge Space retriever resource."""
|
||||
|
||||
def __init__(self, name: str, space_name: str):
|
||||
retriever = KnowledgeSpaceRetriever(space_name=space_name)
|
||||
super().__init__(name, retriever=retriever)
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[KnowledgeSpaceLoadResourceParameters]:
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
|
||||
knowledge_space_service = KnowledgeService()
|
||||
knowledge_spaces = knowledge_space_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest()
|
||||
)
|
||||
results = [ks.name for ks in knowledge_spaces]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynamicKnowledgeSpaceLoadResourceParameters(
|
||||
KnowledgeSpaceLoadResourceParameters
|
||||
):
|
||||
space_name: str = dataclasses.field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Knowledge space name",
|
||||
"valid_values": results,
|
||||
},
|
||||
)
|
||||
|
||||
return _DynamicKnowledgeSpaceLoadResourceParameters
|
92
dbgpt/serve/agent/resource/plugin.py
Normal file
92
dbgpt/serve/agent/resource/plugin.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.pack import PackResourceParameters
|
||||
from dbgpt.agent.resource.tool.pack import ToolPack
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.serve.agent.hub.controller import ModulePlugin
|
||||
from dbgpt.util.parameter_utils import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PluginPackResourceParameters(PackResourceParameters):
|
||||
tool_name: str = dataclasses.field(metadata={"help": "Tool name"})
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["PluginPackResourceParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "tool_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "PluginPackResourceParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "tool_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["tool_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class PluginToolPack(ToolPack):
|
||||
def __init__(self, tool_name: str, **kwargs):
|
||||
kwargs.pop("name")
|
||||
super().__init__([], name="Plugin Tool Pack", **kwargs)
|
||||
# Select tool name
|
||||
self._tool_name = tool_name
|
||||
|
||||
@classmethod
|
||||
def type_alias(cls) -> str:
|
||||
return "tool(autogpt_plugins)"
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[PluginPackResourceParameters]:
|
||||
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
tool_names = []
|
||||
for name, sub_tool in agent_module.tools._resources.items():
|
||||
tool_names.append(name)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynPluginPackResourceParameters(PluginPackResourceParameters):
|
||||
tool_name: str = dataclasses.field(
|
||||
metadata={"help": "Tool name", "valid_values": tool_names}
|
||||
)
|
||||
|
||||
return _DynPluginPackResourceParameters
|
||||
|
||||
def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
tool = agent_module.tools._resources.get(self._tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {self._tool_name} not found")
|
||||
self._resources = {tool.name: tool}
|
@@ -1,66 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
from dbgpt.agent.resource.resource_db_api import ResourceDbClient
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceLoadClient(ResourceDbClient):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# The executor to submit blocking function
|
||||
self._executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
conn = CFG.local_db_manager.get_connector(resource.value)
|
||||
return conn.db_type
|
||||
|
||||
async def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
try:
|
||||
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
table_infos = None
|
||||
try:
|
||||
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor,
|
||||
client.get_db_summary,
|
||||
db,
|
||||
question,
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
if not table_infos:
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor, conn.table_simple_info
|
||||
)
|
||||
|
||||
return table_infos
|
||||
|
||||
async def query_to_df(self, db: str, sql: str):
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
return conn.run_to_df(sql)
|
||||
|
||||
async def query(self, db: str, sql: str):
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
return conn.query_ex(sql)
|
||||
|
||||
async def run_sql(self, db: str, sql: str):
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
return conn.run(sql)
|
@@ -1,35 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
from dbgpt.agent.resource.resource_knowledge_api import ResourceKnowledgeClient
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeSpaceLoadClient(ResourceKnowledgeClient):
|
||||
async def get_space_desc(self, space_name) -> str:
|
||||
pass
|
||||
|
||||
async def get_kn(
|
||||
self, space_name: str, question: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
kn_retriver = KnowledgeSpaceRetriever(space_name=space_name)
|
||||
chunks: List[Chunk] = kn_retriver.retrieve(question)
|
||||
return chunks
|
||||
|
||||
async def add_kn(
|
||||
self, space_name: str, kn_name: str, type: str, content: Optional[Any]
|
||||
):
|
||||
kn_retriver = KnowledgeSpaceRetriever(space_name=space_name)
|
||||
|
||||
async def get_data_introduce(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
docs = await self.get_kn(resource.value, question)
|
||||
return "\n".join([doc.content for doc in docs])
|
@@ -1,40 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
from dbgpt.agent.resource.resource_plugin_api import ResourcePluginClient
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.serve.agent.hub.controller import ModulePlugin
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginHubLoadClient(ResourcePluginClient):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# The executor to submit blocking function
|
||||
self._executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
async def load_plugin(
|
||||
self, value: str, plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
) -> PluginPromptGenerator:
|
||||
logger.info(f"PluginHubLoadClient load plugin:{value}")
|
||||
if plugin_generator is None:
|
||||
plugin_generator = PluginPromptGenerator()
|
||||
plugin_generator.set_command_registry(CFG.command_registry)
|
||||
|
||||
agent_module = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
plugin_generator = agent_module.load_select_plugin(
|
||||
plugin_generator, json.dumps(value)
|
||||
)
|
||||
|
||||
return plugin_generator
|
Reference in New Issue
Block a user