mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 14:11:14 +00:00
feat(agent): Release agent SDK (#1396)
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
"""Resource module for Agent."""
|
||||
from .resource_api import AgentResource, ResourceClient, ResourceType # noqa: F401
|
||||
from .resource_db_api import ResourceDbClient, SqliteLoadClient # noqa: F401
|
||||
from .resource_knowledge_api import ResourceKnowledgeClient # noqa: F401
|
||||
from .resource_loader import ResourceLoader # noqa: F401
|
||||
from .resource_plugin_api import ( # noqa: F401
|
||||
PluginFileLoadClient,
|
||||
ResourcePluginClient,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentResource",
|
||||
"ResourceClient",
|
||||
"ResourceType",
|
||||
"ResourceDbClient",
|
||||
"SqliteLoadClient",
|
||||
"ResourceKnowledgeClient",
|
||||
"ResourceLoader",
|
||||
"PluginFileLoadClient",
|
||||
"ResourcePluginClient",
|
||||
]
|
||||
|
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
"""Resource API for the agent."""
|
||||
import json
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
|
||||
class ResourceType(Enum):
|
||||
"""Resource type enumeration."""
|
||||
|
||||
DB = "database"
|
||||
Knowledge = "knowledge"
|
||||
Internet = "internet"
|
||||
@@ -17,10 +17,12 @@ class ResourceType(Enum):
|
||||
TextFile = "text_file"
|
||||
ExcelFile = "excel_file"
|
||||
ImageFile = "image_file"
|
||||
AwelFlow = "awel_flow"
|
||||
AWELFlow = "awel_flow"
|
||||
|
||||
|
||||
class AgentResource(BaseModel):
|
||||
"""Agent resource class."""
|
||||
|
||||
type: ResourceType
|
||||
name: str
|
||||
value: str
|
||||
@@ -29,10 +31,12 @@ class AgentResource(BaseModel):
|
||||
)
|
||||
|
||||
def resource_prompt_template(self, **kwargs) -> str:
|
||||
return f"""{{data_type}} --{{data_introduce}}"""
|
||||
"""Get the resource prompt template."""
|
||||
return "{data_type} --{data_introduce}"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> Optional[AgentResource]:
|
||||
def from_dict(d: Dict[str, Any]) -> Optional["AgentResource"]:
|
||||
"""Create an AgentResource object from a dictionary."""
|
||||
if d is None:
|
||||
return None
|
||||
return AgentResource(
|
||||
@@ -44,16 +48,25 @@ class AgentResource(BaseModel):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_josn_list_str(d: Optional[str]) -> Optional[List[AgentResource]]:
|
||||
def from_json_list_str(d: Optional[str]) -> Optional[List["AgentResource"]]:
|
||||
"""Create a list of AgentResource objects from a json string."""
|
||||
if d is None:
|
||||
return None
|
||||
try:
|
||||
json_array = json.loads(d)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError(f"Illegal AgentResource json string!{d}")
|
||||
return [AgentResource.from_dict(item) for item in json_array]
|
||||
if not isinstance(json_array, list):
|
||||
raise ValueError(f"Illegal AgentResource json string!{d}")
|
||||
json_list = []
|
||||
for item in json_array:
|
||||
r = AgentResource.from_dict(item)
|
||||
if r:
|
||||
json_list.append(r)
|
||||
return json_list
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the AgentResource object to a dictionary."""
|
||||
temp = self.dict()
|
||||
for field, value in temp.items():
|
||||
if isinstance(value, Enum):
|
||||
@@ -62,29 +75,51 @@ class AgentResource(BaseModel):
|
||||
|
||||
|
||||
class ResourceClient(ABC):
|
||||
"""Resource client interface."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> ResourceType:
|
||||
pass
|
||||
"""Return the resource type."""
|
||||
|
||||
async def get_data_introduce(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> str:
|
||||
) -> Union[str, List[str]]:
|
||||
"""
|
||||
Get the content introduction prompt of the specified resource
|
||||
Get the content introduction prompt of the specified resource.
|
||||
|
||||
Args:
|
||||
value:
|
||||
resource(AgentResource): The specified resource.
|
||||
question(str): The question to be asked.
|
||||
|
||||
Returns:
|
||||
|
||||
str: The introduction content.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
"""Return the data type of the specified resource.
|
||||
|
||||
Args:
|
||||
resource(AgentResource): The specified resource.
|
||||
|
||||
Returns:
|
||||
str: The data type.
|
||||
"""
|
||||
return ""
|
||||
|
||||
async def get_resource_prompt(
|
||||
self, conv_uid, resource: AgentResource, question: Optional[str] = None
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get the resource prompt.
|
||||
|
||||
Args:
|
||||
resource(AgentResource): The specified resource.
|
||||
question(str): The question to be asked.
|
||||
|
||||
Returns:
|
||||
str: The resource prompt.
|
||||
"""
|
||||
return resource.resource_prompt_template().format(
|
||||
data_type=self.get_data_type(resource),
|
||||
data_introduce=await self.get_data_introduce(resource, question),
|
||||
|
@@ -1,52 +1,67 @@
|
||||
"""Database resource client API."""
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
|
||||
from .resource_api import ResourceClient, ResourceType
|
||||
from .resource_api import AgentResource, ResourceClient, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResourceDbClient(ResourceClient):
|
||||
"""Database resource client API."""
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return the resource type."""
|
||||
return ResourceType.DB
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
"""Return the data type of the resource."""
|
||||
return super().get_data_type(resource)
|
||||
|
||||
async def get_data_introduce(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> str:
|
||||
return await self.a_get_schema_link(resource.value, question)
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the data introduce of the resource."""
|
||||
return await self.get_schema_link(resource.value, question)
|
||||
|
||||
async def a_get_schema_link(self, db: str, question: Optional[str] = None) -> str:
|
||||
async def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def a_query_to_df(self, dbe: str, sql: str):
|
||||
async def query_to_df(self, dbe: str, sql: str):
|
||||
"""Return the query result as a DataFrame."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def a_query(self, db: str, sql: str):
|
||||
async def query(self, db: str, sql: str):
|
||||
"""Return the query result."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def a_run_sql(self, db: str, sql: str):
|
||||
async def run_sql(self, db: str, sql: str):
|
||||
"""Run the SQL."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
|
||||
class SqliteLoadClient(ResourceDbClient):
|
||||
"""SQLite resource client."""
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
def __init__(self):
|
||||
"""Create a SQLite resource client."""
|
||||
super(SqliteLoadClient, self).__init__()
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
"""Return the data type of the resource."""
|
||||
return "sqlite"
|
||||
|
||||
@contextmanager
|
||||
def connect(self, db) -> Session:
|
||||
from sqlalchemy import create_engine, text
|
||||
def connect(self, db) -> Iterator[Session]:
|
||||
"""Connect to the database."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
engine = create_engine("sqlite:///" + db, echo=True)
|
||||
@@ -55,17 +70,20 @@ class SqliteLoadClient(ResourceDbClient):
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except:
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def a_get_schema_link(self, db: str, question: Optional[str] = None) -> str:
|
||||
async def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
from sqlalchemy import text
|
||||
|
||||
with self.connect(db) as connect:
|
||||
_tables_sql = f"""
|
||||
_tables_sql = """
|
||||
SELECT name FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
cursor = connect.execute(text(_tables_sql))
|
||||
@@ -86,13 +104,15 @@ class SqliteLoadClient(ResourceDbClient):
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
||||
|
||||
async def a_query_to_df(self, db: str, sql: str):
|
||||
async def query_to_df(self, db: str, sql: str):
|
||||
"""Return the query result as a DataFrame."""
|
||||
import pandas as pd
|
||||
|
||||
field_names, result = await self.a_query(db, sql)
|
||||
field_names, result = await self.query(db, sql)
|
||||
return pd.DataFrame(result, columns=field_names)
|
||||
|
||||
async def a_query(self, db: str, sql: str):
|
||||
async def query(self, db: str, sql: str):
|
||||
"""Return the query result."""
|
||||
from sqlalchemy import text
|
||||
|
||||
with self.connect(db) as connect:
|
||||
@@ -100,10 +120,7 @@ class SqliteLoadClient(ResourceDbClient):
|
||||
if not sql:
|
||||
return []
|
||||
cursor = connect.execute(text(sql))
|
||||
if cursor.returns_rows:
|
||||
if cursor.returns_rows: # type: ignore
|
||||
result = cursor.fetchall()
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
return field_names, result
|
||||
|
||||
async def a_run_sql(self, db: str, sql: str):
|
||||
pass
|
||||
|
@@ -1,9 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .resource_api import ResourceClient, ResourceType
|
||||
|
||||
|
||||
class ResourceFileClient(ResourceClient):
|
||||
@property
|
||||
def type(self) -> ResourceType:
|
||||
return ResourceType.File
|
@@ -1,19 +1,23 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
"""Knowledge resource API for the agent."""
|
||||
from typing import Any, Optional
|
||||
|
||||
from .resource_api import ResourceClient, ResourceType
|
||||
|
||||
|
||||
class ResourceKnowledgeClient(ResourceClient):
|
||||
"""Knowledge resource client."""
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return the resource type."""
|
||||
return ResourceType.Knowledge
|
||||
|
||||
async def a_get_kn(self, space_name: str, question: Optional[str] = None) -> str:
|
||||
async def get_kn(self, space_name: str, question: Optional[str] = None) -> Any:
|
||||
"""Get the knowledge content."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def add_kn(
|
||||
self, space_name: str, kn_name: str, type: str, content: Optional[Any]
|
||||
):
|
||||
"""Add knowledge content."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
@@ -1,16 +1,26 @@
|
||||
"""Resource loader module."""
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
from .resource_api import ResourceClient, ResourceType
|
||||
|
||||
T = TypeVar("T", bound=ResourceClient)
|
||||
|
||||
|
||||
class ResourceLoader:
|
||||
"""Resource loader."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a new resource loader."""
|
||||
self._resource_api_instance = defaultdict(ResourceClient)
|
||||
|
||||
def get_resesource_api(
|
||||
self, resource_type: ResourceType
|
||||
) -> Optional[ResourceClient]:
|
||||
def get_resource_api(
|
||||
self,
|
||||
resource_type: Optional[ResourceType],
|
||||
cls: Optional[Type[T]] = None,
|
||||
check_instance: bool = True,
|
||||
) -> Optional[T]:
|
||||
"""Get the resource loader for the given resource type."""
|
||||
if not resource_type:
|
||||
return None
|
||||
|
||||
@@ -18,8 +28,13 @@ class ResourceLoader:
|
||||
raise ValueError(
|
||||
f"No loader available for resource of type {resource_type.value}"
|
||||
)
|
||||
inst = self._resource_api_instance[resource_type]
|
||||
if check_instance and cls and not isinstance(inst, cls):
|
||||
raise ValueError(
|
||||
f"Resource loader for {resource_type.value} is not an instance of {cls}"
|
||||
)
|
||||
return inst
|
||||
|
||||
return self._resource_api_instance[resource_type]
|
||||
|
||||
def register_resesource_api(self, api_instance: ResourceClient):
|
||||
def register_resource_api(self, api_instance: ResourceClient):
|
||||
"""Register the resource API instance."""
|
||||
self._resource_api_instance[api_instance.type] = api_instance
|
||||
|
@@ -1,58 +1,72 @@
|
||||
"""Resource plugin client API."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.agent.plugin.commands.command_manage import execute_command
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
from dbgpt.agent.plugin.plugins_util import scan_plugin_file, scan_plugins
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
from ..plugin.commands.command_manage import execute_command
|
||||
from ..plugin.generator import PluginPromptGenerator
|
||||
from ..plugin.plugins_util import scan_plugin_file, scan_plugins
|
||||
from ..resource.resource_api import AgentResource
|
||||
from .resource_api import ResourceClient, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResourcePluginClient(ResourceClient):
|
||||
"""Resource plugin client."""
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return the resource type."""
|
||||
return ResourceType.Plugin
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
"""Return the data type of the specified resource."""
|
||||
return "Tools"
|
||||
|
||||
async def get_data_introduce(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> str:
|
||||
return await self.a_plugins_prompt(resource.value)
|
||||
) -> Union[str, List[str]]:
|
||||
"""Get the content introduction prompt of the specified resource."""
|
||||
return await self.plugins_prompt(resource.value)
|
||||
|
||||
async def a_load_plugin(
|
||||
async def load_plugin(
|
||||
self,
|
||||
value: str,
|
||||
plugin_generator: Optional[PluginPromptGenerator] = None,
|
||||
) -> PluginPromptGenerator:
|
||||
"""Load the plugin."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def a_plugins_prompt(
|
||||
async def plugins_prompt(
|
||||
self, value: str, plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
) -> str:
|
||||
plugin_generator = await self.a_load_plugin(value)
|
||||
"""Get the plugin commands prompt."""
|
||||
plugin_generator = await self.load_plugin(value)
|
||||
return plugin_generator.generate_commands_string()
|
||||
|
||||
async def a_execute_command(
|
||||
async def execute_command(
|
||||
self,
|
||||
command_name: str,
|
||||
arguments: Optional[dict],
|
||||
plugin_generator: Optional[PluginPromptGenerator],
|
||||
):
|
||||
"""Execute the command."""
|
||||
if plugin_generator is None:
|
||||
raise ValueError("No plugin commands loaded into the executable!")
|
||||
return execute_command(command_name, arguments, plugin_generator)
|
||||
|
||||
|
||||
class PluginFileLoadClient(ResourcePluginClient):
|
||||
async def a_load_plugin(
|
||||
"""File plugin load client.
|
||||
|
||||
Load the plugin from the local file.
|
||||
"""
|
||||
|
||||
async def load_plugin(
|
||||
self, value: str, plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
) -> PluginPromptGenerator:
|
||||
"""Load the plugin."""
|
||||
logger.info(f"PluginFileLoadClient load plugin:{value}")
|
||||
if plugin_generator is None:
|
||||
plugin_generator = PluginPromptGenerator()
|
||||
@@ -66,10 +80,11 @@ class PluginFileLoadClient(ResourcePluginClient):
|
||||
plugins.extend(scan_plugins(value))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The current mode cannot support plug-in loading with relative paths.{value}"
|
||||
f"The current mode cannot support plug-in loading with relative "
|
||||
f"paths: {value}"
|
||||
)
|
||||
for plugin in plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
plugin_generator = plugin.post_prompt(plugin_generator)
|
||||
return plugin_generator
|
||||
return cast(PluginPromptGenerator, plugin_generator)
|
||||
|
Reference in New Issue
Block a user