refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

@@ -1,12 +0,0 @@
# Old packages
# __all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
__all__ = ["embedding_engine"]
def __getattr__(name: str):
import importlib
if name in ["embedding_engine"]:
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -1,87 +0,0 @@
"""Agentic Workflow Expression Language (AWEL)
Note:
AWEL is still an experimental feature and only opens the lowest level API.
The stability of this API cannot be guaranteed at present.
"""
from pilot.component import SystemApp
from .dag.base import DAGContext, DAG
from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
MapOperator,
BranchOperator,
InputOperator,
BranchFunc,
)
from .operator.stream_operator import (
StreamifyAbsOperator,
UnstreamifyAbsOperator,
TransformStreamAbsOperator,
)
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
from .task.task_impl import (
SimpleInputSource,
SimpleCallDataInputSource,
DefaultTaskContext,
DefaultInputContext,
SimpleTaskOutput,
SimpleStreamTaskOutput,
_is_async_iterator,
)
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner
__all__ = [
"initialize_awel",
"DAGContext",
"DAG",
"BaseOperator",
"JoinOperator",
"ReduceStreamOperator",
"MapOperator",
"BranchOperator",
"InputOperator",
"BranchFunc",
"WorkflowRunner",
"TaskState",
"TaskOutput",
"TaskContext",
"InputContext",
"InputSource",
"DefaultWorkflowRunner",
"SimpleInputSource",
"SimpleCallDataInputSource",
"DefaultTaskContext",
"DefaultInputContext",
"SimpleTaskOutput",
"SimpleStreamTaskOutput",
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"HttpTrigger",
]
def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .operator.base import initialize_runner
DAGVar.set_current_system_app(system_app)
system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()

View File

@@ -1,7 +0,0 @@
from abc import ABC, abstractmethod
class Trigger(ABC):
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@@ -1,364 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import logging
from collections import deque
from functools import cache
from concurrent.futures import Executor
from pilot.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
logger = logging.getLogger(__name__)
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
def _is_async_context():
try:
loop = asyncio.get_running_loop()
return asyncio.current_task(loop=loop) is not None
except RuntimeError:
return False
class DependencyMixin(ABC):
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
"""
@abstractmethod
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
"""
def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self << nodes
Example:
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes
def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self >> nodes
Example:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] >> self"""
self.__lshift__(nodes)
return self
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] << self"""
self.__rshift__(nodes)
return self
class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None
@classmethod
def enter_dag(cls, dag) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
stack.append(dag)
cls._async_local.set(stack)
else:
if not hasattr(cls._thread_local, "current_dag_stack"):
cls._thread_local.current_dag_stack = deque()
cls._thread_local.current_dag_stack.append(dag)
@classmethod
def exit_dag(cls) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
if stack:
stack.pop()
cls._async_local.set(stack)
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
cls._thread_local.current_dag_stack.pop()
@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
return stack[-1] if stack else None
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
return cls._thread_local.current_dag_stack[-1]
return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
if not cls._system_app:
raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
cls._executor = executor
class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
def __init__(
self,
dag: Optional["DAG"] = None,
node_id: Optional[str] = None,
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
self._node_name: str = node_name
@property
def node_id(self) -> str:
return self._node_id
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
def __hash__(self) -> int:
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> str:
return self._node_name
@property
def dag(self) -> "DAG":
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes)
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes, is_upstream=False)
@property
def upstream(self) -> List["DAGNode"]:
return self._upstream
@property
def downstream(self) -> List["DAGNode"]:
return self._downstream
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance of 'DAGNode'"
)
nodes: Sequence[DAGNode] = nodes
dags = set([node.dag for node in nodes if node.dag])
if self.dag:
dags.add(self.dag)
if not dags:
raise ValueError("set dependency to current node must in a DAG context")
if len(dags) != 1:
raise ValueError(
"set dependency to current node just support in one DAG context"
)
dag = dags.pop()
self._dag = dag
dag._append_node(self)
for node in nodes:
if is_upstream and node not in self.upstream:
node._dag = dag
dag._append_node(node)
self._upstream.append(node)
node._downstream.append(self)
elif node not in self._downstream:
node._dag = dag
dag._append_node(node)
self._downstream.append(node)
node._upstream.append(self)
class DAGContext:
def __init__(self) -> None:
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}
@property
def current_task_context(self) -> TaskContext:
return self._curr_task_ctx
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx
async def get_share_data(self, key: str) -> Any:
return self._share_data.get(key)
async def save_to_share_data(self, key: str, data: Any) -> None:
self._share_data[key] = data
class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
self._trigger_nodes = list(
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
)
@property
def root_nodes(self) -> List[DAGNode]:
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
def __enter__(self):
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
if not node:
return nodes
nodes.add(node)
stream_nodes = node.upstream if is_upstream else node.downstream
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes

View File

@@ -1,42 +0,0 @@
from typing import Dict, Optional
import logging
from pilot.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG
logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_filepath: str):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def load_dags(self):
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
ComponentType.AWEL_TRIGGER_MANAGER,
DefaultTriggerManager,
default_component=None,
)
if trigger_manager:
for trigger in triggers:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
else:
logger.warn("No trigger manager, not register dag trigger")

View File

@@ -1,93 +0,0 @@
from abc import ABC, abstractmethod
from typing import List
import os
import hashlib
import sys
import logging
import traceback
from .base import DAG
logger = logging.getLogger(__name__)
class DAGLoader(ABC):
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags"""
class LocalFileDAGLoader(DAGLoader):
def __init__(self, filepath: str) -> None:
super().__init__()
self._filepath = filepath
def load_dags(self) -> List[DAG]:
if not os.path.exists(self._filepath):
return []
if os.path.isdir(self._filepath):
return _process_directory(self._filepath)
else:
return _process_file(self._filepath)
def _process_directory(directory: str) -> List[DAG]:
dags = []
for file in os.listdir(directory):
if file.endswith(".py"):
filepath = os.path.join(directory, file)
dags += _process_file(filepath)
return dags
def _process_file(filepath) -> List[DAG]:
mods = _load_modules_from_file(filepath)
results = _process_modules(mods)
return results
def _load_modules_from_file(filepath: str):
import importlib
import importlib.machinery
import importlib.util
logger.info(f"Importing {filepath}")
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules:
del sys.modules[mod_name]
def parse(mod_name, filepath):
try:
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
spec = importlib.util.spec_from_loader(mod_name, loader)
new_module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception as e:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message
return []
return parse(mod_name, filepath)
def _process_modules(mods) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
found_dags = []
for dag, mod in top_level_dags:
try:
# TODO validate dag params
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to dag file, error message: {msg}")
return found_dags

View File

@@ -1,51 +0,0 @@
import pytest
import threading
import asyncio
from ..dag import DAG, DAGContext
def test_dag_context_sync():
dag1 = DAG("dag1")
dag2 = DAG("dag2")
with dag1:
assert DAGContext.get_current_dag() == dag1
with dag2:
assert DAGContext.get_current_dag() == dag2
assert DAGContext.get_current_dag() == dag1
assert DAGContext.get_current_dag() is None
def test_dag_context_threading():
def thread_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
thread1 = threading.Thread(target=thread_function, args=(dag1,))
thread2 = threading.Thread(target=thread_function, args=(dag2,))
thread1.start()
thread2.start()
thread1.join()
thread2.join()
assert DAGContext.get_current_dag() is None
@pytest.mark.asyncio
async def test_dag_context_async():
async def async_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
await asyncio.gather(async_function(dag1), async_function(dag2))
assert DAGContext.get_current_dag() is None

View File

@@ -1,206 +0,0 @@
from abc import ABC, abstractmethod, ABCMeta
from types import FunctionType
from typing import (
List,
Generic,
TypeVar,
AsyncIterator,
Union,
Any,
Dict,
Optional,
cast,
)
import functools
from inspect import signature
from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import (
ExecutorFactory,
DefaultExecutorFactory,
blocking_func_to_async,
BlockingFunction,
)
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import (
TaskContext,
TaskOutput,
TaskState,
OUT,
T,
InputContext,
InputSource,
)
F = TypeVar("F", bound=FunctionType)
CALL_DATA = Union[Dict, Dict[str, Dict]]
class WorkflowRunner(ABC, Generic[T]):
"""Abstract base class representing a runner for executing workflows in a DAG.
This class defines the interface for executing workflows within the DAG,
handling the flow from one DAG node to another.
"""
@abstractmethod
async def execute_workflow(
self, node: "BaseOperator", call_data: Optional[CALL_DATA] = None
) -> DAGContext:
"""Execute the workflow starting from a given operator.
Args:
node (RunnableDAGNode): The starting node of the workflow to be executed.
call_data (CALL_DATA): The data pass to root operator node.
Returns:
DAGContext: The context after executing the workflow, containing the final state and data.
"""
default_runner: WorkflowRunner = None
class BaseOperatorMeta(ABCMeta):
"""Metaclass of BaseOperator."""
@classmethod
def _apply_defaults(cls, func: F) -> F:
sig_cache = signature(func)
@functools.wraps(func)
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
task_id: Optional[str] = kwargs.get("task_id")
system_app: Optional[SystemApp] = (
kwargs.get("system_app") or DAGVar.get_current_system_app()
)
executor = kwargs.get("executor") or DAGVar.get_executor()
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
else:
executor = DefaultExecutorFactory().create()
DAGVar.set_executor(executor)
if not task_id and dag:
task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
# for arg in sig_cache.parameters:
# if arg not in kwargs:
# kwargs[arg] = default_args[arg]
if not kwargs.get("dag"):
kwargs["dag"] = dag
if not kwargs.get("task_id"):
kwargs["task_id"] = task_id
if not kwargs.get("runner"):
kwargs["runner"] = runner
if not kwargs.get("system_app"):
kwargs["system_app"] = system_app
if not kwargs.get("executor"):
kwargs["executor"] = executor
real_obj = func(self, *args, **kwargs)
return real_obj
return cast(T, apply_defaults)
def __new__(cls, name, bases, namespace, **kwargs):
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
return new_cls
class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
"""Abstract base class for operator nodes that can be executed within a workflow.
This class extends DAGNode by adding execution capabilities.
"""
def __init__(
self,
task_id: Optional[str] = None,
task_name: Optional[str] = None,
dag: Optional[DAG] = None,
runner: WorkflowRunner = None,
**kwargs,
) -> None:
"""Initializes a BaseOperator with an optional workflow runner.
Args:
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
"""
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
if not runner:
from pilot.awel import DefaultWorkflowRunner
runner = DefaultWorkflowRunner()
self._runner: WorkflowRunner = runner
self._dag_ctx: DAGContext = None
@property
def current_dag_context(self) -> DAGContext:
return self._dag_ctx
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
if not self.node_id:
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
self._dag_ctx = dag_ctx
return await self._do_run(dag_ctx)
@abstractmethod
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""
Abstract method to run the task within the DAG node.
Args:
dag_ctx (DAGContext): The context of the DAG when this node is run.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
"""Execute the node and return the output.
This method is a high-level wrapper for executing the node.
Args:
call_data (CALL_DATA): The data pass to root operator node.
Returns:
OUT: The output of the node after execution.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output
async def call_stream(
self, call_data: Optional[CALL_DATA] = None
) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream.
This method is used for nodes where the output is a stream.
Args:
call_data (CALL_DATA): The data pass to root operator node.
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output_stream
async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
) -> Any:
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
def initialize_runner(runner: WorkflowRunner):
global default_runner
default_runner = runner

View File

@@ -1,246 +0,0 @@
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
import asyncio
import logging
from ..dag.base import DAGContext
from ..task.base import (
TaskContext,
TaskOutput,
IN,
OUT,
InputContext,
InputSource,
)
from .base import BaseOperator
logger = logging.getLogger(__name__)
class JoinOperator(BaseOperator, Generic[OUT]):
"""Operator that joins inputs using a custom combine function.
This node type is useful for combining the outputs of upstream nodes.
"""
def __init__(self, combine_function, **kwargs):
super().__init__(**kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")
self.combine_function = combine_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the join operation on the DAG context's inputs.
Args:
dag_ctx (DAGContext): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
input_ctx: InputContext = await curr_task_ctx.task_input.map_all(
self.combine_function
)
# All join result store in the first parent output
join_output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(join_output)
return join_output
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
def __init__(self, reduce_function=None, **kwargs):
"""Initializes a ReduceStreamOperator with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
Raises:
ValueError: If the combine_function is not callable.
"""
super().__init__(**kwargs)
if reduce_function and not callable(reduce_function):
raise ValueError("reduce_function must be callable")
self.reduce_function = reduce_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the join operation on the DAG context's inputs.
Args:
dag_ctx (DAGContext): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_input = curr_task_ctx.task_input
if not task_input.check_stream():
raise ValueError("ReduceStreamOperator expects stream data")
if not task_input.check_single_parent():
raise ValueError("ReduceStreamOperator expects single parent")
reduce_function = self.reduce_function or self.reduce
input_ctx: InputContext = await task_input.reduce(reduce_function)
# All join result store in the first parent output
reduce_output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(reduce_output)
return reduce_output
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
raise NotImplementedError
class MapOperator(BaseOperator, Generic[IN, OUT]):
"""Map operator that applies a mapping function to its inputs.
This operator transforms its input data using a provided mapping function and
passes the transformed data downstream.
"""
def __init__(self, map_function=None, **kwargs):
"""Initializes a MapDAGNode with a mapping function.
Args:
map_function: A function that defines how to map the input data.
Raises:
ValueError: If the map_function is not callable.
"""
super().__init__(**kwargs)
if map_function and not callable(map_function):
raise ValueError("map_function must be callable")
self.map_function = map_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the mapping operation on the DAG context's inputs.
This method applies the mapping function to the input context and updates
the DAG context with the new data.
Args:
dag_ctx (DAGContext[IN]): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
Raises:
ValueError: If not a single parent or the map_function is not callable
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
if not curr_task_ctx.task_input.check_single_parent():
num_parents = len(curr_task_ctx.task_input.parent_outputs)
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
)
map_function = self.map_function or self.map
input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function)
# All join result store in the first parent output
reduce_output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(reduce_output)
return reduce_output
async def map(self, input_value: IN) -> OUT:
raise NotImplementedError
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
class BranchOperator(BaseOperator, Generic[IN, OUT]):
"""Operator node that branches the workflow based on a provided function.
This node filters its input data using a branching function and
allows for conditional paths in the workflow.
"""
def __init__(
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
):
"""
Initializes a BranchDAGNode with a branching function.
Args:
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
Raises:
ValueError: If the branch_function is not callable.
"""
super().__init__(**kwargs)
if branches:
for branch_function, value in branches.items():
if not callable(branch_function):
raise ValueError("branch_function must be callable")
if isinstance(value, BaseOperator):
branches[branch_function] = value.node_name or value.node_name
self._branches = branches
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the branching operation on the DAG context's inputs.
This method applies the branching function to the input context to determine
the path of execution in the workflow.
Args:
dag_ctx (DAGContext[IN]): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_input = curr_task_ctx.task_input
if task_input.check_stream():
raise ValueError("BranchDAGNode expects no stream data")
if not task_input.check_single_parent():
raise ValueError("BranchDAGNode expects single parent")
branches = self._branches
if not branches:
branches = await self.branchs()
branch_func_tasks = []
branch_nodes: List[str] = []
for func, node_name in branches.items():
branch_nodes.append(node_name)
branch_func_tasks.append(
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
)
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
parent_output = task_input.parent_outputs[0].task_output
curr_task_ctx.set_task_output(parent_output)
skip_node_names = []
for i, ctx in enumerate(branch_input_ctxs):
node_name = branch_nodes[i]
branch_out = ctx.parent_outputs[0].task_output
logger.info(
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
)
if ctx.parent_outputs[0].task_output.is_empty:
logger.info(f"Skip node name {node_name}")
skip_node_names.append(node_name)
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
raise NotImplementedError
class InputOperator(BaseOperator, Generic[OUT]):
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
super().__init__(**kwargs)
self._input_source = input_source
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_output = await self._input_source.read(curr_task_ctx)
curr_task_ctx.set_task_output(task_output)
return task_output
class TriggerOperator(InputOperator, Generic[OUT]):
def __init__(self, **kwargs) -> None:
from ..task.task_impl import SimpleCallDataInputSource
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)

View File

@@ -1,90 +0,0 @@
from abc import ABC, abstractmethod
from typing import Generic, AsyncIterator
from ..task.base import OUT, IN, TaskOutput, TaskContext
from ..dag.base import DAGContext
from .base import BaseOperator
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
self.streamify
)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
"""Convert a value of IN to an AsyncIterator[OUT]
Args:
input_value (IN): The data of parent operator's output
Example:
.. code-block:: python
class MyStreamOperator(StreamifyAbsOperator[int, int]):
async def streamify(self, input_value: int) -> AsyncIterator[int]
for i in range(input_value):
yield i
"""
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.unstreamify(self.unstreamify)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT:
"""Convert a value of AsyncIterator[IN] to an OUT.
Args:
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
.. code-block:: python
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> int
value_cnt = 0
async for v in input_value:
value_cnt += 1
return value_cnt
"""
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.transform_stream(self.transform_stream)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def transform_stream(
self, input_value: AsyncIterator[IN]
) -> AsyncIterator[OUT]:
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
Args:
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
.. code-block:: python
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
async for v in input_value:
yield v + 1
"""

View File

@@ -1,8 +0,0 @@
from abc import ABC, abstractmethod
class ResourceGroup(ABC):
@property
@abstractmethod
def name(self) -> str:
"""The name of current resource group"""

View File

@@ -1,82 +0,0 @@
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG
from ..operator.base import BaseOperator, CALL_DATA
logger = logging.getLogger(__name__)
class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None:
pass
class DAGInstance:
def __init__(self, dag: DAG) -> None:
self._dag = dag
class JobManager:
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
) -> None:
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
self._id2node_data = id2call_data
@staticmethod
def build_from_end_node(
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> "JobManager":
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
return JobManager(root_nodes, nodes, end_node, id2call_data)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
return self._id2node_data.get(node_id)
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]:
id2call_data = {}
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
if not call_data:
return id2call_data
if len(root_nodes) == 1:
node = root_nodes[0]
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
id2call_data[node.node_id] = call_data
else:
for node in root_nodes:
node_id = node.node_id
logger.info(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id)
return id2call_data
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id
if not task_id:
task_id = str(uuid.uuid4())
end_node.set_node_id(task_id)
nodes.append(end_node)
for node in end_node.upstream:
nodes += _build_from_end_node(node)
return nodes
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
return list(set(filter(lambda x: not x.upstream, nodes)))

View File

@@ -1,106 +0,0 @@
from typing import Dict, Optional, Set, List
import logging
from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
async def execute_workflow(
self, node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> DAGContext:
# Create DAG context
dag_ctx = DAGContext()
job_manager = JobManager.build_from_end_node(node, call_data)
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)
dag = node.dag
# Save node output
node_outputs: Dict[str, TaskContext] = {}
skip_node_ids = set()
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids
)
return dag_ctx
async def _execute_node(
self,
job_manager: JobManager,
node: BaseOperator,
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
):
# Skip run node
if node.node_id in node_outputs:
return
# Run all upstream node
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
)
inputs = [
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
input_ctx = DefaultInputContext(inputs)
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
task_ctx.set_current_state(TaskState.RUNNING)
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(None))
node_outputs[node.node_id] = task_ctx
return
try:
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)
if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.debug(
f"Current is branch operator, skip node names: {skip_nodes}"
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
except Exception as e:
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
task_ctx.set_current_state(TaskState.FAILED)
raise e
def _skip_current_downstream_by_node_name(
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
):
if not skip_nodes:
return
for child in branch_node.downstream:
if child.node_name in skip_nodes:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
if isinstance(node, JoinOperator):
# Not skip join node
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
_skip_downstream_by_id(child, skip_node_ids)

View File

@@ -1,367 +0,0 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TypeVar,
Generic,
Optional,
AsyncIterator,
Union,
Callable,
Any,
Dict,
List,
)
IN = TypeVar("IN")
OUT = TypeVar("OUT")
T = TypeVar("T")
class TaskState(str, Enum):
"""Enumeration representing the state of a task in the workflow.
This Enum defines various states a task can be in during its lifecycle in the DAG.
"""
INIT = "init" # Initial state of the task, not yet started
SKIP = "skip" # State indicating the task was skipped
RUNNING = "running" # State indicating the task is currently running
SUCCESS = "success" # State indicating the task completed successfully
FAILED = "failed" # State indicating the task failed during execution
class TaskOutput(ABC, Generic[T]):
"""Abstract base class representing the output of a task.
This class encapsulates the output of a task and provides methods to access the output data.
It can be subclassed to implement specific output behaviors.
"""
@property
def is_stream(self) -> bool:
"""Check if the output is a stream.
Returns:
bool: True if the output is a stream, False otherwise.
"""
return False
@property
def is_empty(self) -> bool:
"""Check if the output is empty.
Returns:
bool: True if the output is empty, False otherwise.
"""
return False
@property
def output(self) -> Optional[T]:
"""Return the output of the task.
Returns:
T: The output of the task. None if the output is empty.
"""
raise NotImplementedError
@property
def output_stream(self) -> Optional[AsyncIterator[T]]:
"""Return the output of the task as an asynchronous stream.
Returns:
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
"""
raise NotImplementedError
@abstractmethod
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
"""Set the output data to current object.
Args:
output_data (Union[T, AsyncIterator[T]]): Output data.
"""
@abstractmethod
def new_output(self) -> "TaskOutput[T]":
"""Create new output object"""
async def map(self, map_func) -> "TaskOutput[T]":
"""Apply a mapping function to the task's output.
Args:
map_func: A function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the mapping function.
"""
raise NotImplementedError
async def reduce(self, reduce_func) -> "TaskOutput[T]":
"""Apply a reducing function to the task's output.
Stream TaskOutput to Nonstream TaskOutput.
Args:
reduce_func: A reducing function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
Args:
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> "TaskOutput[T]":
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
Args:
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def check_condition(self, condition_func) -> bool:
"""Check if current output meets a given condition.
Args:
condition_func: A function to determine if the condition is met.
Returns:
bool: True if current output meet the condition, False otherwise.
"""
raise NotImplementedError
class TaskContext(ABC, Generic[T]):
"""Abstract base class representing the context of a task within a DAG.
This class provides the interface for accessing task-related information
and manipulating task output.
"""
@property
@abstractmethod
def task_id(self) -> str:
"""Return the unique identifier of the task.
Returns:
str: The unique identifier of the task.
"""
@property
@abstractmethod
def task_input(self) -> "InputContext":
"""Return the InputContext of current task.
Returns:
InputContext: The InputContext of current task.
"""
@abstractmethod
def set_task_input(self, input_ctx: "InputContext") -> None:
"""Set the InputContext object to current task.
Args:
input_ctx (InputContext): The InputContext of current task
"""
@property
@abstractmethod
def task_output(self) -> TaskOutput[T]:
"""Return the output object of the task.
Returns:
TaskOutput[T]: The output object of the task.
"""
@abstractmethod
def set_task_output(self, task_output: TaskOutput[T]) -> None:
"""Set the output object to current task."""
@property
@abstractmethod
def current_state(self) -> TaskState:
"""Get the current state of the task.
Returns:
TaskState: The current state of the task.
"""
@abstractmethod
def set_current_state(self, task_state: TaskState) -> None:
"""Set current task state
Args:
task_state (TaskState): The task state to be set.
"""
@abstractmethod
def new_ctx(self) -> "TaskContext":
"""Create new task context
Returns:
TaskContext: A new instance of a TaskContext.
"""
@property
@abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Get the metadata of current task
Returns:
Dict[str, Any]: The metadata
"""
def update_metadata(self, key: str, value: Any) -> None:
"""Update metadata with key and value
Args:
key (str): The key of metadata
value (str): The value to be add to metadata
"""
self.metadata[key] = value
@property
def call_data(self) -> Optional[Dict]:
"""Get the call data for current data"""
return self.metadata.get("call_data")
def set_call_data(self, call_data: Dict) -> None:
"""Set call data for current task"""
self.update_metadata("call_data", call_data)
class InputContext(ABC):
"""Abstract base class representing the context of inputs to a operator node.
This class defines methods to manipulate and access the inputs for a operator node.
"""
@property
@abstractmethod
def parent_outputs(self) -> List[TaskContext]:
"""Get the outputs from the parent nodes.
Returns:
List[TaskContext]: A list of contexts of the parent nodes' outputs.
"""
@abstractmethod
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
"""Apply a mapping function to the inputs.
Args:
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
Returns:
InputContext: A new InputContext instance with the mapped inputs.
"""
@abstractmethod
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
"""Apply a mapping function to all inputs.
Args:
map_func (Callable[..., Any]): A function to be applied to all inputs.
Returns:
InputContext: A new InputContext instance with the mapped inputs.
"""
@abstractmethod
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
"""Apply a reducing function to the inputs.
Args:
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
Returns:
InputContext: A new InputContext instance with the reduced inputs.
"""
@abstractmethod
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
"""Filter the inputs based on a provided function.
Args:
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
Returns:
InputContext: A new InputContext instance with the filtered inputs.
"""
@abstractmethod
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
) -> "InputContext":
"""Predicate the inputs based on a provided function.
Args:
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate function is False
Returns:
InputContext: A new InputContext instance with the predicate inputs.
"""
def check_single_parent(self) -> bool:
"""Check if there is only a single parent output.
Returns:
bool: True if there is only one parent output, False otherwise.
"""
return len(self.parent_outputs) == 1
def check_stream(self, skip_empty: bool = False) -> bool:
"""Check if all parent outputs are streams.
Args:
skip_empty (bool): Skip empty output or not.
Returns:
bool: True if all parent outputs are streams, False otherwise.
"""
for out in self.parent_outputs:
if out.task_output.is_empty and skip_empty:
continue
if not (out.task_output and out.task_output.is_stream):
return False
return True
class InputSource(ABC, Generic[T]):
"""Abstract base class representing the source of inputs to a DAG node."""
@abstractmethod
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
"""Read the data from current input source.
Returns:
TaskOutput[T]: The output object read from current source
"""

View File

@@ -1,339 +0,0 @@
from abc import ABC, abstractmethod
from typing import (
Callable,
Coroutine,
Iterator,
AsyncIterator,
List,
Generic,
TypeVar,
Any,
Tuple,
Dict,
Union,
)
import asyncio
import logging
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
logger = logging.getLogger(__name__)
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
# Init accumulator
try:
accumulator = await stream.__anext__()
except StopAsyncIteration:
raise ValueError("Stream is empty")
is_async = asyncio.iscoroutinefunction(reduce_function)
async for element in stream:
if is_async:
accumulator = await reduce_function(accumulator, element)
else:
accumulator = reduce_function(accumulator, element)
return accumulator
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: T) -> None:
super().__init__()
self._data = data
@property
def output(self) -> T:
return self._data
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleTaskOutput(None)
@property
def is_empty(self) -> bool:
return not self._data
async def _apply_func(self, func) -> Any:
if asyncio.iscoroutinefunction(func):
out = await func(self._data)
else:
out = func(self._data)
return out
async def map(self, map_func) -> TaskOutput[T]:
out = await self._apply_func(map_func)
return SimpleTaskOutput(out)
async def check_condition(self, condition_func) -> bool:
return await self._apply_func(condition_func)
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> TaskOutput[T]:
out = await self._apply_func(transform_func)
return SimpleStreamTaskOutput(out)
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: AsyncIterator[T]) -> None:
super().__init__()
self._data = data
@property
def is_stream(self) -> bool:
return True
@property
def is_empty(self) -> bool:
return not self._data
@property
def output_stream(self) -> AsyncIterator[T]:
return self._data
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleStreamTaskOutput(None)
async def map(self, map_func) -> TaskOutput[T]:
is_async = asyncio.iscoroutinefunction(map_func)
async def new_iter() -> AsyncIterator[T]:
async for out in self._data:
if is_async:
out = await map_func(out)
else:
out = map_func(out)
yield out
return SimpleStreamTaskOutput(new_iter())
async def reduce(self, reduce_func) -> TaskOutput[T]:
out = await _reduce_stream(self._data, reduce_func)
return SimpleTaskOutput(out)
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> TaskOutput[T]:
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
else:
out = transform_func(self._data)
return SimpleTaskOutput(out)
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> TaskOutput[T]:
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
else:
out = transform_func(self._data)
return SimpleStreamTaskOutput(out)
def _is_async_iterator(obj):
return (
hasattr(obj, "__anext__")
and callable(getattr(obj, "__anext__", None))
and hasattr(obj, "__aiter__")
and callable(getattr(obj, "__aiter__", None))
)
class BaseInputSource(InputSource, ABC):
def __init__(self) -> None:
super().__init__()
self._is_read = False
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data with task context"""
async def read(self, task_ctx: TaskContext) -> Coroutine[Any, Any, TaskOutput]:
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output = SimpleStreamTaskOutput(data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
return output
class SimpleInputSource(BaseInputSource):
def __init__(self, data: Any) -> None:
super().__init__()
self._data = data
def _read_data(self, task_ctx: TaskContext) -> Any:
return self._data
class SimpleCallDataInputSource(BaseInputSource):
def __init__(self) -> None:
super().__init__()
def _read_data(self, task_ctx: TaskContext) -> Any:
call_data = task_ctx.call_data
data = call_data.get("data") if call_data else None
if not (call_data and data):
raise ValueError("No call data for current SimpleCallDataInputSource")
return data
class DefaultTaskContext(TaskContext, Generic[T]):
def __init__(
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
) -> None:
super().__init__()
self._task_id = task_id
self._task_state = task_state
self._output = task_output
self._task_input = None
self._metadata = {}
@property
def task_id(self) -> str:
return self._task_id
@property
def task_input(self) -> InputContext:
return self._task_input
def set_task_input(self, input_ctx: "InputContext") -> None:
self._task_input = input_ctx
@property
def task_output(self) -> TaskOutput:
return self._output
def set_task_output(self, task_output: TaskOutput) -> None:
self._output = task_output
@property
def current_state(self) -> TaskState:
return self._task_state
def set_current_state(self, task_state: TaskState) -> None:
self._task_state = task_state
def new_ctx(self) -> TaskContext:
new_output = self._output.new_output()
return DefaultTaskContext(self._task_id, self._task_state, new_output)
@property
def metadata(self) -> Dict[str, Any]:
return self._metadata
class DefaultInputContext(InputContext):
def __init__(self, outputs: List[TaskContext]) -> None:
super().__init__()
self._outputs = outputs
@property
def parent_outputs(self) -> List[TaskContext]:
return self._outputs
async def _apply_func(
self, func: Callable[[Any], Any], apply_type: str = "map"
) -> Tuple[List[TaskContext], List[TaskOutput]]:
new_outputs: List[TaskContext] = []
map_tasks = []
for out in self._outputs:
new_outputs.append(out.new_ctx())
result = None
if apply_type == "map":
result = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
elif apply_type == "check_condition":
result = out.task_output.check_condition(func)
else:
raise ValueError(f"Unsupport apply type {apply_type}")
map_tasks.append(result)
results = await asyncio.gather(*map_tasks)
return new_outputs, results
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
new_outputs, results = await self._apply_func(map_func)
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
if not self._outputs:
return DefaultInputContext([])
# Some parent may be empty
not_empty_idx = 0
for i, p in enumerate(self._outputs):
if p.task_output.is_empty:
continue
not_empty_idx = i
break
# All output is empty?
is_steam = self._outputs[not_empty_idx].task_output.is_stream
if is_steam:
if not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
outputs = []
for out in self._outputs:
if out.task_output.is_stream:
outputs.append(out.task_output.output_stream)
else:
outputs.append(out.task_output.output)
if asyncio.iscoroutinefunction(map_func):
map_res = await map_func(*outputs)
else:
map_res = map_func(*outputs)
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
single_output.task_output.set_output(map_res)
logger.debug(
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
new_outputs, results = await self._apply_func(
filter_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
if results[i]:
result_outputs.append(task_ctx)
return DefaultInputContext(result_outputs)
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
) -> "InputContext":
new_outputs, results = await self._apply_func(
predicate_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
if results[i]:
task_ctx.task_output.set_output(True)
result_outputs.append(task_ctx)
else:
task_ctx.task_output.set_output(failed_value)
result_outputs.append(task_ctx)
return DefaultInputContext(result_outputs)

View File

@@ -1,102 +0,0 @@
import pytest
import pytest_asyncio
from typing import AsyncIterator, List
from contextlib import contextmanager, asynccontextmanager
from .. import (
WorkflowRunner,
InputOperator,
DAGContext,
TaskState,
DefaultWorkflowRunner,
SimpleInputSource,
)
from ..task.task_impl import _is_async_iterator
@pytest.fixture
def runner():
return DefaultWorkflowRunner()
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
iters = []
for _ in range(num_nodes):
async def stream_iter():
for i in range(10):
yield i
stream_iter = stream_iter()
assert _is_async_iterator(stream_iter)
iters.append(stream_iter)
return iters
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
iters = []
for single_stream in output_streams:
async def stream_iter():
for i in single_stream:
yield i
stream_iter = stream_iter()
assert _is_async_iterator(stream_iter)
iters.append(stream_iter)
return iters
@asynccontextmanager
async def _create_input_node(**kwargs):
num_nodes = kwargs.get("num_nodes")
is_stream = kwargs.get("is_stream", False)
if is_stream:
outputs = kwargs.get("output_streams")
if outputs:
if num_nodes and num_nodes != len(outputs):
raise ValueError(
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
)
outputs = _create_stream_from(outputs)
else:
num_nodes = num_nodes or 1
outputs = _create_stream(num_nodes)
else:
outputs = kwargs.get("outputs", ["Hello."])
nodes = []
for output in outputs:
print(f"output: {output}")
input_source = SimpleInputSource(output)
input_node = InputOperator(input_source)
nodes.append(input_node)
yield nodes
@pytest_asyncio.fixture
async def input_node(request):
param = getattr(request, "param", {})
async with _create_input_node(**param) as input_nodes:
yield input_nodes[0]
@pytest_asyncio.fixture
async def stream_input_node(request):
param = getattr(request, "param", {})
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes[0]
@pytest_asyncio.fixture
async def input_nodes(request):
param = getattr(request, "param", {})
async with _create_input_node(**param) as input_nodes:
yield input_nodes
@pytest_asyncio.fixture
async def stream_input_nodes(request):
param = getattr(request, "param", {})
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes

View File

@@ -1,51 +0,0 @@
import pytest
from typing import List
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
ReduceStreamOperator,
SimpleInputSource,
)
from .conftest import (
runner,
input_node,
input_nodes,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)
def _register_dag_to_fastapi_app(dag):
# TODO
pass
@pytest.mark.asyncio
async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator):
with DAG("test_map") as dag:
pass
# http_req_task = HttpRequestOperator(endpoint="/api/completions")
# db_task = DBQueryOperator(table_name="user_info")
# prompt_task = PromptTemplateOperator(
# system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
# )
# llm_task = ChatGPTLLMOperator(model="chagpt-3.5")
# output_parser_task = CommonOutputParserOperator()
# http_res_task = HttpResponseOperator()
# (
# http_req_task
# >> db_task
# >> prompt_task
# >> llm_task
# >> output_parser_task
# >> http_res_task
# )
_register_dag_to_fastapi_app(dag)

View File

@@ -1,141 +0,0 @@
import pytest
from typing import List
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
ReduceStreamOperator,
SimpleInputSource,
)
from .conftest import (
runner,
input_node,
input_nodes,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)
@pytest.mark.asyncio
async def test_input_node(runner: WorkflowRunner):
input_node = InputOperator(SimpleInputSource("hello"))
res: DAGContext[str] = await runner.execute_workflow(input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert res.current_task_context.task_output.output == "hello"
async def new_steam_iter(n: int):
for i in range(n):
yield i
num_iter = 10
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
output_steam = res.current_task_context.task_output.output_stream
assert output_steam
assert _is_async_iterator(output_steam)
i = 0
async for x in output_steam:
assert x == i
i += 1
@pytest.mark.asyncio
async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator):
with DAG("test_map") as dag:
map_node = MapOperator(lambda x: x * 2)
stream_input_node >> map_node
res: DAGContext[int] = await runner.execute_workflow(map_node)
output_steam = res.current_task_context.task_output.output_stream
assert output_steam
i = 0
async for x in output_steam:
assert x == i * 2
i += 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"stream_input_node, expect_sum",
[
({"output_streams": [[0, 1, 2, 3]]}, 6),
({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55),
],
indirect=["stream_input_node"],
)
async def test_reduce_node(
runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int
):
with DAG("test_reduce_node") as dag:
reduce_node = ReduceStreamOperator(lambda x, y: x + y)
stream_input_node >> reduce_node
res: DAGContext[int] = await runner.execute_workflow(reduce_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert not res.current_task_context.task_output.is_stream
assert res.current_task_context.task_output.output == expect_sum
@pytest.mark.asyncio
@pytest.mark.parametrize(
"input_nodes",
[
({"outputs": [0, 1, 2]}),
],
indirect=["input_nodes"],
)
async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]):
def join_func(p1, p2, p3) -> int:
return p1 + p2 + p3
with DAG("test_join_node") as dag:
join_node = JoinOperator(join_func)
for input_node in input_nodes:
input_node >> join_node
res: DAGContext[int] = await runner.execute_workflow(join_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert not res.current_task_context.task_output.is_stream
assert res.current_task_context.task_output.output == 3
@pytest.mark.asyncio
@pytest.mark.parametrize(
"input_node, is_odd",
[
({"outputs": [0]}, False),
({"outputs": [1]}, True),
],
indirect=["input_node"],
)
async def test_branch_node(
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
):
def join_func(o1, o2) -> int:
print(f"join func result, o1: {o1}, o2: {o2}")
return o1 or o2
with DAG("test_join_node") as dag:
odd_node = MapOperator(
lambda x: 999, task_id="odd_node", task_name="odd_node_name"
)
even_node = MapOperator(
lambda x: 888, task_id="even_node", task_name="even_node_name"
)
join_node = JoinOperator(join_func)
branch_node = BranchOperator(
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
)
branch_node >> odd_node >> join_node
branch_node >> even_node >> join_node
input_node >> branch_node
res: DAGContext[int] = await runner.execute_workflow(join_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
expect_res = 999 if is_odd else 888
assert res.current_task_context.task_output.output == expect_res

View File

@@ -1,11 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from ..operator.common_operator import TriggerOperator
class Trigger(TriggerOperator, ABC):
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@@ -1,137 +0,0 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
from starlette.requests import Request
from starlette.responses import Response
from pydantic import BaseModel
import logging
from .base import Trigger
from ..dag.base import DAG
from ..operator.base import BaseOperator
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
RequestBody = Union[Request, Type[BaseModel], str]
logger = logging.getLogger(__name__)
class HttpTrigger(Trigger):
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
self._endpoint = endpoint
self._methods = methods
self._req_body = request_body
self._streaming_response = streaming_response
self._response_model = response_model
self._status_code = status_code
self._router_tags = router_tags
self._response_headers = response_headers
self._response_media_type = response_media_type
self._end_node: BaseOperator = None
async def trigger(self) -> None:
pass
def mount_to_router(self, router: "APIRouter") -> None:
from fastapi import Depends
methods = self._methods if isinstance(self._methods, list) else [self._methods]
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)
async def route_function(body=Depends(_request_body_dependency)):
return await _trigger_dag(
body,
self.dag,
self._streaming_response,
self._response_headers,
self._response_media_type,
)
route_function.__name__ = name
return route_function
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
request_model = (
self._req_body
if isinstance(self._req_body, type)
and issubclass(self._req_body, BaseModel)
else None
)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
)
router.api_route(
self._endpoint,
methods=methods,
response_model=self._response_model,
status_code=self._status_code,
tags=self._router_tags,
)(dynamic_route_function)
async def _parse_request_body(
request: Request, request_body_cls: Optional[Type[BaseModel]]
):
if not request_body_cls:
return None
if request.method == "POST":
json_data = await request.json()
return request_body_cls(**json_data)
elif request.method == "GET":
return request_body_cls(**request.query_params)
else:
return request
async def _trigger_dag(
body: Any,
dag: DAG,
streaming_response: Optional[bool] = False,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
) -> Any:
from fastapi.responses import StreamingResponse
end_node = dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = response_headers
media_type = response_media_type if response_media_type else "text/event-stream"
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)

View File

@@ -1,74 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING, Optional
import logging
if TYPE_CHECKING:
from fastapi import APIRouter
from pilot.component import SystemApp, BaseComponent, ComponentType
logger = logging.getLogger(__name__)
class TriggerManager(ABC):
@abstractmethod
def register_trigger(self, trigger: Any) -> None:
""" "Register a trigger to current manager"""
class HttpTriggerManager(TriggerManager):
def __init__(
self,
router: Optional["APIRouter"] = None,
router_prefix: Optional[str] = "/api/v1/awel/trigger",
) -> None:
if not router:
from fastapi import APIRouter
router = APIRouter()
self._router_prefix = router_prefix
self._router = router
self._trigger_map = {}
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if not isinstance(trigger, HttpTrigger):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger: HttpTrigger = trigger
trigger_id = trigger.node_id
if trigger_id not in self._trigger_map:
trigger.mount_to_router(self._router)
self._trigger_map[trigger_id] = trigger
def _init_app(self, system_app: SystemApp):
logger.info(
f"Include router {self._router} to prefix path {self._router_prefix}"
)
system_app.app.include_router(
self._router, prefix=self._router_prefix, tags=["AWEL"]
)
class DefaultTriggerManager(TriggerManager, BaseComponent):
name = ComponentType.AWEL_TRIGGER_MANAGER
def __init__(self, system_app: SystemApp | None = None):
self.system_app = system_app
self.http_trigger = HttpTriggerManager()
super().__init__(None)
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if isinstance(trigger, HttpTrigger):
logger.info(f"Register trigger {trigger}")
self.http_trigger.register_trigger(trigger)
else:
raise ValueError(f"Unsupport trigger: {trigger}")
def after_register(self) -> None:
self.http_trigger._init_app(self.system_app)

View File

@@ -1,11 +0,0 @@
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
from .commands.command import execute_command, get_command
from .commands.generator import PluginPromptGenerator
from .commands.disply_type.show_chart_gen import static_message_img_path
from .common.schema import Status, PluginStorageType
from .commands.command_mange import ApiCall
from .commands.command import execute_command

View File

@@ -1,61 +0,0 @@
"""Commands for converting audio to text."""
import json
import requests
from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
CFG = Config()
@command(
"read_audio_from_file",
"Convert Audio to text",
'"filename": "<filename>"',
CFG.huggingface_audio_to_text_model,
"Configure huggingface_audio_to_text_model.",
)
def read_audio_from_file(filename: str) -> str:
"""
Convert audio to text.
Args:
filename (str): The path to the audio file
Returns:
str: The text from the audio
"""
with open(filename, "rb") as audio_file:
audio = audio_file.read()
return read_audio(audio)
def read_audio(audio: bytes) -> str:
"""
Convert audio to text.
Args:
audio (bytes): The audio to convert
Returns:
str: The text from the audio
"""
model = CFG.huggingface_audio_to_text_model
api_url = f"https://api-inference.huggingface.co/models/{model}"
api_token = CFG.huggingface_api_token
headers = {"Authorization": f"Bearer {api_token}"}
if api_token is None:
raise ValueError(
"You need to set your Hugging Face API token in the config file."
)
response = requests.post(
api_url,
headers=headers,
data=audio,
)
text = json.loads(response.content.decode("utf-8"))["text"]
return f"The audio says: {text}"

View File

@@ -1,124 +0,0 @@
""" Image Generation Module for AutoGPT."""
import io
import uuid
from base64 import b64decode
import logging
import requests
from PIL import Image
from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
logger = logging.getLogger(__name__)
CFG = Config()
@command("generate_image", "Generate Image", '"prompt": "<prompt>"', CFG.image_provider)
def generate_image(prompt: str, size: int = 256) -> str:
"""Generate an image from a prompt.
Args:
prompt (str): The prompt to use
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
Returns:
str: The filename of the image
"""
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
# HuggingFace
if CFG.image_provider == "huggingface":
return generate_image_with_hf(prompt, filename)
# SD WebUI
elif CFG.image_provider == "sdwebui":
return generate_image_with_sd_webui(prompt, filename, size)
return "No Image Provider Set"
def generate_image_with_hf(prompt: str, filename: str) -> str:
"""Generate an image with HuggingFace's API.
Args:
prompt (str): The prompt to use
filename (str): The filename to save the image to
Returns:
str: The filename of the image
"""
API_URL = (
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
)
if CFG.huggingface_api_token is None:
raise ValueError(
"You need to set your Hugging Face API token in the config file."
)
headers = {
"Authorization": f"Bearer {CFG.huggingface_api_token}",
"X-Use-Cache": "false",
}
response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)
image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")
image.save(filename)
return f"Saved to disk:{filename}"
def generate_image_with_sd_webui(
prompt: str,
filename: str,
size: int = 512,
negative_prompt: str = "",
extra: dict = {},
) -> str:
"""Generate an image with Stable Diffusion webui.
Args:
prompt (str): The prompt to use
filename (str): The filename to save the image to
size (int, optional): The size of the image. Defaults to 256.
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
Returns:
str: The filename of the image
"""
# Create a session and set the basic auth if needed
s = requests.Session()
if CFG.sd_webui_auth:
username, password = CFG.sd_webui_auth.split(":")
s.auth = (username, password or "")
# Generate the images
response = requests.post(
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
json={
"prompt": prompt,
"negative_prompt": negative_prompt,
"sampler_index": "DDIM",
"steps": 20,
"cfg_scale": 7.0,
"width": size,
"height": size,
"n_iter": 1,
**extra,
},
)
logger.info(f"Image Generated for prompt:{prompt}")
# Save the image to disk
response = response.json()
b64 = b64decode(response["images"][0].split(",", 1)[0])
image = Image.open(io.BytesIO(b64))
image.save(filename)
return f"Saved to disk:{filename}"

View File

@@ -1,153 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
from typing import Dict
from .exception_not_commands import NotCommands
from .generator import PluginPromptGenerator
from pilot.configs.config import Config
def _resolve_pathlike_command_args(command_args):
if "directory" in command_args and command_args["directory"] in {"", "/"}:
# todo
command_args["directory"] = ""
else:
for pathlike in ["filename", "directory", "clone_path"]:
if pathlike in command_args:
# todo
command_args[pathlike] = ""
return command_args
def execute_ai_response_json(
prompt: PluginPromptGenerator,
ai_response,
user_input: str = None,
) -> str:
"""
Args:
command_registry:
ai_response:
prompt:
Returns:
"""
from pilot.speech.say import say_text
cfg = Config()
command_name, arguments = get_command(ai_response)
if cfg.speak_mode:
say_text(f"I want to execute {command_name}")
arguments = _resolve_pathlike_command_args(arguments)
# Execute command
if command_name is not None and command_name.lower().startswith("error"):
result = f"Command {command_name} threw the following error: {arguments}"
elif command_name == "human_feedback":
result = f"Human feedback: {user_input}"
else:
for plugin in cfg.plugins:
if not plugin.can_handle_pre_command():
continue
command_name, arguments = plugin.pre_command(command_name, arguments)
command_result = execute_command(
command_name,
arguments,
prompt,
)
result = f"{command_result}"
return result
def execute_command(
command_name: str,
arguments,
plugin_generator: PluginPromptGenerator,
):
"""Execute the command and return the result
Args:
command_name (str): The name of the command to execute
arguments (dict): The arguments for the command
Returns:
str: The result of the command
"""
cmd = plugin_generator.command_registry.commands.get(command_name)
# If the command is found, call it with the provided arguments
if cmd:
try:
return cmd(**arguments)
except Exception as e:
raise ValueError(f"Error: {str(e)}")
# return f"Error: {str(e)}"
# TODO: Change these to take in a file rather than pasted code, if
# non-file is given, return instructions "Input should be a python
# filepath, write your code to file and try again
else:
for command in plugin_generator.commands:
if (
command_name == command["label"].lower()
or command_name == command["name"].lower()
):
try:
# 删除非定义参数
diff_ags = list(
set(arguments.keys()).difference(set(command["args"].keys()))
)
for arg_name in diff_ags:
del arguments[arg_name]
print(str(arguments))
return command["function"](**arguments)
except Exception as e:
return f"Error: {str(e)}"
raise NotCommands("非可用命令" + command_name)
def get_command(response_json: Dict):
"""Parse the response and return the command name and arguments
Args:
response_json (json): The response from the AI
Returns:
tuple: The command name and arguments
Raises:
json.decoder.JSONDecodeError: If the response is not valid JSON
Exception: If any other error occurs
"""
try:
if "command" not in response_json:
return "Error:", "Missing 'command' object in JSON"
if not isinstance(response_json, dict):
return "Error:", f"'response_json' object is not dictionary {response_json}"
command = response_json["command"]
if not isinstance(command, dict):
return "Error:", "'command' object is not a dictionary"
if "name" not in command:
return "Error:", "Missing 'name' field in 'command' object"
command_name = command["name"]
# Use an empty dictionary if 'args' field is not present in 'command' object
arguments = command.get("args", {})
return command_name, arguments
except json.decoder.JSONDecodeError:
return "Error:", "Invalid JSON"
# All other errors, return "Error: + error message"
except Exception as e:
return "Error:", str(e)

View File

@@ -1,493 +0,0 @@
import functools
import importlib
import inspect
import time
import json
import logging
import xml.etree.ElementTree as ET
import pandas as pd
from pilot.common.json_utils import serialize
from datetime import datetime
from typing import Any, Callable, Optional, List
from pydantic import BaseModel
from pilot.base_modules.agent.common.schema import Status, ApiTagType
from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent.commands.generator import PluginPromptGenerator
from pilot.common.string_utils import extract_content_open_ending, extract_content
# Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
class Command:
"""A class representing a command.
Attributes:
name (str): The name of the command.
description (str): A brief description of what the command does.
signature (str): The signature of the function that the command executes. Defaults to None.
"""
def __init__(
self,
name: str,
description: str,
method: Callable[..., Any],
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
):
self.name = name
self.description = description
self.method = method
self.signature = signature if signature else str(inspect.signature(self.method))
self.enabled = enabled
self.disabled_reason = disabled_reason
def __call__(self, *args, **kwargs) -> Any:
if not self.enabled:
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
return self.method(*args, **kwargs)
def __str__(self) -> str:
return f"{self.name}: {self.description}, args: {self.signature}"
class CommandRegistry:
"""
The CommandRegistry class is a manager for a collection of Command objects.
It allows the registration, modification, and retrieval of Command objects,
as well as the scanning and loading of command plugins from a specified
directory.
"""
def __init__(self):
self.commands = {}
def _import_module(self, module_name: str) -> Any:
return importlib.import_module(module_name)
def _reload_module(self, module: Any) -> Any:
return importlib.reload(module)
def register(self, cmd: Command) -> None:
self.commands[cmd.name] = cmd
def unregister(self, command_name: str):
if command_name in self.commands:
del self.commands[command_name]
else:
raise KeyError(f"Command '{command_name}' not found in registry.")
def reload_commands(self) -> None:
"""Reloads all loaded command plugins."""
for cmd_name in self.commands:
cmd = self.commands[cmd_name]
module = self._import_module(cmd.__module__)
reloaded_module = self._reload_module(module)
if hasattr(reloaded_module, "register"):
reloaded_module.register(self)
def is_valid_command(self, name: str) -> bool:
if name not in self.commands:
return False
else:
return True
def get_command(self, name: str) -> Callable[..., Any]:
return self.commands[name]
def call(self, command_name: str, **kwargs) -> Any:
if command_name not in self.commands:
raise KeyError(f"Command '{command_name}' not found in registry.")
command = self.commands[command_name]
return command(**kwargs)
def command_prompt(self) -> str:
"""
Returns a string representation of all registered `Command` objects for use in a prompt
"""
commands_list = [
f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values())
]
return "\n".join(commands_list)
def import_commands(self, module_name: str) -> None:
"""
Imports the specified Python module containing command plugins.
This method imports the associated module and registers any functions or
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
as `Command` objects. The registered `Command` objects are then added to the
`commands` dictionary of the `CommandRegistry` object.
Args:
module_name (str): The name of the module to import for command plugins.
"""
module = importlib.import_module(module_name)
for attr_name in dir(module):
attr = getattr(module, attr_name)
# Register decorated functions
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
attr, AUTO_GPT_COMMAND_IDENTIFIER
):
self.register(attr.command)
# Register command classes
elif (
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
):
cmd_instance = attr()
self.register(cmd_instance)
def command(
name: str,
description: str,
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
) -> Callable[..., Any]:
"""The command decorator is used to create Command objects from ordinary functions."""
def decorator(func: Callable[..., Any]) -> Command:
cmd = Command(
name=name,
description=description,
method=func,
signature=signature,
enabled=enabled,
disabled_reason=disabled_reason,
)
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
return func(*args, **kwargs)
wrapper.command = cmd
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
return wrapper
return decorator
class PluginStatus(BaseModel):
name: str
location: List[int]
args: dict
status: Status = Status.TODO.value
logo_url: str = None
api_result: str = None
err_msg: str = None
start_time = datetime.now().timestamp() * 1000
end_time: int = None
df: Any = None
class ApiCall:
agent_prefix = "<api-call>"
agent_end = "</api-call>"
name_prefix = "<name>"
name_end = "</name>"
def __init__(
self,
plugin_generator: Any = None,
display_registry: Any = None,
backend_rendering: bool = False,
):
# self.name: str = ""
# self.status: Status = Status.TODO.value
# self.logo_url: str = None
# self.args = {}
# self.api_result: str = None
# self.err_msg: str = None
self.plugin_status_map = {}
self.plugin_generator = plugin_generator
self.display_registry = display_registry
self.start_time = datetime.now().timestamp() * 1000
self.backend_rendering: bool = False
def __repr__(self):
return f"ApiCall(name={self.name}, status={self.status}, args={self.args})"
def __is_need_wait_plugin_call(self, api_call_context):
start_agent_count = api_call_context.count(self.agent_prefix)
end_agent_count = api_call_context.count(self.agent_end)
if start_agent_count > 0:
return True
else:
# 末尾新出字符检测
check_len = len(self.agent_prefix)
last_text = api_call_context[-check_len:]
for i in range(check_len):
text_tmp = last_text[-i:]
prefix_tmp = self.agent_prefix[:i]
if text_tmp == prefix_tmp:
return True
else:
i += 1
return False
def check_last_plugin_call_ready(self, all_context):
start_agent_count = all_context.count(self.agent_prefix)
end_agent_count = all_context.count(self.agent_end)
if start_agent_count > 0 and start_agent_count == end_agent_count:
return True
return False
def __deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
error_md_tags = [
"```",
"```python",
"```xml",
"```json",
"```markdown",
"```sql",
]
if include_end == False:
md_tag_end = ""
else:
md_tag_end = "```"
for tag in error_md_tags:
all_context = all_context.replace(
tag + api_context + md_tag_end, api_context
)
all_context = all_context.replace(
tag + "\n" + api_context + "\n" + md_tag_end, api_context
)
all_context = all_context.replace(
tag + " " + api_context + " " + md_tag_end, api_context
)
all_context = all_context.replace(tag + api_context, api_context)
return all_context
def api_view_context(self, all_context: str, display_mode: bool = False):
call_context_map = extract_content_open_ending(
all_context, self.agent_prefix, self.agent_end, True
)
for api_index, api_context in call_context_map.items():
api_status = self.plugin_status_map.get(api_context)
if api_status is not None:
if display_mode:
all_context = self.__deal_error_md_tags(all_context, api_context)
if Status.FAILED.value == api_status.status:
all_context = all_context.replace(
api_context,
f'\n<span style="color:red">Error:</span>{api_status.err_msg}\n'
+ self.to_view_antv_vis(api_status),
)
else:
all_context = all_context.replace(
api_context, self.to_view_antv_vis(api_status)
)
else:
all_context = self.__deal_error_md_tags(
all_context, api_context, False
)
all_context = all_context.replace(
api_context, self.to_view_text(api_status)
)
else:
# not ready api call view change
now_time = datetime.now().timestamp() * 1000
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(
api_context,
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
)
return all_context
def update_from_context(self, all_context):
api_context_map = extract_content(
all_context, self.agent_prefix, self.agent_end, True
)
for api_index, api_context in api_context_map.items():
api_context = api_context.replace("\\n", "").replace("\n", "")
api_call_element = ET.fromstring(api_context)
api_name = api_call_element.find("name").text
if api_name.find("[") >= 0 or api_name.find("]") >= 0:
api_name = api_name.replace("[", "").replace("]", "")
api_args = {}
args_elements = api_call_element.find("args")
for child_element in args_elements.iter():
api_args[child_element.tag] = child_element.text
api_status = self.plugin_status_map.get(api_context)
if api_status is None:
api_status = PluginStatus(
name=api_name, location=[api_index], args=api_args
)
self.plugin_status_map[api_context] = api_status
else:
api_status.location.append(api_index)
def __to_view_param_str(self, api_status):
param = {}
if api_status.name:
param["name"] = api_status.name
param["status"] = api_status.status
if api_status.logo_url:
param["logo"] = api_status.logo_url
if api_status.err_msg:
param["err_msg"] = api_status.err_msg
if api_status.api_result:
param["result"] = api_status.api_result
return json.dumps(param, default=serialize, ensure_ascii=False)
def to_view_text(self, api_status: PluginStatus):
api_call_element = ET.Element("dbgpt-view")
api_call_element.text = self.__to_view_param_str(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")
def to_view_antv_vis(self, api_status: PluginStatus):
if self.backend_rendering:
html_table = api_status.df.to_html(
index=False, escape=False, sparsify=False
)
table_str = "".join(html_table.split())
table_str = table_str.replace("\n", " ")
html = f""" \n<div><b>[SQL]{api_status.args["sql"]}</b></div><div class="w-full overflow-auto">{table_str}</div>\n """
return html
else:
api_call_element = ET.Element("chart-view")
api_call_element.attrib["content"] = self.__to_antv_vis_param(api_status)
api_call_element.text = "\n"
# api_call_element.set("content", self.__to_antv_vis_param(api_status))
# api_call_element.text = self.__to_antv_vis_param(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")
# return f'<chart-view content="{self.__to_antv_vis_param(api_status)}">'
def __to_antv_vis_param(self, api_status: PluginStatus):
param = {}
if api_status.name:
param["type"] = api_status.name
if api_status.args:
param["sql"] = api_status.args["sql"]
# if api_status.err_msg:
# param["err_msg"] = api_status.err_msg
if api_status.api_result:
param["data"] = api_status.api_result
else:
param["data"] = []
return json.dumps(param, ensure_ascii=False)
def run(self, llm_text):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logging.info(f"插件执行:{value.name},{value.args}")
try:
value.api_result = execute_command(
value.name, value.args, self.plugin_generator
)
value.status = Status.COMPLETED.value
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text)
def run_display_sql(self, llm_text, sql_run_func):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logging.info(f"sql展示执行:{value.name},{value.args}")
try:
sql = value.args["sql"]
if sql:
param = {
"df": sql_run_func(sql),
}
value.df = param["df"]
if self.display_registry.is_valid_command(value.name):
value.api_result = self.display_registry.call(
value.name, **param
)
else:
value.api_result = self.display_registry.call(
"response_table", **param
)
value.status = Status.COMPLETED.value
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text, True)
def display_sql_llmvis(self, llm_text, sql_run_func):
"""
Render charts using the Antv standard protocol
Args:
llm_text: LLM response text
sql_run_func: sql run function
Returns:
ChartView protocol text
"""
try:
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logging.info(f"sql展示执行:{value.name},{value.args}")
try:
sql = value.args["sql"]
if sql is not None and len(sql) > 0:
data_df = sql_run_func(sql)
value.df = data_df
value.api_result = json.loads(
data_df.to_json(
orient="records",
date_format="iso",
date_unit="s",
)
)
value.status = Status.COMPLETED.value
else:
value.status = Status.FAILED.value
value.err_msg = "No executable sql"
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
except Exception as e:
logging.error("Api parsing exception", e)
raise ValueError("Api parsing exception," + str(e))
return self.api_view_context(llm_text, True)

View File

@@ -1,314 +0,0 @@
from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
import pandas as pd
import uuid
import os
import matplotlib
import seaborn as sns
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager
from pilot.common.string_utils import is_scientific_notation
import logging
logger = logging.getLogger(__name__)
static_message_img_path = os.path.join(os.getcwd(), "message/img")
def data_pre_classification(df: DataFrame):
## Data pre-classification
columns = df.columns.tolist()
number_columns = []
non_numeric_colums = []
# 收集数据分类小于10个的列
non_numeric_colums_value_map = {}
numeric_colums_value_map = {}
for column_name in columns:
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
number_columns.append(column_name)
unique_values = df[column_name].unique()
numeric_colums_value_map.update({column_name: len(unique_values)})
else:
non_numeric_colums.append(column_name)
unique_values = df[column_name].unique()
non_numeric_colums_value_map.update({column_name: len(unique_values)})
sorted_numeric_colums_value_map = dict(
sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
)
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
sorted_colums_value_map = dict(
sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
)
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
# Analyze x-coordinate
if len(non_numeric_colums_sort_list) > 0:
x_cloumn = non_numeric_colums_sort_list[-1]
non_numeric_colums_sort_list.remove(x_cloumn)
else:
x_cloumn = number_columns[0]
numeric_colums_sort_list.remove(x_cloumn)
# Analyze y-coordinate
if len(numeric_colums_sort_list) > 0:
y_column = numeric_colums_sort_list[0]
numeric_colums_sort_list.remove(y_column)
else:
raise ValueError("Not enough numeric columns for chart")
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
def zh_font_set():
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
def format_axis(value, pos):
# 判断是否为数字
if is_scientific_notation(value):
# 判断是否需要进行非科学计数法格式化
return "{:.2f}".format(value)
return value
@command(
"response_line_chart",
"Line chart display, used to display comparative trend analysis data",
'"df":"<data frame>"',
)
def response_line_chart(df: DataFrame) -> str:
logger.info(f"response_line_chart")
if df.size <= 0:
raise ValueError("No Data")
try:
# set font
# zh_font_set()
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {"font.sans-serif": can_use_fonts}
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
x, y, non_num_columns, num_colmns = data_pre_classification(df)
# ## 复杂折线图实现
if len(num_colmns) > 0:
num_colmns.append(y)
df_melted = pd.melt(
df,
id_vars=x,
value_vars=num_colmns,
var_name="line",
value_name="Value",
)
sns.lineplot(
data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2"
)
else:
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
# ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, dpi=100, transparent=True)
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
except Exception as e:
logging.error("Draw Line Chart Faild!" + str(e), e)
raise ValueError("Draw Line Chart Faild!" + str(e))
@command(
"response_bar_chart",
"Histogram, suitable for comparative analysis of multiple target values",
'"df":"<data frame>"',
)
def response_bar_chart(df: DataFrame) -> str:
logger.info(f"response_bar_chart")
if df.size <= 0:
raise ValueError("No Data")
# set font
# zh_font_set()
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {"font.sans-serif": can_use_fonts}
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
hue = None
x, y, non_num_columns, num_colmns = data_pre_classification(df)
if len(non_num_columns) >= 1:
hue = non_num_columns[0]
if len(num_colmns) >= 1:
if hue:
if len(num_colmns) >= 2:
can_use_columns = num_colmns[:2]
else:
can_use_columns = num_colmns
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
for sub_y_column in can_use_columns:
sns.barplot(
data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
)
else:
if len(num_colmns) > 5:
can_use_columns = num_colmns[:5]
else:
can_use_columns = num_colmns
can_use_columns.append(y)
df_melted = pd.melt(
df,
id_vars=x,
value_vars=can_use_columns,
var_name="line",
value_name="Value",
)
sns.barplot(
data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax
)
else:
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
# 设置 y 轴刻度格式为普通数字格式
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
# ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, dpi=100, transparent=True)
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@command(
"response_pie_chart",
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
'"df":"<data frame>"',
)
def response_pie_chart(df: DataFrame) -> str:
logger.info(f"response_pie_chart")
columns = df.columns.tolist()
if df.size <= 0:
raise ValueError("No Data")
# set font
# zh_font_set()
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
ax = df.plot(
kind="pie",
y=columns[1],
ax=ax,
labels=df[columns[0]].values,
startangle=90,
autopct="%1.1f%%",
)
plt.axis("equal") # 使饼图为正圆形
# plt.title(columns[0])
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100, transparent=True)
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img

View File

@@ -1,21 +0,0 @@
from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
import logging
logger = logging.getLogger(__name__)
@command(
"response_table",
"Table display, suitable for display with many display columns or non-numeric columns",
'"df":"<data frame>"',
)
def response_table(df: DataFrame) -> str:
logger.info(f"response_table")
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
table_str = table_str.replace("\n", " ")
html = f""" \n<div class="w-full overflow-auto">{table_str}</div>\n """
return html

View File

@@ -1,37 +0,0 @@
from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
import logging
logger = logging.getLogger(__name__)
@command(
"response_data_text",
"Text display, the default display method, suitable for single-line or simple content display",
'"df":"<data frame>"',
)
def response_data_text(df: DataFrame) -> str:
logger.info(f"response_data_text")
data = df.values
row_size = data.shape[0]
value_str = ""
text_info = ""
if row_size > 1:
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
text_info = html.replace("\n", " ")
elif row_size == 1:
row = data[0]
for value in row:
if value_str:
value_str = value_str + f", ** {value} **"
else:
value_str = f" ** {value} **"
text_info = f" {value_str}"
else:
text_info = f"##### _没有找到可用的数据_"
return text_info

View File

@@ -1,4 +0,0 @@
class NotCommands(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@@ -1,135 +0,0 @@
""" A module for generating custom prompt strings."""
import json
from typing import Any, Callable, Dict, List, Optional
class PluginPromptGenerator:
"""
A class for generating custom prompt strings based on constraints, commands,
resources, and performance evaluations.
"""
def __init__(self) -> None:
"""
Initialize the PromptGenerator object with empty lists of constraints,
commands, resources, and performance evaluations.
"""
self.constraints = []
self.commands = []
self.resources = []
self.performance_evaluation = []
self.goals = []
self.command_registry = None
self.response_format = {
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user",
},
"command": {"name": "command name", "args": {"arg name": "value"}},
}
def add_constraint(self, constraint: str) -> None:
"""
Add a constraint to the constraints list.
Args:
constraint (str): The constraint to be added.
"""
self.constraints.append(constraint)
def add_command(
self,
command_label: str,
command_name: str,
args=None,
function: Optional[Callable] = None,
) -> None:
"""
Add a command to the commands list with a label, name, and optional arguments.
Args:
command_label (str): The label of the command.
command_name (str): The name of the command.
args (dict, optional): A dictionary containing argument names and their
values. Defaults to None.
function (callable, optional): A callable function to be called when
the command is executed. Defaults to None.
"""
if args is None:
args = {}
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
command = {
"label": command_label,
"name": command_name,
"args": command_args,
"function": function,
}
self.commands.append(command)
def _generate_command_string(self, command: Dict[str, Any]) -> str:
"""
Generate a formatted string representation of a command.
Args:
command (dict): A dictionary containing command information.
Returns:
str: The formatted command string.
"""
args_string = ", ".join(
f'"{key}": "{value}"' for key, value in command["args"].items()
)
return f'"{command["name"]}": {command["label"]} , args: {args_string}'
def add_resource(self, resource: str) -> None:
"""
Add a resource to the resources list.
Args:
resource (str): The resource to be added.
"""
self.resources.append(resource)
def add_performance_evaluation(self, evaluation: str) -> None:
"""
Add a performance evaluation item to the performance_evaluation list.
Args:
evaluation (str): The evaluation item to be added.
"""
self.performance_evaluation.append(evaluation)
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
"""
Generate a numbered list from given items based on the item_type.
Args:
items (list): A list of items to be numbered.
item_type (str, optional): The type of items in the list.
Defaults to 'list'.
Returns:
str: The formatted numbered list.
"""
if item_type == "command":
command_strings = []
if self.command_registry:
command_strings += [
str(item)
for item in self.command_registry.commands.values()
if item.enabled
]
# terminate command is added manually
command_strings += [self._generate_command_string(item) for item in items]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def generate_commands_string(self) -> str:
return f"{self._generate_numbered_list(self.commands, item_type='command')}"

View File

@@ -1,18 +0,0 @@
from enum import Enum
class PluginStorageType(Enum):
Git = "git"
Oss = "oss"
class Status(Enum):
TODO = "todo"
RUNNING = "running"
FAILED = "failed"
COMPLETED = "completed"
class ApiTagType(Enum):
API_VIEW = "dbgpt_view"
API_CALL = "dbgpt_call"

View File

@@ -1,163 +0,0 @@
import json
import time
import logging
from fastapi import (
APIRouter,
Body,
UploadFile,
File,
)
from abc import ABC, abstractmethod
from typing import List
from pilot.configs.model_config import LOGDIR
from pilot.openapi.api_view_model import (
Result,
)
from .model import (
PluginHubParam,
PagenationFilter,
PagenationResult,
PluginHubFilter,
MyPluginFilter,
)
from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity
from .plugins_util import scan_plugins
from .commands.generator import PluginPromptGenerator
from pilot.configs.model_config import PLUGINS_DIR
from pilot.component import BaseComponent, ComponentType, SystemApp
router = APIRouter()
logger = logging.getLogger(__name__)
class ModuleAgent(BaseComponent, ABC):
name = ComponentType.AGENT_HUB
def __init__(self):
# load plugins
self.plugins = scan_plugins(PLUGINS_DIR)
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
def refresh_plugins(self):
self.plugins = scan_plugins(PLUGINS_DIR)
def load_select_plugin(
self, generator: PluginPromptGenerator, select_plugins: List[str]
) -> PluginPromptGenerator:
logger.info(f"load_select_plugin:{select_plugins}")
# load select plugin
for plugin in self.plugins:
if plugin._name in select_plugins:
if not plugin.can_handle_post_prompt():
continue
generator = plugin.post_prompt(generator)
return generator
module_agent = ModuleAgent()
@router.post("/v1/agent/hub/update", response_model=Result[str])
async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
branch = (
update_param.branch
if update_param.branch is not None and len(update_param.branch) > 0
else "main"
)
authorization = (
update_param.authorization
if update_param.branch is not None and len(update_param.branch) > 0
else None
)
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
return Result.succ(None)
except Exception as e:
logger.error("Agent Hub Update Error!", e)
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
logger.info(f"get_agent_list:{filter.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
filter_enetity: PluginHubEntity = PluginHubEntity()
if filter.filter:
attrs = vars(filter.filter) # 获取原始对象的属性字典
for attr, value in attrs.items():
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
datas, total_pages, total_count = agent_hub.hub_dao.list(
filter_enetity, filter.page_index, filter.page_size
)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result.page_index = filter.page_index
result.page_size = filter.page_size
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
# print(json.dumps(result.to_dic()))
return Result.succ(result.to_dic())
@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user: str = None):
logger.info(f"my_agents:{user}")
agent_hub = AgentHub(PLUGINS_DIR)
agents = agent_hub.get_my_plugin(user)
agent_dicts = []
for agent in agents:
agent_dicts.append(agent.__dict__)
return Result.succ(agent_dicts)
@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name: str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.install_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.failed(code="E0021", msg=f"Plugin Install Error {e}")
@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name: str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.uninstall_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.failed(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
await agent_hub.upload_my_plugin(doc_file, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.failed(code="E0023", msg=f"Upload Personal Plugin Error {e}")

View File

@@ -1,156 +0,0 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func
from sqlalchemy import UniqueConstraint
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
class MyPluginEntity(Base):
__tablename__ = "my_plugin"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String(255), nullable=True, comment="user's tenant")
user_code = Column(String(255), nullable=False, comment="user code")
user_name = Column(String(255), nullable=True, comment="user name")
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
use_count = Column(
Integer, nullable=True, default=0, comment="plugin total use count"
)
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
UniqueConstraint("user_code", "name", name="uk_name")
class MyPluginDao(BaseDao[MyPluginEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def add(self, engity: MyPluginEntity):
session = self.get_session()
my_plugin = MyPluginEntity(
tenant=engity.tenant,
user_code=engity.user_code,
user_name=engity.user_name,
name=engity.name,
type=engity.type,
version=engity.version,
use_count=engity.use_count or 0,
succ_count=engity.succ_count or 0,
sys_code=engity.sys_code,
gmt_created=datetime.now(),
)
session.add(my_plugin)
session.commit()
id = my_plugin.id
session.close()
return id
def update(self, entity: MyPluginEntity):
session = self.get_session()
updated = session.merge(entity)
session.commit()
return updated.id
def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
result = my_plugins.all()
session.close()
return result
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
result = my_plugins.first()
session.close()
return result
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.tenant is not None:
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.type is not None:
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.user_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
result = my_plugins.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def count(self, query: MyPluginEntity):
session = self.get_session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.type is not None:
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.tenant is not None:
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.user_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
count = my_plugins.scalar()
session.close()
return count
def delete(self, plugin_id: int):
session = self.get_session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id)
my_plugins = session.query(MyPluginEntity)
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
my_plugins.delete()
session.commit()
session.close()

View File

@@ -1,159 +0,0 @@
from datetime import datetime
import pytz
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, DDL
from sqlalchemy import UniqueConstraint
from pilot.base_modules.meta_data.meta_data import Base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
# TODO We should consider that the production environment does not have permission to execute the DDL
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
class PluginHubEntity(Base):
__tablename__ = "plugin_hub"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
description = Column(String(255), nullable=False, comment="plugin description")
author = Column(String(255), nullable=True, comment="plugin author")
email = Column(String(255), nullable=True, comment="plugin author email")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
storage_channel = Column(String(255), comment="plugin storage channel")
storage_url = Column(String(255), comment="plugin download url")
download_param = Column(String(255), comment="plugin download param")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin upload time"
)
installed = Column(Integer, default=False, comment="plugin already installed count")
UniqueConstraint("name", name="uk_name")
Index("idx_q_type", "type")
class PluginHubDao(BaseDao[PluginHubEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def add(self, engity: PluginHubEntity):
session = self.get_session()
timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity(
name=engity.name,
author=engity.author,
email=engity.email,
type=engity.type,
version=engity.version,
storage_channel=engity.storage_channel,
storage_url=engity.storage_url,
gmt_created=timezone.localize(datetime.now()),
)
session.add(plugin_hub)
session.commit()
id = plugin_hub.id
session.close()
return id
def update(self, entity: PluginHubEntity):
session = self.get_session()
try:
updated = session.merge(entity)
session.commit()
return updated.id
finally:
session.close()
def list(
self, query: PluginHubEntity, page=1, page_size=20
) -> list[PluginHubEntity]:
session = self.get_session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
)
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
result = plugin_hubs.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def get_by_storage_url(self, storage_url):
session = self.get_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all()
session.close()
return result
def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first()
session.close()
return result
def count(self, query: PluginHubEntity):
session = self.get_session()
plugin_hubs = session.query(func.count(PluginHubEntity.id))
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
)
count = plugin_hubs.scalar()
session.close()
return count
def delete(self, plugin_id: int):
session = self.get_session()
if plugin_id is None:
raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity)
if plugin_id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
plugin_hubs.delete()
session.commit()
session.close()

View File

@@ -1,217 +0,0 @@
import json
import logging
import os
import glob
import shutil
from fastapi import UploadFile
from typing import Any
import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from ..common.schema import PluginStorageType
from ..plugins_util import scan_plugins, update_from_git
logger = logging.getLogger(__name__)
Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = ""
class AgentHub:
def __init__(self, plugin_dir) -> None:
self.hub_dao = PluginHubDao()
self.my_plugin_dao = MyPluginDao()
os.makedirs(plugin_dir, exist_ok=True)
self.plugin_dir = plugin_dir
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
def install_plugin(self, plugin_name: str, user_name: str = None):
logger.info(f"install_plugin {plugin_name}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
if plugin_entity:
if plugin_entity.storage_channel == PluginStorageType.Git.value:
try:
branch_name = None
authorization = None
if plugin_entity.download_param:
download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization")
file_name = self.__download_from_git(
plugin_entity.storage_url, branch_name, authorization
)
# add to my plugins and edit hub status
plugin_entity.installed = plugin_entity.installed + 1
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
user_name, plugin_name
)
if my_plugin_entity is None:
my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name
if user_name:
# TODO use user
my_plugin_entity.user_code = user_name
my_plugin_entity.user_name = user_name
my_plugin_entity.tenant = ""
else:
my_plugin_entity.user_code = Default_User
with self.hub_dao.get_session() as session:
try:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
session.commit()
session.close()
except Exception as e:
logger.error("install merge roll back!" + str(e))
session.rollback()
except Exception as e:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
else:
raise ValueError(
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
)
else:
raise ValueError(f"Can't Find Plugin {plugin_name}!")
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session:
try:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
if plugin_entity is not None:
# delete package file if not use
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
have_installed = False
for plugin_info in plugin_infos:
if plugin_info.installed > 0:
have_installed = True
break
if not have_installed:
plugin_repo_name = (
plugin_entity.storage_url.replace(".git", "")
.strip("/")
.split("/")[-1]
)
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
for file in files:
os.remove(file)
else:
files = glob.glob(
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
)
for file in files:
os.remove(file)
def __download_from_git(self, github_repo, branch_name, authorization):
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
my_plugin_entity = MyPluginEntity()
my_plugin_entity.name = hub_plugin.name
my_plugin_entity.type = hub_plugin.type
my_plugin_entity.version = hub_plugin.version
return my_plugin_entity
def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = "main",
authorization: str = None,
):
logger.info("refresh_hub_by_git start!")
update_from_git(
self.temp_hub_file_path, github_repo, branch_name, authorization
)
git_plugins = scan_plugins(self.temp_hub_file_path)
try:
for git_plugin in git_plugins:
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
if old_hub_info:
plugin_hub_info = old_hub_info
else:
plugin_hub_info = PluginHubEntity()
plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
plugin_hub_info.email = getattr(git_plugin, "_email", "")
download_param = {}
if branch_name:
download_param["branch_name"] = branch_name
if authorization and len(authorization) > 0:
download_param["authorization"] = authorization
plugin_hub_info.download_param = json.dumps(download_param)
plugin_hub_info.installed = 0
plugin_hub_info.name = git_plugin._name
plugin_hub_info.version = git_plugin._version
plugin_hub_info.description = git_plugin._description
self.hub_dao.update(plugin_hub_info)
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
# We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path):
os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
tmp_path,
os.path.join(self.plugin_dir, doc_file.filename),
)
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
if user is None or len(user) <= 0:
user = Default_User
for my_plugin in my_plugins:
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
user, my_plugin._name
)
if my_plugin_entiy is None:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal"
my_plugin_entiy.user_code = user
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_plugin_dao.update(my_plugin_entiy)
def reload_my_plugins(self):
logger.info(f"load_plugins start!")
return scan_plugins(self.plugin_dir)
def get_my_plugin(self, user: str):
logger.info(f"get_my_plugin:{user}")
if not user:
user = Default_User
return self.my_plugin_dao.get_by_user(user)

View File

@@ -1,69 +0,0 @@
from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
T = TypeVar("T")
class PagenationFilter(BaseModel, Generic[T]):
page_index: int = 1
page_size: int = 20
filter: T = None
class PagenationResult(BaseModel, Generic[T]):
page_index: int = 1
page_size: int = 20
total_page: int = 0
total_row_count: int = 0
datas: List[T] = []
def to_dic(self):
data_dicts = []
for item in self.datas:
data_dicts.append(item.__dict__)
return {
"page_index": self.page_index,
"page_size": self.page_size,
"total_page": self.total_page,
"total_row_count": self.total_row_count,
"datas": data_dicts,
}
@dataclass
class PluginHubFilter(BaseModel):
name: str
description: str
author: str
email: str
type: str
version: str
storage_channel: str
storage_url: str
@dataclass
class MyPluginFilter(BaseModel):
tenant: str
user_code: str
user_name: str
name: str
file_name: str
type: str
version: str
class PluginHubParam(BaseModel):
channel: Optional[str] = Field("git", description="Plugin storage channel")
url: Optional[str] = Field(
"https://github.com/eosphoros-ai/DB-GPT-Plugins.git",
description="Plugin storage url",
)
branch: Optional[str] = Field(
"main", description="github download branch", nullable=True
)
authorization: Optional[str] = Field(
None, description="github download authorization", nullable=True
)

View File

@@ -1,260 +0,0 @@
"""加载组件"""
import json
import os
import glob
import zipfile
import fnmatch
import requests
import git
import threading
import datetime
import logging
from pathlib import Path
from typing import List
from urllib.parse import urlparse
from zipimport import zipimporter
import requests
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR
logger = logging.getLogger(__name__)
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
"""
Loader zip plugin file. Native support Auto_gpt_plugin
Args:
zip_path (str): Path to the zipfile.
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
list[str]: The list of module names found or empty list if none were found.
"""
result = []
with zipfile.ZipFile(zip_path, "r") as zfile:
for name in zfile.namelist():
if name.endswith("__init__.py") and not name.startswith("__MACOSX"):
logger.debug(f"Found module '{name}' in the zipfile at: {name}")
result.append(name)
if len(result) == 0:
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
return result
def write_dict_to_json_file(data: dict, file_path: str) -> None:
"""
Write a dictionary to a JSON file.
Args:
data (dict): Dictionary to write.
file_path (str): Path to the file.
"""
with open(file_path, "w") as file:
json.dump(data, file, indent=4)
def create_directory_if_not_exists(directory_path: str) -> bool:
"""
Create a directory if it does not exist.
Args:
directory_path (str): Path to the directory.
Returns:
bool: True if the directory was created, else False.
"""
if not os.path.exists(directory_path):
try:
os.makedirs(directory_path)
logger.debug(f"Created directory: {directory_path}")
return True
except OSError as e:
logger.warn(f"Error creating directory {directory_path}: {e}")
return False
else:
logger.info(f"Directory {directory_path} already exists")
return True
def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load:
print("not auto load_native_plugins")
return
def load_from_git(cfg: Config):
print("async load_native_plugins")
branch_name = cfg.plugins_git_branch
native_plugin_repo = "DB-GPT-Plugins"
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try:
session = requests.Session()
response = session.get(
url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
)
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name)
with open(file_name, "wb") as f:
f.write(response.content)
print("save file")
cfg.set_plugins(scan_plugins(cfg.debug_mode))
else:
print("get file failedresponse code", response.status_code)
except Exception as e:
print("load plugin from git exception!" + str(e))
t = threading.Thread(target=load_from_git, args=(cfg,))
t.start()
def __scan_plugin_file(file_path, debug: bool = False) -> List[AutoGPTPluginTemplate]:
logger.info(f"__scan_plugin_file:{file_path},{debug}")
loaded_plugins = []
if moduleList := inspect_zip_for_modules(str(file_path), debug):
for module in moduleList:
plugin = Path(file_path)
module = Path(module)
logger.debug(f"Plugin: {plugin} Module: {module}")
zipped_package = zipimporter(str(plugin))
zipped_module = zipped_package.load_module(str(module.parent))
for key in dir(zipped_module):
if key.startswith("__"):
continue
a_module = getattr(zipped_module, key)
a_keys = dir(a_module)
if (
"_abc_impl" in a_keys
and a_module.__name__ != "AutoGPTPluginTemplate"
# and denylist_allowlist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
return loaded_plugins
def scan_plugins(
plugins_file_path: str, file_name: str = "", debug: bool = False
) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.
Args:
cfg (Config): Config instance including plugins config
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
List[Tuple[str, Path]]: List of plugins.
"""
loaded_plugins = []
# Generic plugins
plugins_path = Path(plugins_file_path)
if file_name:
plugin_path = Path(plugins_path, file_name)
loaded_plugins = __scan_plugin_file(plugin_path)
else:
for plugin_path in plugins_path.glob("*.zip"):
loaded_plugins.extend(__scan_plugin_file(plugin_path))
if loaded_plugins:
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
for plugin in loaded_plugins:
logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}")
return loaded_plugins
def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
"""Check if the plugin is in the allowlist or denylist.
Args:
plugin_name (str): Name of the plugin.
cfg (Config): Config object.
Returns:
True or False
"""
logger.debug(f"Checking if plugin {plugin_name} should be loaded")
if plugin_name in cfg.plugins_denylist:
logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.")
return False
if plugin_name in cfg.plugins_allowlist:
logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.")
return True
ack = input(
f"WARNING: Plugin {plugin_name} found. But not in the"
f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): "
)
return ack.lower() == cfg.authorise_key
def update_from_git(
download_path: str,
github_repo: str = "",
branch_name: str = "main",
authorization: str = None,
):
os.makedirs(download_path, exist_ok=True)
if github_repo:
if github_repo.index("github.com") <= 0:
raise ValueError("Not a correct Github repository address" + github_repo)
github_repo = github_repo.replace(".git", "")
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
plugin_repo_name = github_repo.strip("/").split("/")[-1]
else:
url = (
"https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
)
plugin_repo_name = "DB-GPT-Plugins"
try:
session = requests.Session()
headers = {}
if authorization and len(authorization) > 0:
headers = {"Authorization": authorization}
response = session.get(
url,
headers=headers,
)
if response.status_code == 200:
plugins_path_path = Path(download_path)
files = glob.glob(os.path.join(plugins_path_path, f"{plugin_repo_name}*"))
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = (
f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
)
print(file_name)
with open(file_name, "wb") as f:
f.write(response.content)
return plugin_repo_name
else:
logger.error("update plugins faildresponse code", response.status_code)
raise ValueError("download plugin faild!" + response.status_code)
except Exception as e:
logger.error("update plugins from git exception!" + str(e))
raise ValueError("download plugin exception!", e)
def __fetch_from_git(local_path, git_url):
logger.info("fetch plugins from git to local path:{}", local_path)
os.makedirs(local_path, exist_ok=True)
repo = git.Repo(local_path)
if repo.is_repo():
repo.remotes.origin.pull()
else:
git.Repo.clone_from(git_url, local_path)
# if repo.head.is_valid():
# clone succ fetch plugins info

View File

@@ -1,3 +0,0 @@
flask_sqlalchemy==3.0.5
flask==2.3.2
gitpython==3.1.36

View File

@@ -1,6 +0,0 @@
class ModuleMangeApi:
def module_name(self):
pass
def register(self):
pass

View File

@@ -1,25 +0,0 @@
from typing import TypeVar, Generic, List, Any
from sqlalchemy.orm import sessionmaker
T = TypeVar("T")
class BaseDao(Generic[T]):
def __init__(
self,
orm_base=None,
database: str = None,
db_engine: Any = None,
session: Any = None,
) -> None:
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
self._orm_base = orm_base
self._database = database
self._db_engine = db_engine
self._session = session
def get_session(self):
Session = sessionmaker(autocommit=False, autoflush=False, bind=self._db_engine)
session = Session()
return session

View File

@@ -1,93 +0,0 @@
import os
import sqlite3
import logging
from sqlalchemy import create_engine, DDL
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from alembic import command
from alembic.config import Config as AlembicConfig
from urllib.parse import quote
from pilot.configs.config import Config
from urllib.parse import quote_plus as urlquote
logger = logging.getLogger(__name__)
# DB-GPT meta_data database config, now support mysql and sqlite
CFG = Config()
default_db_path = os.path.join(os.getcwd(), "meta_data")
os.makedirs(default_db_path, exist_ok=True)
# Meta Info
META_DATA_DATABASE = CFG.LOCAL_DB_NAME
db_name = META_DATA_DATABASE
db_path = default_db_path + f"/{db_name}.db"
connection = sqlite3.connect(db_path)
if CFG.LOCAL_DB_TYPE == "mysql":
engine_temp = create_engine(
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
)
# check and auto create mysqldatabase
try:
# try to connect
with engine_temp.connect() as conn:
# TODO We should consider that the production environment does not have permission to execute the DDL
conn.execute(DDL(f"CREATE DATABASE IF NOT EXISTS {db_name}"))
print(f"Already connect '{db_name}'")
except OperationalError as e:
# if connect failed, create dbgpt database
logger.error(f"{db_name} not connect success!")
engine = create_engine(
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
)
else:
engine = create_engine(f"sqlite:///{db_path}")
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
Base = declarative_base()
# Base.metadata.create_all()
alembic_ini_path = default_db_path + "/alembic.ini"
alembic_cfg = AlembicConfig(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
os.makedirs(default_db_path + "/alembic", exist_ok=True)
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
alembic_cfg.attributes["target_metadata"] = Base.metadata
alembic_cfg.attributes["session"] = session
def ddl_init_and_upgrade(disable_alembic_upgrade: bool):
"""Initialize and upgrade database metadata
Args:
disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata
"""
if disable_alembic_upgrade:
logger.info(
"disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic"
)
return
with engine.connect() as connection:
alembic_cfg.attributes["connection"] = connection
heads = command.heads(alembic_cfg)
print("heads:" + str(heads))
command.revision(alembic_cfg, "dbgpt ddl upate", True)
command.upgrade(alembic_cfg, "head")

View File

@@ -1 +0,0 @@

View File

@@ -1,10 +0,0 @@
from pilot.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from pilot.cache.manager import CacheManager, initialize_cache
__all__ = [
"LLMCacheKey",
"LLMCacheValue",
"LLMCacheClient",
"CacheManager",
"initialize_cache",
]

161
pilot/cache/base.py vendored
View File

@@ -1,161 +0,0 @@
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, TypeVar, Generic, Optional, Type, Dict
from dataclasses import dataclass
from enum import Enum
T = TypeVar("T", bound="Serializable")
K = TypeVar("K")
V = TypeVar("V")
class Serializable(ABC):
@abstractmethod
def serialize(self) -> bytes:
"""Convert the object into bytes for storage or transmission.
Returns:
bytes: The byte array after serialization
"""
@abstractmethod
def to_dict(self) -> Dict:
"""Convert the object's state to a dictionary."""
# @staticmethod
# @abstractclassmethod
# def from_dict(cls: Type["Serializable"], obj_dict: Dict) -> "Serializable":
# """Deserialize a dictionary to an Serializable object.
# """
class RetrievalPolicy(str, Enum):
EXACT_MATCH = "exact_match"
SIMILARITY_MATCH = "similarity_match"
class CachePolicy(str, Enum):
LRU = "lru"
FIFO = "fifo"
@dataclass
class CacheConfig:
retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH
cache_policy: Optional[CachePolicy] = CachePolicy.LRU
class CacheKey(Serializable, ABC, Generic[K]):
"""The key of the cache. Must be hashable and comparable.
Supported cache keys:
- The LLM cache key: Include user prompt and the parameters to LLM.
- The embedding model cache key: Include the texts to embedding and the parameters to embedding model.
"""
@abstractmethod
def __hash__(self) -> int:
"""Return the hash value of the key."""
@abstractmethod
def __eq__(self, other: Any) -> bool:
"""Check equality with another key."""
@abstractmethod
def get_hash_bytes(self) -> bytes:
"""Return the byte array of hash value."""
@abstractmethod
def get_value(self) -> K:
"""Get the underlying value of the cache key.
Returns:
K: The real object of current cache key
"""
class CacheValue(Serializable, ABC, Generic[V]):
"""Cache value abstract class."""
@abstractmethod
def get_value(self) -> V:
"""Get the underlying real value."""
class Serializer(ABC):
"""The serializer abstract class for serializing cache keys and values."""
@abstractmethod
def serialize(self, obj: Serializable) -> bytes:
"""Serialize a cache object.
Args:
obj (Serializable): The object to serialize
"""
@abstractmethod
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
"""Deserialize data back into a cache object of the specified type.
Args:
data (bytes): The byte array to deserialize
cls (Type[Serializable]): The type of current object
Returns:
Serializable: The serializable object
"""
class CacheClient(ABC, Generic[K, V]):
"""The cache client interface."""
@abstractmethod
async def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[CacheValue[V]]:
"""Retrieve a value from the cache using the provided key.
Args:
key (CacheKey[K]): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None.
"""
@abstractmethod
async def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key.
Args:
key (CacheKey[K]): The key to set to cache
value (CacheValue[V]): The value to set to cache
cache_config (Optional[CacheConfig]): Cache config
"""
@abstractmethod
async def exists(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> bool:
"""Check if a key exists in the cache.
Args:
key (CacheKey[K]): The key to set to cache
cache_config (Optional[CacheConfig]): Cache config
Return:
bool: True if the key in the cache, otherwise is False
"""
@abstractmethod
def new_key(self, **kwargs) -> CacheKey[K]:
"""Create a cache key with params"""
@abstractmethod
def new_value(self, **kwargs) -> CacheValue[K]:
"""Create a cache key with params"""

View File

View File

@@ -1,148 +0,0 @@
from typing import Optional, Dict, Any, Union, List
from dataclasses import dataclass, asdict
import json
import hashlib
from pilot.cache.base import CacheKey, CacheValue, Serializer, CacheClient, CacheConfig
from pilot.cache.manager import CacheManager
from pilot.cache.storage.base import CacheStorage
from pilot.model.base import ModelType, ModelOutput
@dataclass
class LLMCacheKeyData:
prompt: str
model_name: str
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
model_type: Optional[str] = ModelType.HF
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@dataclass
class LLMCacheValueData:
output: CacheOutputType
user: Optional[str] = None
_is_list: Optional[bool] = False
@staticmethod
def from_dict(**kwargs) -> "LLMCacheValueData":
output = kwargs.get("output")
if not output:
raise ValueError("Can't new LLMCacheValueData object, output is None")
if isinstance(output, dict):
output = ModelOutput(**output)
elif isinstance(output, list):
kwargs["_is_list"] = True
output_list = []
for out in output:
if isinstance(out, dict):
out = ModelOutput(**out)
output_list.append(out)
output = output_list
kwargs["output"] = output
return LLMCacheValueData(**kwargs)
def to_dict(self) -> Dict:
output = self.output
is_list = False
if isinstance(output, list):
output_list = []
is_list = True
for out in output:
output_list.append(out.to_dict())
output = output_list
else:
output = output.to_dict()
return {"output": output, "_is_list": is_list, "user": self.user}
@property
def is_list(self) -> bool:
return self._is_list
def __str__(self) -> str:
if not isinstance(self.output, list):
return f"user: {self.user}, output: {self.output}"
else:
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
super().__init__()
self._serializer = serializer
self.config = LLMCacheKeyData(**kwargs)
def __hash__(self) -> int:
serialize_bytes = self.serialize()
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, LLMCacheKey):
return False
return self.config == other.config
def get_hash_bytes(self) -> bytes:
serialize_bytes = self.serialize()
return hashlib.sha256(serialize_bytes).digest()
def to_dict(self) -> Dict:
return asdict(self.config)
def serialize(self) -> bytes:
return self._serializer.serialize(self)
def get_value(self) -> LLMCacheKeyData:
return self.config
class LLMCacheValue(CacheValue[LLMCacheValueData]):
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
super().__init__()
self._serializer = serializer
self.value = LLMCacheValueData.from_dict(**kwargs)
def to_dict(self) -> Dict:
return self.value.to_dict()
def serialize(self) -> bytes:
return self._serializer.serialize(self)
def get_value(self) -> LLMCacheValueData:
return self.value
def __str__(self) -> str:
return f"vaue: {str(self.value)}"
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
def __init__(self, cache_manager: CacheManager) -> None:
super().__init__()
self._cache_manager: CacheManager = cache_manager
async def get(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
) -> Optional[LLMCacheValue]:
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
async def set(
self,
key: LLMCacheKey,
value: LLMCacheValue,
cache_config: Optional[CacheConfig] = None,
) -> None:
return await self._cache_manager.set(key, value, cache_config)
async def exists(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
) -> bool:
return await self.get(key, cache_config) is not None
def new_key(self, **kwargs) -> LLMCacheKey:
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
def new_value(self, **kwargs) -> LLMCacheValue:
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)

126
pilot/cache/manager.py vendored
View File

@@ -1,126 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional, Type
import logging
from concurrent.futures import Executor
from pilot.cache.storage.base import CacheStorage, StorageItem
from pilot.cache.base import (
K,
V,
CacheKey,
CacheValue,
CacheConfig,
Serializer,
Serializable,
)
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
logger = logging.getLogger(__name__)
class CacheManager(BaseComponent, ABC):
name = ComponentType.MODEL_CACHE_MANAGER
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
self.system_app = system_app
@abstractmethod
async def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
"""Set cache"""
@abstractmethod
async def get(
self,
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
"""Get cache with key"""
@property
@abstractmethod
def serializer(self) -> Serializer:
"""Get cache serializer"""
class LocalCacheManager(CacheManager):
def __init__(
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
) -> None:
super().__init__(system_app)
self._serializer = serializer
self._storage = storage
@property
def executor(self) -> Executor:
"""Return executor to submit task"""
self._executor = self.system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
async def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
if self._storage.support_async():
await self._storage.aset(key, value, cache_config)
else:
await blocking_func_to_async(
self.executor, self._storage.set, key, value, cache_config
)
async def get(
self,
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
if self._storage.support_async():
item_bytes = await self._storage.aget(key, cache_config)
else:
item_bytes = await blocking_func_to_async(
self.executor, self._storage.get, key, cache_config
)
if not item_bytes:
return None
return self._serializer.deserialize(item_bytes.value_data, cls)
@property
def serializer(self) -> Serializer:
return self._serializer
def initialize_cache(
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
):
from pilot.cache.protocal.json_protocal import JsonSerializer
from pilot.cache.storage.base import MemoryCacheStorage
cache_storage = None
if storage_type == "disk":
try:
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
cache_storage = DiskCacheStorage(
persist_dir, mem_table_buffer_mb=max_memory_mb
)
except ImportError as e:
logger.warn(
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
)
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
else:
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
system_app.register(
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
)

View File

View File

@@ -1,44 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, Type
import json
from pilot.cache.base import Serializable, Serializer
JSON_ENCODING = "utf-8"
class JsonSerializable(Serializable, ABC):
@abstractmethod
def to_dict(self) -> Dict:
"""Return the dict of current serializable object"""
def serialize(self) -> bytes:
"""Convert the object into bytes for storage or transmission."""
return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
class JsonSerializer(Serializer):
"""The serializer abstract class for serializing cache keys and values."""
def serialize(self, obj: Serializable) -> bytes:
"""Serialize a cache object.
Args:
obj (Serializable): The object to serialize
"""
return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
"""Deserialize data back into a cache object of the specified type.
Args:
data (bytes): The byte array to deserialize
cls (Type[Serializable]): The type of current object
Returns:
Serializable: The serializable object
"""
# Convert bytes back to JSON and then to the specified class
json_data = json.loads(data.decode(JSON_ENCODING))
# Assume that the cls has an __init__ that accepts a dictionary
return cls(**json_data)

View File

View File

@@ -1,252 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
from collections import OrderedDict
import msgpack
import logging
from pilot.cache.base import (
K,
V,
CacheKey,
CacheValue,
CacheClient,
CacheConfig,
RetrievalPolicy,
CachePolicy,
)
from pilot.utils.memory_utils import _get_object_bytes
logger = logging.getLogger(__name__)
@dataclass
class StorageItem:
"""
A class representing a storage item.
This class encapsulates data related to a storage item, such as its length,
the hash of the key, and the data for both the key and value.
Parameters:
length (int): The bytes length of the storage item.
key_hash (bytes): The hash value of the storage item's key.
key_data (bytes): The data of the storage item's key, represented in bytes.
value_data (bytes): The data of the storage item's value, also in bytes.
"""
length: int # The bytes length of the storage item
key_hash: bytes # The hash value of the storage item's key
key_data: bytes # The data of the storage item's key
value_data: bytes # The data of the storage item's value
@staticmethod
def build_from(
key_hash: bytes, key_data: bytes, value_data: bytes
) -> "StorageItem":
length = (
32
+ _get_object_bytes(key_hash)
+ _get_object_bytes(key_data)
+ _get_object_bytes(value_data)
)
return StorageItem(
length=length, key_hash=key_hash, key_data=key_data, value_data=value_data
)
@staticmethod
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
key_hash = key.get_hash_bytes()
key_data = key.serialize()
value_data = value.serialize()
return StorageItem.build_from(key_hash, key_data, value_data)
def serialize(self) -> bytes:
"""Serialize the StorageItem into a byte stream using MessagePack.
This method packs the object data into a dictionary, marking the
key_data and value_data fields as raw binary data to avoid re-serialization.
Returns:
bytes: The serialized bytes.
"""
obj = {
"length": self.length,
"key_hash": msgpack.ExtType(1, self.key_hash),
"key_data": msgpack.ExtType(2, self.key_data),
"value_data": msgpack.ExtType(3, self.value_data),
}
return msgpack.packb(obj)
@staticmethod
def deserialize(data: bytes) -> "StorageItem":
"""Deserialize bytes back into a StorageItem using MessagePack.
This extracts the fields from the MessagePack dict back into
a StorageItem object.
Args:
data (bytes): Serialized bytes
Returns:
StorageItem: Deserialized StorageItem object.
"""
obj = msgpack.unpackb(data)
key_hash = obj["key_hash"].data
key_data = obj["key_data"].data
value_data = obj["value_data"].data
return StorageItem(
length=obj["length"],
key_hash=key_hash,
key_data=key_data,
value_data=value_data,
)
class CacheStorage(ABC):
@abstractmethod
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
"""Check whether the CacheConfig is legal.
Args:
cache_config (Optional[CacheConfig]): Cache config.
raise_error (Optional[bool]): Whether raise error if illegal.
Returns:
ValueError: Error when raise_error is True and config is illegal.
"""
def support_async(self) -> bool:
return False
@abstractmethod
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key.
Args:
key (CacheKey[K]): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
"""
async def aget(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key asynchronously.
Args:
key (CacheKey[K]): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
"""
raise NotImplementedError
@abstractmethod
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key asynchronously.
Args:
key (CacheKey[K]): The key to set to cache
value (CacheValue[V]): The value to set to cache
cache_config (Optional[CacheConfig]): Cache config
"""
async def aset(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key asynchronously.
Args:
key (CacheKey[K]): The key to set to cache
value (CacheValue[V]): The value to set to cache
cache_config (Optional[CacheConfig]): Cache config
"""
raise NotImplementedError
class MemoryCacheStorage(CacheStorage):
def __init__(self, max_memory_mb: int = 256):
self.cache = OrderedDict()
self.max_memory = max_memory_mb * 1024 * 1024
self.current_memory_usage = 0
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
):
if raise_error:
raise ValueError(
"MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy"
)
return False
return True
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
key_hash = hash(key)
item: StorageItem = self.cache.get(key_hash)
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
if not item:
return None
# Move the item to the end of the OrderedDict to signify recent use.
self.cache.move_to_end(key_hash)
return item
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
key_hash = hash(key)
item = StorageItem.build_from_kv(key, value)
# Calculate memory size of the new entry
new_entry_size = _get_object_bytes(item)
# Evict entries if necessary
while self.current_memory_usage + new_entry_size > self.max_memory:
self._apply_cache_policy(cache_config)
# Store the item in the cache.
self.cache[key_hash] = item
self.current_memory_usage += new_entry_size
logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}")
def exists(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> bool:
return self.get(key, cache_config) is not None
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
# Remove the oldest/newest item based on the cache policy.
if cache_config and cache_config.cache_policy == CachePolicy.FIFO:
self.cache.popitem(last=False)
else: # Default is LRU
self.cache.popitem(last=True)

View File

View File

@@ -1,93 +0,0 @@
from typing import Optional
import logging
from pilot.cache.base import (
K,
V,
CacheKey,
CacheValue,
CacheConfig,
RetrievalPolicy,
CachePolicy,
)
from pilot.cache.storage.base import StorageItem, CacheStorage
from rocksdict import Rdict
from rocksdict import Rdict, Options, SliceTransform, PlainTableFactoryOptions
logger = logging.getLogger(__name__)
def db_options(
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
):
opt = Options()
# create table
opt.create_if_missing(True)
# config to more jobs, default 2
opt.set_max_background_jobs(background_threads)
# configure mem-table to a large value
opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024)
# opt.set_write_buffer_size(1024)
# opt.set_level_zero_file_num_compaction_trigger(4)
# configure l0 and l1 size, let them have the same size (1 GB)
# opt.set_max_bytes_for_level_base(0x40000000)
# 256 MB file size
# opt.set_target_file_size_base(0x10000000)
# use a smaller compaction multiplier
# opt.set_max_bytes_for_level_multiplier(4.0)
# use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
# opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8))
# set to plain-table
# opt.set_plain_table_factory(PlainTableFactoryOptions())
return opt
class DiskCacheStorage(CacheStorage):
def __init__(
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
) -> None:
super().__init__()
self.db: Rdict = Rdict(
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
)
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
):
if raise_error:
raise ValueError(
"DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy"
)
return False
return True
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
key_hash = key.get_hash_bytes()
item_bytes = self.db.get(key_hash)
if not item_bytes:
return None
item = StorageItem.deserialize(item_bytes)
logger.debug(f"Read file cache, key: {key}, storage item: {item}")
return item
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
item = StorageItem.build_from_kv(key, value)
key_hash = item.key_hash
self.db[key_hash] = item.serialize()
logger.debug(f"Save file cache, key: {key}, value: {value}")

View File

@@ -1,53 +0,0 @@
import pytest
from ..base import StorageItem
from pilot.utils.memory_utils import _get_object_bytes
def test_build_from():
key_hash = b"key_hash"
key_data = b"key_data"
value_data = b"value_data"
item = StorageItem.build_from(key_hash, key_data, value_data)
assert item.key_hash == key_hash
assert item.key_data == key_data
assert item.value_data == value_data
assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes(
key_data
) + _get_object_bytes(value_data)
def test_build_from_kv():
class MockCacheKey:
def get_hash_bytes(self):
return b"key_hash"
def serialize(self):
return b"key_data"
class MockCacheValue:
def serialize(self):
return b"value_data"
key = MockCacheKey()
value = MockCacheValue()
item = StorageItem.build_from_kv(key, value)
assert item.key_hash == key.get_hash_bytes()
assert item.key_data == key.serialize()
assert item.value_data == value.serialize()
def test_serialize_deserialize():
key_hash = b"key_hash"
key_data = b"key_data"
value_data = b"value_data"
item = StorageItem.build_from(key_hash, key_data, value_data)
serialized = item.serialize()
deserialized = StorageItem.deserialize(serialized)
assert deserialized.key_hash == item.key_hash
assert deserialized.key_data == item.key_data
assert deserialized.value_data == item.value_data
assert deserialized.length == item.length

View File

@@ -1,47 +0,0 @@
import asyncio
from typing import Coroutine, List, Any
from starlette.responses import StreamingResponse
from pilot.scene.base_chat import BaseChat
from pilot.scene.chat_factory import ChatFactory
chat_factory = ChatFactory()
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
"""llm_chat_response_nostream"""
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
res = await chat.get_llm_response()
return res
async def llm_chat_response(chat_scene: str, **chat_param):
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
return chat.stream_call()
async def run_async_tasks(
tasks: List[Coroutine],
concurrency_limit: int = None,
) -> List[Any]:
"""Run a list of async tasks."""
tasks_to_execute: List[Any] = tasks
async def _gather() -> List[Any]:
if concurrency_limit:
semaphore = asyncio.Semaphore(concurrency_limit)
async def _execute_task(task):
async with semaphore:
return await task
# Execute tasks with semaphore limit
return await asyncio.gather(
*[_execute_task(task) for task in tasks_to_execute]
)
else:
return await asyncio.gather(*tasks_to_execute)
# outputs: List[Any] = asyncio.run(_gather())
return await _gather()

View File

@@ -1,34 +0,0 @@
from collections import OrderedDict
from collections import deque
class FixedSizeDict(OrderedDict):
def __init__(self, max_size):
super().__init__()
self.max_size = max_size
def __setitem__(self, key, value):
if len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)
class FixedSizeList:
def __init__(self, max_size):
self.max_size = max_size
self.list = deque(maxlen=max_size)
def append(self, value):
self.list.append(value)
def __getitem__(self, index):
return self.list[index]
def __setitem__(self, index, value):
self.list[index] = value
def __len__(self):
return len(self.list)
def __str__(self):
return str(list(self.list))

View File

@@ -1,61 +0,0 @@
"""Utilities for formatting strings."""
import json
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
"everything should be passed as keyword arguments."
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
class NoStrictFormatter(StrictFormatter):
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Not check unused args"""
pass
formatter = StrictFormatter()
no_strict_formatter = NoStrictFormatter()
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
elif hasattr(obj, "__dict__"):
return obj.__dict__
else:
return json.JSONEncoder.default(self, obj)

View File

@@ -1,448 +0,0 @@
"""General utils functions."""
import asyncio
import os
import random
import sys
import time
import traceback
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from itertools import islice
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Type,
Union,
cast,
)
class GlobalsHelper:
"""Helper to retrieve globals.
Helpful for global caching of certain variables that can be expensive to load.
(e.g. tokenization)
"""
_tokenizer: Optional[Callable[[str], List]] = None
_stopwords: Optional[List[str]] = None
@property
def tokenizer(self) -> Callable[[str], List]:
"""Get tokenizer."""
if self._tokenizer is None:
tiktoken_import_err = (
"`tiktoken` package not found, please run `pip install tiktoken`"
)
try:
import tiktoken
except ImportError:
raise ImportError(tiktoken_import_err)
enc = tiktoken.get_encoding("gpt2")
self._tokenizer = cast(Callable[[str], List], enc.encode)
self._tokenizer = partial(self._tokenizer, allowed_special="all")
return self._tokenizer # type: ignore
@property
def stopwords(self) -> List[str]:
"""Get stopwords."""
if self._stopwords is None:
try:
import nltk
from nltk.corpus import stopwords
except ImportError:
raise ImportError(
"`nltk` package not found, please run `pip install nltk`"
)
from llama_index.utils import get_cache_dir
cache_dir = get_cache_dir()
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
# update nltk path for nltk so that it finds the data
if nltk_data_dir not in nltk.data.path:
nltk.data.path.append(nltk_data_dir)
try:
nltk.data.find("corpora/stopwords")
except LookupError:
nltk.download("stopwords", download_dir=nltk_data_dir)
self._stopwords = stopwords.words("english")
return self._stopwords
globals_helper = GlobalsHelper()
def get_new_id(d: Set) -> str:
"""Get a new ID."""
while True:
new_id = str(uuid.uuid4())
if new_id not in d:
break
return new_id
def get_new_int_id(d: Set) -> int:
"""Get a new integer ID."""
while True:
new_id = random.randint(0, sys.maxsize)
if new_id not in d:
break
return new_id
@contextmanager
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
"""Temporary setter.
Utility class for setting a temporary value for an attribute on a class.
Taken from: https://tinyurl.com/2p89xymh
"""
prev_values = {k: getattr(obj, k) for k in kwargs}
for k, v in kwargs.items():
setattr(obj, k, v)
try:
yield
finally:
for k, v in prev_values.items():
setattr(obj, k, v)
@dataclass
class ErrorToRetry:
"""Exception types that should be retried.
Args:
exception_cls (Type[Exception]): Class of exception.
check_fn (Optional[Callable[[Any]], bool]]):
A function that takes an exception instance as input and returns
whether to retry.
"""
exception_cls: Type[Exception]
check_fn: Optional[Callable[[Any], bool]] = None
def retry_on_exceptions_with_backoff(
lambda_fn: Callable,
errors_to_retry: List[ErrorToRetry],
max_tries: int = 10,
min_backoff_secs: float = 0.5,
max_backoff_secs: float = 60.0,
) -> Any:
"""Execute lambda function with retries and exponential backoff.
Args:
lambda_fn (Callable): Function to be called and output we want.
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
At least one needs to be provided.
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
min_backoff_secs (float): Minimum amount of backoff time between attempts.
Defaults to 0.5.
max_backoff_secs (float): Maximum amount of backoff time between attempts.
Defaults to 60.
"""
if not errors_to_retry:
raise ValueError("At least one error to retry needs to be provided")
error_checks = {
error_to_retry.exception_cls: error_to_retry.check_fn
for error_to_retry in errors_to_retry
}
exception_class_tuples = tuple(error_checks.keys())
backoff_secs = min_backoff_secs
tries = 0
while True:
try:
return lambda_fn()
except exception_class_tuples as e:
traceback.print_exc()
tries += 1
if tries >= max_tries:
raise
check_fn = error_checks.get(e.__class__)
if check_fn and not check_fn(e):
raise
time.sleep(backoff_secs)
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a maximum length."""
if len(text) <= max_length:
return text
return text[: max_length - 3] + "..."
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
"""Iterate over an iterable in batches.
>>> list(iter_batch([1,2,3,4,5], 3))
[[1, 2, 3], [4, 5]]
"""
source_iter = iter(iterable)
while source_iter:
b = list(islice(source_iter, size))
if len(b) == 0:
break
yield b
def concat_dirs(dirname: str, basename: str) -> str:
"""
Append basename to dirname, avoiding backslashes when running on windows.
os.path.join(dirname, basename) will add a backslash before dirname if
basename does not end with a slash, so we make sure it does.
"""
dirname += "/" if dirname[-1] != "/" else ""
return os.path.join(dirname, basename)
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
"""
Optionally get a tqdm iterable. Ensures tqdm.auto is used.
"""
_iterator = items
if show_progress:
try:
from tqdm.auto import tqdm
return tqdm(items, desc=desc)
except ImportError:
pass
return _iterator
def count_tokens(text: str) -> int:
tokens = globals_helper.tokenizer(text)
return len(tokens)
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
"""
Args:
model_name(str): the model name of the tokenizer.
For instance, fxmarty/tiny-llama-fast-tokenizer.
"""
try:
from transformers import AutoTokenizer
except ImportError:
raise ValueError(
"`transformers` package not found, please run `pip install transformers`"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer.tokenize
def get_cache_dir() -> str:
"""Locate a platform-appropriate cache directory for llama_index,
and create it if it doesn't yet exist.
"""
# User override
if "LLAMA_INDEX_CACHE_DIR" in os.environ:
path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
# Linux, Unix, AIX, etc.
elif os.name == "posix" and sys.platform != "darwin":
path = Path("/tmp/llama_index")
# Mac OS
elif sys.platform == "darwin":
path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
# Windows (hopefully)
else:
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
"~\\AppData\\Local"
)
path = Path(local, "llama_index")
if not os.path.exists(path):
os.makedirs(
path, exist_ok=True
) # prevents https://github.com/jerryjliu/llama_index/issues/7362
return str(path)
def add_sync_version(func: Any) -> Any:
"""Decorator for adding sync version of an async function. The sync version
is added as a function attribute to the original function, func.
Args:
func(Any): the async function for which a sync variant will be built.
"""
assert asyncio.iscoroutinefunction(func)
@wraps(func)
def _wrapper(*args: Any, **kwds: Any) -> Any:
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
func.sync = _wrapper
return func
# Sample text from llama_index's readme
SAMPLE_TEXT = """
Context
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
They are pre-trained on large amounts of publicly available data.
How do we best augment LLMs with our own private data?
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
Proposed Solution
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
you build LLM apps. It provides the following tools:
Offers data connectors to ingest your existing data sources and data formats
(APIs, PDFs, docs, SQL, etc.)
Provides ways to structure your data (indices, graphs) so that this data can be
easily used with LLMs.
Provides an advanced retrieval/query interface over your data:
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
Allows easy integrations with your outer application framework
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
LlamaIndex provides tools for both beginner users and advanced users.
Our high-level API allows beginner users to use LlamaIndex to ingest and
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
customize and extend any module (data connectors, indices, retrievers, query engines,
reranking modules), to fit their needs.
"""
_LLAMA_INDEX_COLORS = {
"llama_pink": "38;2;237;90;200",
"llama_blue": "38;2;90;149;237",
"llama_turquoise": "38;2;11;159;203",
"llama_lavender": "38;2;155;135;227",
}
_ANSI_COLORS = {
"red": "31",
"green": "32",
"yellow": "33",
"blue": "34",
"magenta": "35",
"cyan": "36",
"pink": "38;5;200",
}
def get_color_mapping(
items: List[str], use_llama_index_colors: bool = True
) -> Dict[str, str]:
"""
Get a mapping of items to colors.
Args:
items (List[str]): List of items to be mapped to colors.
use_llama_index_colors (bool, optional): Flag to indicate
whether to use LlamaIndex colors or ANSI colors.
Defaults to True.
Returns:
Dict[str, str]: Mapping of items to colors.
"""
if use_llama_index_colors:
color_palette = _LLAMA_INDEX_COLORS
else:
color_palette = _ANSI_COLORS
colors = list(color_palette.keys())
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
def _get_colored_text(text: str, color: str) -> str:
"""
Get the colored version of the input text.
Args:
text (str): Input text.
color (str): Color to be applied to the text.
Returns:
str: Colored version of the input text.
"""
all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
if color not in all_colors:
return f"\033[1;3m{text}\033[0m" # just bolded and italicized
color = all_colors[color]
return f"\033[1;3;{color}m{text}\033[0m"
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
"""
Print the text with the specified color.
Args:
text (str): Text to be printed.
color (str, optional): Color to be applied to the text. Supported colors are:
llama_pink, llama_blue, llama_turquoise, llama_lavender,
red, green, yellow, blue, magenta, cyan, pink.
end (str, optional): String appended after the last character of the text.
Returns:
None
"""
text_to_print = _get_colored_text(text, color) if color is not None else text
print(text_to_print, end=end)
def infer_torch_device() -> str:
"""Infer the input to torch.device."""
try:
has_cuda = torch.cuda.is_available()
except NameError:
import torch
has_cuda = torch.cuda.is_available()
if has_cuda:
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def unit_generator(x: Any) -> Generator[Any, None, None]:
"""A function that returns a generator of a single element.
Args:
x (Any): the element to build yield
Yields:
Any: the single element
"""
yield x
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
"""A function that returns a generator of a single element.
Args:
x (Any): the element to build yield
Yields:
Any: the single element
"""
yield x

View File

@@ -1,7 +0,0 @@
import json
from datetime import date
def serialize(obj):
if isinstance(obj, date):
return obj.isoformat()

View File

@@ -1,43 +0,0 @@
from pydantic import Field, BaseModel
DEFAULT_CONTEXT_WINDOW = 3900
DEFAULT_NUM_OUTPUTS = 256
class LLMMetadata(BaseModel):
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=(
"Total number of tokens the model can be input and output for one response."
),
)
num_output: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="Number of tokens the model can output when generating a response.",
)
is_chat_model: bool = Field(
default=False,
description=(
"Set True if the model exposes a chat interface (i.e. can be passed a"
" sequence of messages, rather than text), like OpenAI's"
" /v1/chat/completions endpoint."
),
)
is_function_calling_model: bool = Field(
default=False,
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
description=(
"Set True if the model supports function calling messages, similar to"
" OpenAI's function calling API. For example, converting 'Email Anya to"
" see if she wants to get coffee next Friday' to a function call like"
" `send_email(to: string, body: string)`."
),
)
model_name: str = Field(
default="unknown",
description=(
"The model's name used for logging, testing, and sanity checking. For some"
" models this can be automatically discerned. For other models, like"
" locally loaded models, this must be manually specified."
),
)

View File

@@ -1,58 +0,0 @@
import markdown2
def datas_to_table_html(data):
import pandas as pd
df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style>
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
return html.replace("\n", " ")
def generate_markdown_table(data):
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n"""
# 获取表格列数
num_cols = len(data[0])
# 生成表头
header = "| "
for i in range(num_cols):
header += data[0][i] + " | "
# 生成分隔线
separator = "| "
for i in range(num_cols):
separator += "--- | "
# 生成表格内容
content = ""
for row in data[1:]:
content += "| "
for i in range(num_cols):
content += str(row[i]) + " | "
content += "\n"
# 合并表头、分隔线和表格内容
table = header + "\n" + separator + "\n" + content
return table
def generate_htm_table(data):
markdown_text = generate_markdown_table(data)
html_table = markdown2.markdown(markdown_text, extras=["tables"])
return html_table
if __name__ == "__main__":
# mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n"
# print(generate_htm_table(mk_text))
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
print(table_style.replace("\n", " "))

View File

@@ -1,6 +0,0 @@
import os
def has_path(filename):
directory = os.path.dirname(filename)
return bool(directory)

View File

@@ -1,6 +0,0 @@
def csv_colunm_foramt(val):
if str(val).find("$") >= 0:
return float(val.replace("$", "").replace(",", ""))
if str(val).find("¥") >= 0:
return float(val.replace("¥", "").replace(",", ""))
return val

View File

@@ -1,239 +0,0 @@
"""General prompt helper that can help deal with LLM context window token limitations.
At its core, it calculates available context size by starting with the context window
size of an LLM and reserve token space for the prompt template, and the output.
It provides utility for "repacking" text chunks (retrieved from index) to maximally
make use of the available context window (and thereby reducing the number of LLM calls
needed), or truncating them so that they fit in a single LLM call.
"""
import logging
from string import Formatter
from typing import Callable, List, Optional, Sequence
from pydantic import Field, PrivateAttr, BaseModel
from pilot.common.global_helper import globals_helper
from pilot.common.llm_metadata import LLMMetadata
from pilot.embedding_engine.loader.token_splitter import TokenTextSplitter
DEFAULT_PADDING = 5
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
DEFAULT_CONTEXT_WINDOW = 3000 # tokens
DEFAULT_NUM_OUTPUTS = 256 # tokens
logger = logging.getLogger(__name__)
class PromptHelper(BaseModel):
"""Prompt helper.
General prompt helper that can help deal with LLM context window token limitations.
At its core, it calculates available context size by starting with the context
window size of an LLM and reserve token space for the prompt template, and the
output.
It provides utility for "repacking" text chunks (retrieved from index) to maximally
make use of the available context window (and thereby reducing the number of LLM
calls needed), or truncating them so that they fit in a single LLM call.
Args:
context_window (int): Context window for the LLM.
num_output (int): Number of outputs for the LLM.
chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size
chunk_size_limit (Optional[int]): Maximum chunk size to use.
tokenizer (Optional[Callable[[str], List]]): Tokenizer to use.
separator (str): Separator for text splitter
"""
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The maximum context size that will get sent to the LLM.",
)
num_output: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The amount of token-space to leave in input for generation.",
)
chunk_overlap_ratio: float = Field(
default=DEFAULT_CHUNK_OVERLAP_RATIO,
description="The percentage token amount that each chunk should overlap.",
)
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
separator: str = Field(
default=" ", description="The separator when chunking tokens."
)
_tokenizer: Callable[[str], List] = PrivateAttr()
def __init__(
self,
context_window: int = DEFAULT_CONTEXT_WINDOW,
num_output: int = DEFAULT_NUM_OUTPUTS,
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
chunk_size_limit: Optional[int] = None,
tokenizer: Optional[Callable[[str], List]] = None,
separator: str = " ",
) -> None:
"""Init params."""
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
# TODO: make configurable
self._tokenizer = tokenizer or globals_helper.tokenizer
super().__init__(
context_window=context_window,
num_output=num_output,
chunk_overlap_ratio=chunk_overlap_ratio,
chunk_size_limit=chunk_size_limit,
separator=separator,
)
@classmethod
def from_llm_metadata(
cls,
llm_metadata: LLMMetadata,
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
chunk_size_limit: Optional[int] = None,
tokenizer: Optional[Callable[[str], List]] = None,
separator: str = " ",
) -> "PromptHelper":
"""Create from llm predictor.
This will autofill values like context_window and num_output.
"""
context_window = llm_metadata.context_window
if llm_metadata.num_output == -1:
num_output = DEFAULT_NUM_OUTPUTS
else:
num_output = llm_metadata.num_output
return cls(
context_window=context_window,
num_output=num_output,
chunk_overlap_ratio=chunk_overlap_ratio,
chunk_size_limit=chunk_size_limit,
tokenizer=tokenizer,
separator=separator,
)
@classmethod
def class_name(cls) -> str:
return "PromptHelper"
def _get_available_context_size(self, template: str) -> int:
"""Get available context size.
This is calculated as:
available context window = total context window
- input (partially filled prompt)
- output (room reserved for response)
Notes:
- Available context size is further clamped to be non-negative.
"""
empty_prompt_txt = get_empty_prompt_txt(template)
num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt))
context_size_tokens = (
self.context_window - num_empty_prompt_tokens - self.num_output
)
if context_size_tokens < 0:
raise ValueError(
f"Calculated available context size {context_size_tokens} was"
" not non-negative."
)
return context_size_tokens
def _get_available_chunk_size(
self, prompt_template: str, num_chunks: int = 1, padding: int = 5
) -> int:
"""Get available chunk size.
This is calculated as:
available chunk size = available context window // number_chunks
- padding
Notes:
- By default, we use padding of 5 (to save space for formatting needs).
- Available chunk size is further clamped to chunk_size_limit if specified.
"""
available_context_size = self._get_available_context_size(prompt_template)
result = available_context_size // num_chunks - padding
if self.chunk_size_limit is not None:
result = min(result, self.chunk_size_limit)
return result
def get_text_splitter_given_prompt(
self,
prompt_template: str,
num_chunks: int = 1,
padding: int = DEFAULT_PADDING,
) -> TokenTextSplitter:
"""Get text splitter configured to maximally pack available context window,
taking into account of given prompt, and desired number of chunks.
"""
chunk_size = self._get_available_chunk_size(
prompt_template, num_chunks, padding=padding
)
if chunk_size <= 0:
raise ValueError(f"Chunk size {chunk_size} is not positive.")
chunk_overlap = int(self.chunk_overlap_ratio * chunk_size)
return TokenTextSplitter(
separator=self.separator,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
tokenizer=self._tokenizer,
)
def repack(
self,
prompt_template: str,
text_chunks: Sequence[str],
padding: int = DEFAULT_PADDING,
) -> List[str]:
"""Repack text chunks to fit available context window.
This will combine text chunks into consolidated chunks
that more fully "pack" the prompt template given the context_window.
"""
text_splitter = self.get_text_splitter_given_prompt(
prompt_template, padding=padding
)
combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()])
return text_splitter.split_text(combined_str)
def get_empty_prompt_txt(template: str) -> str:
"""Get empty prompt text.
Substitute empty strings in parts of the prompt that have
not yet been filled out. Skip variables that have already
been partially formatted. This is used to compute the initial tokens.
"""
# partial_kargs = prompt.kwargs
partial_kargs = {}
template_vars = get_template_vars(template)
empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs}
all_kwargs = {**partial_kargs, **empty_kwargs}
prompt = template.format(**all_kwargs)
return prompt
def get_template_vars(template_str: str) -> List[str]:
"""Get template variables from a template string."""
variables = []
formatter = Formatter()
for _, variable_name, _, _ in formatter.parse(template_str):
if variable_name:
variables.append(variable_name)
return variables

View File

@@ -1,56 +0,0 @@
from enum import auto, Enum
import os
class SeparatorStyle(Enum):
SINGLE = "###"
TWO = "</s>"
THREE = auto()
FOUR = auto()
class ExampleType(Enum):
ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot"
class DbInfo:
def __init__(self, name, is_file_db: bool = False):
self.name = name
self.is_file_db = is_file_db
class DBType(Enum):
Mysql = DbInfo("mysql")
OCeanBase = DbInfo("oceanbase")
DuckDb = DbInfo("duckdb", True)
SQLite = DbInfo("sqlite", True)
Oracle = DbInfo("oracle")
MSSQL = DbInfo("mssql")
Postgresql = DbInfo("postgresql")
Clickhouse = DbInfo("clickhouse")
StarRocks = DbInfo("starrocks")
Spark = DbInfo("spark", True)
Doris = DbInfo("doris")
def value(self):
return self._value_.name
def is_file_db(self):
return self._value_.is_file_db
@staticmethod
def of_db_type(db_type: str):
for item in DBType:
if item.value() == db_type:
return item
return None
@staticmethod
def parse_file_db_name_from_path(db_type: str, local_db_path: str):
"""Parse out the database name of the embedded database from the file path"""
base_name = os.path.basename(local_db_path)
db_name = os.path.splitext(base_name)[0]
if "." in db_name:
db_name = os.path.splitext(db_name)[0]
return db_type + "_" + db_name

View File

@@ -1,491 +0,0 @@
from __future__ import annotations
import sqlparse
import regex as re
import warnings
from typing import Any, Iterable, List, Optional
import sqlalchemy
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
f'Name: {index["name"]}, Unique: {index["unique"]},'
f' Columns: {str(index["column_names"])}'
)
class Database:
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
self._db_sessions = Session
self._all_tables = set()
self.view_support = False
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = set()
self._indexes_in_table_info = set()
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
self._indexes_in_table_info = indexes_in_table_info
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> Database:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return self._all_tables - self._ignore_tables
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
def get_session_db(self, connect):
sql = text(f"select DATABASE()")
cursor = connect.execute(sql)
result = cursor.fetchone()[0]
return result
def get_session(self, db_name: str):
session = self._db_sessions()
self._metadata = MetaData()
# sql = f"use {db_name}"
sql = text(f"use `{db_name}`")
session.execute(sql)
# 处理表信息数据
self._metadata.reflect(bind=self._engine, schema=db_name)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=db_name)
+ (
self._inspector.get_view_names(schema=db_name)
if self.view_support
else []
)
)
return session
def get_current_db_name(self, session) -> str:
return session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self, session):
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
"""
cursor = session.execute(text(_sql))
results = cursor.fetchall()
return results
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
for table in meta_tables:
if self._custom_table_info and table.name in self._custom_table_info:
tables.append(self._custom_table_info[table.name])
continue
# add create table command
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
if self._indexes_in_table_info:
table_info += f"\n{self._get_table_indexes(table)}\n"
if self._sample_rows_in_table_info:
table_info += f"\n{self._get_sample_rows(table)}\n"
if has_extra_info:
table_info += "*/"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def _get_sample_rows(self, table: Table) -> str:
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
# save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
try:
# get the sample rows
with self._engine.connect() as connection:
sample_rows_result: CursorResult = connection.execute(command)
# shorten values in the sample rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _get_table_indexes(self, table: Table) -> str:
indexes = self._inspector.get_indexes(table.name)
indexes_formatted = "\n".join(map(_format_index, indexes))
return f"Table Indexes:\n{indexes_formatted}"
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
def __write(self, session, write_sql):
print(f"Write[{write_sql}]")
db_cache = self.get_session_db(session)
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self, session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
return result
def query_ex(self, session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys())
result = list(result)
return field_names, result
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)
if not command:
return []
parsed, ttype, sql_type = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query(session, command, fetch)
else:
self.__write(session, command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(session, select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
cursor = session.execute(text(command))
session.commit()
if cursor.returns_rows:
result = cursor.fetchall()
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
return self.__query(session, "SHOW COLUMNS FROM test")
return result
else:
return self.__query(session, "SHOW COLUMNS FROM test")
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement throws an error, the error message is returned.
"""
try:
return self.run(session, command, fetch)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0]
not in [
"information_schema",
"performance_schema",
"sys",
"mysql",
"knowledge_management",
]
]
def convert_sql_write_to_select(self, write_sql):
"""
SQL classification processing
author:xiangh8
Args:
sql:
Returns:
"""
# 将SQL命令转换为小写并按空格拆分
parts = write_sql.lower().split()
# 获取命令类型insert, delete, update
cmd_type = parts[0]
# 根据命令类型进行处理
if cmd_type == "insert":
match = re.match(
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
)
if match:
table_name, columns, values = match.groups()
# 将字段列表和值列表分割为单独的字段和值
columns = columns.split(",")
values = values.split(",")
# 构造 WHERE 子句
where_clause = " AND ".join(
[
f"{col.strip()}={val.strip()}"
for col, val in zip(columns, values)
]
)
return f"SELECT * FROM {table_name} WHERE {where_clause}"
elif cmd_type == "delete":
table_name = parts[2] # delete from <table_name> ...
# 返回一个select语句它选择该表的所有数据
return f"SELECT * FROM {table_name} "
elif cmd_type == "update":
table_name = parts[1]
set_idx = parts.index("set")
where_idx = parts.index("where")
# 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
# 截取 `where` 之后的条件语句
where_clause = " ".join(parts[where_idx + 1 :])
# 返回一个select语句它选择更新的数据
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
else:
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
def __sql_parse(self, sql):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
return parsed, ttype, sql_type
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW INDEXES FROM `{table_name}`"))
indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes]
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW CREATE TABLE `{table_name}`"))
ans = cursor.fetchall()
res = ans[0][1]
res = re.sub(r"\s*ENGINE\s*=\s*InnoDB\s*", " ", res, flags=re.IGNORECASE)
res = re.sub(
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", res, flags=re.IGNORECASE
)
res = re.sub(r"\s*COLLATE\s*=\s*\w+\s*", " ", res, flags=re.IGNORECASE)
return res
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
table_name
)
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_charset(self):
"""Get character_set."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@character_set_database"))
character_set = cursor.fetchone()[0]
return character_set
def get_collation(self):
"""Get collation."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@collation_database"))
collation = cursor.fetchone()[0]
return collation
def get_grants(self):
"""Get grant info."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW GRANTS"))
grants = cursor.fetchall()
return grants
def get_users(self):
"""Get user info."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT user, host FROM mysql.user"))
users = cursor.fetchall()
return [(user[0], user[1]) for user in users]
def get_table_comments(self, database):
session = self._db_sessions()
cursor = session.execute(
text(
f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{database}'""".format(
database
)
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]

View File

@@ -1,81 +0,0 @@
import re
def is_all_chinese(text):
### Determine whether the string is pure Chinese
pattern = re.compile(r"^[一-龥]+$")
match = re.match(pattern, text)
return match is not None
def is_number_chinese(text):
### Determine whether the string is numbers and Chinese
pattern = re.compile(r"^[\d一-龥]+$")
match = re.match(pattern, text)
return match is not None
def is_chinese_include_number(text):
### Determine whether the string is pure Chinese or Chinese containing numbers
pattern = re.compile(r"^[一-龥]+[\d一-龥]*$")
match = re.match(pattern, text)
return match is not None
def is_scientific_notation(string):
# 科学计数法的正则表达式
pattern = r"^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$"
# 使用正则表达式匹配字符串
match = re.match(pattern, str(string))
# 判断是否匹配成功
if match is not None:
return True
else:
return False
def extract_content(long_string, s1, s2, is_include: bool = False):
# extract text
match_map = {}
start_index = long_string.find(s1)
while start_index != -1:
if is_include:
end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index : end_index + len(s2)]
else:
end_index = long_string.find(s2, start_index + len(s1))
extracted_content = long_string[start_index + len(s1) : end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index = long_string.find(s1, start_index + 1)
return match_map
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
# extract text open ending
match_map = {}
start_index = long_string.find(s1)
while start_index != -1:
if long_string.find(s2, start_index) <= 0:
end_index = len(long_string)
else:
if is_include:
end_index = long_string.find(s2, start_index + len(s1) + 1)
else:
end_index = long_string.find(s2, start_index + len(s1))
if is_include:
extracted_content = long_string[start_index : end_index + len(s2)]
else:
extracted_content = long_string[start_index + len(s1) : end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index = long_string.find(s1, start_index + 1)
return match_map
if __name__ == "__main__":
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
s1 = "123"
s2 = "456"
print(extract_content_open_ending(s, s1, s2, True))

View File

@@ -1,221 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import sys
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum
import logging
import asyncio
# Checking for type hints during runtime
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
class LifeCycle:
"""This class defines hooks for lifecycle events of a component."""
def before_start(self):
"""Called before the component starts."""
pass
async def async_before_start(self):
"""Asynchronous version of before_start."""
pass
def after_start(self):
"""Called after the component has started."""
pass
async def async_after_start(self):
"""Asynchronous version of after_start."""
pass
def before_stop(self):
"""Called before the component stops."""
pass
async def async_before_stop(self):
"""Asynchronous version of before_stop."""
pass
class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
MODEL_REGISTRY = "dbgpt_model_registry"
MODEL_API_SERVER = "dbgpt_model_api_server"
MODEL_CACHE_MANAGER = "dbgpt_model_cache_manager"
AGENT_HUB = "dbgpt_agent_hub"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
TRACER = "dbgpt_tracer"
TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage"
RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
class BaseComponent(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this."""
name = "base_dbgpt_component"
def __init__(self, system_app: Optional[SystemApp] = None):
if system_app is not None:
self.init_app(system_app)
@abstractmethod
def init_app(self, system_app: SystemApp):
"""Initialize the component with the main application.
This method needs to be implemented by every component to define how it integrates
with the main system app.
"""
T = TypeVar("T", bound=BaseComponent)
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components."""
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
self.components: Dict[
str, BaseComponent
] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app
@property
def app(self) -> Optional["FastAPI"]:
"""Returns the internal ASGI app."""
return self._asgi_app
def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
"""Register a new component by its type.
Args:
component (Type[BaseComponent]): The component class to register
Returns:
T: The instance of registered component
"""
instance = component(self, *args, **kwargs)
self.register_instance(instance)
return instance
def register_instance(self, instance: T) -> T:
"""Register an already initialized component.
Args:
instance (T): The component instance to register
Returns:
T: The instance of registered component
"""
name = instance.name
if isinstance(name, ComponentType):
name = name.value
if name in self.components:
raise RuntimeError(
f"Componse name {name} already exists: {self.components[name]}"
)
logger.info(f"Register component with name {name} and instance: {instance}")
self.components[name] = instance
instance.init_app(self)
return instance
def get_component(
self,
name: Union[str, ComponentType],
component_type: Type[T],
default_component=_EMPTY_DEFAULT_COMPONENT,
or_register_component: Type[BaseComponent] = None,
*args,
**kwargs,
) -> T:
"""Retrieve a registered component by its name and type.
Args:
name (Union[str, ComponentType]): Component name
component_type (Type[T]): The type of current retrieve component
default_component : The default component instance if not retrieve by name
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
Returns:
T: The instance retrieved by component name
"""
if isinstance(name, ComponentType):
name = name.value
component = self.components.get(name)
if not component:
if or_register_component:
return self.register(or_register_component, *args, **kwargs)
if default_component != _EMPTY_DEFAULT_COMPONENT:
return default_component
raise ValueError(f"No component found with name {name}")
if not isinstance(component, component_type):
raise TypeError(f"Component {name} is not of type {component_type}")
return component
def before_start(self):
"""Invoke the before_start hooks for all registered components."""
for _, v in self.components.items():
v.before_start()
async def async_before_start(self):
"""Asynchronously invoke the before_start hooks for all registered components."""
tasks = [v.async_before_start() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def after_start(self):
"""Invoke the after_start hooks for all registered components."""
for _, v in self.components.items():
v.after_start()
async def async_after_start(self):
"""Asynchronously invoke the after_start hooks for all registered components."""
tasks = [v.async_after_start() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def before_stop(self):
"""Invoke the before_stop hooks for all registered components."""
for _, v in self.components.items():
try:
v.before_stop()
except Exception as e:
pass
async def async_before_stop(self):
"""Asynchronously invoke the before_stop hooks for all registered components."""
tasks = [v.async_before_stop() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def _build(self):
"""Integrate lifecycle events with the internal ASGI app if available."""
if not self.app:
return
@self.app.on_event("startup")
async def startup_event():
"""ASGI app startup event handler."""
async def _startup_func():
try:
await self.async_after_start()
except Exception as e:
logger.error(f"Error starting system app: {e}")
sys.exit(1)
asyncio.create_task(_startup_func())
self.after_start()
@self.app.on_event("shutdown")
async def shutdown_event():
"""ASGI app shutdown event handler."""
await self.async_before_stop()
self.before_stop()

View File

@@ -1,16 +0,0 @@
import os
import random
import sys
from dotenv import load_dotenv
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42")
random.seed(42)
# Load the users .env file into environment variables
load_dotenv(verbose=True, override=True)
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
load_dotenv(os.path.join(ROOT_PATH, ".plugin_env"))
del load_dotenv

View File

@@ -1,284 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
from typing import List, Optional, TYPE_CHECKING
from pilot.singleton import Singleton
if TYPE_CHECKING:
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.component import SystemApp
class Config(metaclass=Singleton):
"""Configuration class to store the state of bools for different scripts access"""
def __init__(self) -> None:
"""Initialize the Config class"""
self.NEW_SERVER_MODE = False
self.SERVER_LIGHT_MODE = False
# Gradio language version: en, zh
self.LANGUAGE = os.getenv("LANGUAGE", "en")
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
self.debug_mode = False
self.skip_reprompt = False
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
# self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
)
# User agent header to use when making HTTP requests
# Some websites might just completely deny request with an error code if
# no user agent was found.
self.user_agent = os.getenv(
"USER_AGENT",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36"
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
)
# This is a proxy server, just for test_py. we will remove this later.
self.proxy_api_key = os.getenv("PROXY_API_KEY")
self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY")
# In order to be compatible with the new and old model parameter design
if self.bard_proxy_api_key:
os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key
# tongyi
self.tongyi_proxy_api_key = os.getenv("TONGYI_PROXY_API_KEY")
if self.tongyi_proxy_api_key:
os.environ["tongyi_proxyllm_proxy_api_key"] = self.tongyi_proxy_api_key
# zhipu
self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY")
if self.zhipu_proxy_api_key:
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
"ZHIPU_MODEL_VERSION"
)
# wenxin
self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY")
self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_API_SECRET")
self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION")
if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret:
os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key
os.environ[
"wenxin_proxyllm_proxy_api_secret"
] = self.wenxin_proxy_api_secret
os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version
# xunfei spark
self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION")
self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID")
if self.spark_proxy_api_key and self.spark_proxy_api_secret:
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version
os.environ["spark_proxyllm_proxy_api_app_id"] = self.spark_proxy_api_appid
# baichuan proxy
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
self.bc_proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET")
self.bc_model_version = os.getenv("BAICHUN_MODEL_NAME")
if self.bc_proxy_api_key and self.bc_proxy_api_secret:
os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID")
self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID")
self.use_mac_os_tts = False
self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS")
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n")
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
self.image_provider = os.getenv("IMAGE_PROVIDER")
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
self.huggingface_image_model = os.getenv(
"HUGGINGFACE_IMAGE_MODEL", "CompVis/stable-diffusion-v1-4"
)
self.huggingface_audio_to_text_model = os.getenv(
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
)
self.speak_mode = False
from pilot.prompts.prompt_registry import PromptTemplateRegistry
self.prompt_template_registry = PromptTemplateRegistry()
### Related configuration of built-in commands
self.command_registry = []
### Relate configuration of disply commands
self.command_disply = []
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
if disabled_command_categories:
self.disabled_command_categories = disabled_command_categories.split(",")
else:
self.disabled_command_categories = []
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
)
### message stor file
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
self.plugins: List["AutoGPTPluginTemplate"] = []
self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true"
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
if plugins_allowlist:
self.plugins_allowlist = plugins_allowlist.split(",")
else:
self.plugins_allowlist = []
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
if plugins_denylist:
self.plugins_denylist = plugins_denylist.split(",")
else:
self.plugins_denylist = []
### Native SQL Execution Capability Control Configuration
self.NATIVE_SQL_CAN_RUN_DDL = (
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true"
)
self.NATIVE_SQL_CAN_RUN_WRITE = (
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
)
self.LOCAL_DB_MANAGE = None
###dbgpt meta info database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
self.LOCAL_DB_HOST = "127.0.0.1"
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME", "dbgpt")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db")
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
self.LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
self.PROXYLLM_BACKEND = None
if self.LLM_MODEL == "proxyllm":
self.PROXYLLM_BACKEND = os.getenv("PROXYLLM_BACKEND")
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)
self.MODEL_SERVER = os.getenv(
"MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)
)
### Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
# QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False").lower() == "true"
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
self.IS_LOAD_8BIT = False
# In order to be compatible with the new and old model parameter design
os.environ["load_8bit"] = str(self.IS_LOAD_8BIT)
os.environ["load_4bit"] = str(self.IS_LOAD_4BIT)
### EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
# default recall similarity score, between 0 and 1
self.KNOWLEDGE_SEARCH_RECALL_SCORE = float(
os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3)
)
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
)
# Whether to enable Chat Knowledge Search Rewrite Mode
self.KNOWLEDGE_SEARCH_REWRITE = (
os.getenv("KNOWLEDGE_SEARCH_REWRITE", "False").lower() == "true"
)
# Control whether to display the source document of knowledge on the front end.
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = (
os.getenv("KNOWLEDGE_CHAT_SHOW_RELATIONS", "False").lower() == "true"
)
### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None)
### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
self.SYSTEM_APP: Optional["SystemApp"] = None
### Temporary configuration
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
self.MODEL_CACHE_ENABLE: bool = (
os.getenv("MODEL_CACHE_ENABLE", "True").lower() == "true"
)
self.MODEL_CACHE_STORAGE_TYPE: str = os.getenv(
"MODEL_CACHE_STORAGE_TYPE", "disk"
)
self.MODEL_CACHE_MAX_MEMORY_MB: int = int(
os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256)
)
self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
"MODEL_CACHE_STORAGE_DISK_DIR"
)
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value
def set_templature(self, value: int) -> None:
"""Set the temperature value."""
self.temperature = value
def set_speak_mode(self, value: bool) -> None:
"""Set the speak mode value."""
self.speak_mode = value
def set_last_plugin_return(self, value: bool) -> None:
"""Set the speak mode value."""
self.last_plugin_return = value

View File

@@ -1,176 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
from functools import cache
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs"))
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data")
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel")
current_directory = os.getcwd()
new_directory = PILOT_PATH
os.chdir(new_directory)
@cache
def get_device() -> str:
try:
import torch
return (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
except ModuleNotFoundError:
return "cpu"
LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"vicuna-7b": os.path.join(MODEL_PATH, "vicuna-7b"),
# (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5
"vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"),
"vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"),
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"chatglm2-6b": os.path.join(MODEL_PATH, "chatglm2-6b"),
"chatglm2-6b-int4": os.path.join(MODEL_PATH, "chatglm2-6b-int4"),
# https://huggingface.co/THUDM/chatglm3-6b
"chatglm3-6b": os.path.join(MODEL_PATH, "chatglm3-6b"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "chatgpt_proxyllm",
"chatgpt_proxyllm": "chatgpt_proxyllm",
"bard_proxyllm": "bard_proxyllm",
"claude_proxyllm": "claude_proxyllm",
"wenxin_proxyllm": "wenxin_proxyllm",
"tongyi_proxyllm": "tongyi_proxyllm",
"zhipu_proxyllm": "zhipu_proxyllm",
"bc_proxyllm": "bc_proxyllm",
"spark_proxyllm": "spark_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
# https://huggingface.co/Qwen/Qwen-7B-Chat
"qwen-7b-chat": os.path.join(MODEL_PATH, "Qwen-7B-Chat"),
# https://huggingface.co/Qwen/Qwen-7B-Chat-Int8
"qwen-7b-chat-int8": os.path.join(MODEL_PATH, "Qwen-7B-Chat-Int8"),
# https://huggingface.co/Qwen/Qwen-7B-Chat-Int4
"qwen-7b-chat-int4": os.path.join(MODEL_PATH, "Qwen-7B-Chat-Int4"),
# https://huggingface.co/Qwen/Qwen-14B-Chat
"qwen-14b-chat": os.path.join(MODEL_PATH, "Qwen-14B-Chat"),
# https://huggingface.co/Qwen/Qwen-14B-Chat-Int8
"qwen-14b-chat-int8": os.path.join(MODEL_PATH, "Qwen-14B-Chat-Int8"),
# https://huggingface.co/Qwen/Qwen-14B-Chat-Int4
"qwen-14b-chat-int4": os.path.join(MODEL_PATH, "Qwen-14B-Chat-Int4"),
# https://huggingface.co/Qwen/Qwen-72B-Chat
"qwen-72b-chat": os.path.join(MODEL_PATH, "Qwen-72B-Chat"),
# https://huggingface.co/Qwen/Qwen-72B-Chat-Int8
"qwen-72b-chat-int8": os.path.join(MODEL_PATH, "Qwen-72B-Chat-Int8"),
# https://huggingface.co/Qwen/Qwen-72B-Chat-Int4
"qwen-72b-chat-int4": os.path.join(MODEL_PATH, "Qwen-72B-Chat-Int4"),
# https://huggingface.co/Qwen/Qwen-1_8B-Chat
"qwen-1.8b-chat": os.path.join(MODEL_PATH, "Qwen-1_8B-Chat"),
# https://huggingface.co/Qwen/Qwen-1_8B-Chat-Int8
"qwen-1.8b-chat-int8": os.path.join(MODEL_PATH, "wen-1_8B-Chat-Int8"),
# https://huggingface.co/Qwen/Qwen-1_8B-Chat-Int4
"qwen-1.8b-chat-int4": os.path.join(MODEL_PATH, "Qwen-1_8B-Chat-Int4"),
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
# wget https://huggingface.co/TheBloke/vicuna-13B-v1.5-GGUF/resolve/main/vicuna-13b-v1.5.Q4_K_M.gguf -O models/ggml-model-q4_0.gguf
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.gguf"),
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
"codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
"codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
"codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
# https://huggingface.co/microsoft/Orca-2-7b
"orca-2-7b": os.path.join(MODEL_PATH, "Orca-2-7b"),
# https://huggingface.co/microsoft/Orca-2-13b
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
# https://huggingface.co/openchat/openchat_3.5
"openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
# https://huggingface.co/hfl/chinese-alpaca-2-7b
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
# https://huggingface.co/hfl/chinese-alpaca-2-13b
"chinese-alpaca-2-13b": os.path.join(MODEL_PATH, "chinese-alpaca-2-13b"),
# https://huggingface.co/THUDM/codegeex2-6b
"codegeex2-6b": os.path.join(MODEL_PATH, "codegeex2-6b"),
# https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
"zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"),
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
"mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"),
# https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
"mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"),
# https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1
"xwin-lm-7b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-7B-V0.1"),
# https://huggingface.co/Xwin-LM/Xwin-LM-13B-V0.1
"xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"),
# https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
# https://huggingface.co/01-ai/Yi-34B-Chat
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
# https://huggingface.co/01-ai/Yi-34B-Chat-8bits
"yi-34b-chat-8bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-8bits"),
# https://huggingface.co/01-ai/Yi-34B-Chat-4bits
"yi-34b-chat-4bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-4bits"),
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
}
EMBEDDING_MODEL_CONFIG = {
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
# https://huggingface.co/moka-ai/m3e-large
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
# https://huggingface.co/moka-ai/m3e-base
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
# https://huggingface.co/BAAI/bge-large-en
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
# https://huggingface.co/BAAI/bge-large-zh
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",
}
# Load model config
ISDEBUG = False
VECTOR_SEARCH_TOP_K = 10
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "data"
)
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100

View File

@@ -1 +0,0 @@
from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao

View File

@@ -1,59 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
"""We need to design a base class. That other connector can Write with this"""
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional
class BaseConnect(ABC):
def get_connect(self, db_name: str):
pass
def get_table_names(self) -> Iterable[str]:
pass
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
pass
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
pass
def get_example_data(self, table: str, count: int = 3):
pass
def get_database_list(self):
pass
def get_database_names(self):
pass
def get_table_comments(self, db_name):
pass
def run(self, session, command: str, fetch: str = "all") -> List:
pass
def run_to_df(self, command: str, fetch: str = "all"):
pass
def get_users(self):
pass
def get_grants(self):
pass
def get_collation(self):
pass
def get_charset(self):
pass
def get_fields(self, table_name):
pass
def get_show_create_table(self, table_name):
pass
def get_indexes(self, table_name):
pass

View File

@@ -1,119 +0,0 @@
from typing import Optional, Any
from pilot.connections.base import BaseConnect
class SparkConnect(BaseConnect):
"""
Spark Connect supports operating on a variety of data sources through the DataFrame interface.
A DataFrame can be operated on using relational transformations and can also be used to create a temporary view.
Registering a DataFrame as a temporary view allows you to run SQL queries over its data.
Datasource now support parquet, jdbc, orc, libsvm, csv, text, json.
"""
"""db type"""
db_type: str = "spark"
"""db driver"""
driver: str = "spark"
"""db dialect"""
dialect: str = "sparksql"
def __init__(
self,
file_path: str,
spark_session: Optional = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Initialize the Spark DataFrame from Datasource path
return: Spark DataFrame
"""
from pyspark.sql import SparkSession
self.spark_session = (
spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate()
)
self.path = file_path
self.table_name = "temp"
self.df = self.create_df(self.path)
@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
):
try:
return cls(file_path=file_path, engine_args=engine_args)
except Exception as e:
print("load spark datasource error" + str(e))
def create_df(self, path):
"""Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.).
return: Spark DataFrame
reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html
"""
extension = (
"text" if path.rsplit(".", 1)[-1] == "txt" else path.rsplit(".", 1)[-1]
)
return self.spark_session.read.load(
path, format=extension, inferSchema="true", header="true"
)
def run(self, sql):
print(f"spark sql to run is {sql}")
self.df.createOrReplaceTempView(self.table_name)
df = self.spark_session.sql(sql)
first_row = df.first()
rows = [first_row.asDict().keys()]
for row in df.collect():
rows.append(row)
return rows
def query_ex(self, sql):
rows = self.run(sql)
field_names = rows[0]
return field_names, rows
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
return "ans"
def get_fields(self):
"""Get column meta about dataframe."""
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_db_list(self):
return ["default"]
def get_db_names(self):
return ["default"]
def get_database_list(self):
return []
def get_database_names(self):
return []
def table_simple_info(self):
return f"{self.table_name}{self.get_fields()}"
def get_table_comments(self, db_name):
return ""

View File

@@ -1,17 +0,0 @@
from pydantic import BaseModel, Field
class DBConfig(BaseModel):
db_type: str
db_name: str
file_path: str = ""
db_host: str = ""
db_port: int = 0
db_user: str = ""
db_pwd: str = ""
comment: str = ""
class DbTypeInfo(BaseModel):
db_type: str
is_file_db: bool = False

View File

@@ -1,244 +0,0 @@
from sqlalchemy import Column, Integer, String, Index, Text, text
from sqlalchemy import UniqueConstraint
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
class ConnectConfigEntity(Base):
"""db connect config entity"""
__tablename__ = "connect_config"
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
db_type = Column(String(255), nullable=False, comment="db type")
db_name = Column(String(255), nullable=False, comment="db name")
db_path = Column(String(255), nullable=True, comment="file db path")
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
db_port = Column(String(255), nullable=True, comment="db connect port(not file db)")
db_user = Column(String(255), nullable=True, comment="db user")
db_pwd = Column(String(255), nullable=True, comment="db password")
comment = Column(Text, nullable=True, comment="db comment")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
__table_args__ = (
UniqueConstraint("db_name", name="uk_db"),
Index("idx_q_db_type", "db_type"),
{"mysql_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
"""db connect config dao"""
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def update(self, entity: ConnectConfigEntity):
"""update db connect info"""
session = self.get_session()
try:
updated = session.merge(entity)
session.commit()
return updated.id
finally:
session.close()
def delete(self, db_name: int):
""" "delete db connect info"""
session = self.get_session()
if db_name is None:
raise Exception("db_name is None")
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
db_connect.delete()
session.commit()
session.close()
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
"""get db connect info by name"""
session = self.get_session()
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
result = db_connect.first()
session.close()
return result
def add_url_db(
self,
db_name,
db_type,
db_host: str,
db_port: int,
db_user: str,
db_pwd: str,
comment: str = "",
):
"""
add db connect info
Args:
db_name: db name
db_type: db type
db_host: db host
db_port: db port
db_user: db user
db_pwd: db password
comment: comment
"""
try:
session = self.get_session()
from sqlalchemy import text
insert_statement = text(
"""
INSERT INTO connect_config (
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
) VALUES (
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
)
"""
)
params = {
"db_name": db_name,
"db_type": db_type,
"db_path": "",
"db_host": db_host,
"db_port": db_port,
"db_user": db_user,
"db_pwd": db_pwd,
"comment": comment,
}
session.execute(insert_statement, params)
session.commit()
session.close()
except Exception as e:
print("add db connect info error" + str(e))
def update_db_info(
self,
db_name,
db_type,
db_path: str = "",
db_host: str = "",
db_port: int = 0,
db_user: str = "",
db_pwd: str = "",
comment: str = "",
):
"""update db connect info"""
old_db_conf = self.get_db_config(db_name)
if old_db_conf:
try:
session = self.get_session()
if not db_path:
update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
)
else:
update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
)
session.execute(update_statement)
session.commit()
session.close()
except Exception as e:
print("edit db connect info error" + str(e))
return True
raise ValueError(f"{db_name} not have config info!")
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
"""add file db connect info"""
try:
session = self.get_session()
insert_statement = text(
"""
INSERT INTO connect_config(
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
) VALUES (
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
)
"""
)
params = {
"db_name": db_name,
"db_type": db_type,
"db_path": db_path,
"db_host": "",
"db_port": 0,
"db_user": "",
"db_pwd": "",
"comment": comment,
}
session.execute(insert_statement, params)
session.commit()
session.close()
except Exception as e:
print("add db connect info error" + str(e))
def get_db_config(self, db_name):
"""get db config by name"""
session = self.get_session()
if db_name:
select_statement = text(
"""
SELECT
*
FROM
connect_config
WHERE
db_name = :db_name
"""
)
params = {"db_name": db_name}
result = session.execute(select_statement, params)
else:
raise ValueError("Cannot get database by name" + db_name)
fields = [field[0] for field in result.cursor.description]
row_dict = {}
row_1 = list(result.cursor.fetchall()[0])
for i, field in enumerate(fields):
row_dict[field] = row_1[i]
return row_dict
def get_db_list(self):
"""get db list"""
session = self.get_session()
result = session.execute(text("SELECT * FROM connect_config"))
fields = [field[0] for field in result.cursor.description]
data = []
for row in result.cursor.fetchall():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
def delete_db(self, db_name):
"""delete db connect info"""
session = self.get_session()
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
params = {"db_name": db_name}
session.execute(delete_statement, params)
session.commit()
session.close()
return True

View File

@@ -1,147 +0,0 @@
import os
import duckdb
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
table_name = "connect_config"
class DuckdbConnectConfig:
def __init__(self):
os.makedirs(default_db_path, exist_ok=True)
self.connect = duckdb.connect(duckdb_path)
self.__init_config_tables()
def __init_config_tables(self):
# check config table
result = self.connect.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result:
# create config table
self.connect.execute(
"CREATE TABLE connect_config (id integer primary key, db_name VARCHAR(100) UNIQUE, db_type VARCHAR(50), db_path VARCHAR(255) NULL, db_host VARCHAR(255) NULL, db_port INTEGER NULL, db_user VARCHAR(255) NULL, db_pwd VARCHAR(255) NULL, comment TEXT NULL)"
)
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
def add_url_db(
self,
db_name,
db_type,
db_host: str,
db_port: int,
db_user: str,
db_pwd: str,
comment: str = "",
):
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
[db_name, db_type, "", db_host, db_port, db_user, db_pwd, comment],
)
cursor.commit()
self.connect.commit()
except Exception as e:
print("add db connect info error1" + str(e))
def update_db_info(
self,
db_name,
db_type,
db_path: str = "",
db_host: str = "",
db_port: int = 0,
db_user: str = "",
db_pwd: str = "",
comment: str = "",
):
old_db_conf = self.get_db_config(db_name)
if old_db_conf:
try:
cursor = self.connect.cursor()
if not db_path:
cursor.execute(
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
)
else:
cursor.execute(
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
)
cursor.commit()
self.connect.commit()
except Exception as e:
print("edit db connect info error2" + str(e))
return True
raise ValueError(f"{db_name} not have config info!")
def get_file_db_name(self, path):
try:
conn = duckdb.connect(path)
result = conn.execute("SELECT current_database()").fetchone()[0]
return result
except Exception as e:
raise "Unusable duckdb database path:" + path
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
[db_name, db_type, db_path, "", 0, "", "", comment],
)
cursor.commit()
self.connect.commit()
except Exception as e:
print("add db connect info error2" + str(e))
def delete_db(self, db_name):
cursor = self.connect.cursor()
cursor.execute("DELETE FROM connect_config where db_name=?", [db_name])
cursor.commit()
return True
def get_db_config(self, db_name):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if db_name:
cursor.execute(
"SELECT * FROM connect_config where db_name=? ", [db_name]
)
else:
raise ValueError("Cannot get database by name" + db_name)
fields = [field[0] for field in cursor.description]
row_dict = {}
row_1 = list(cursor.fetchall()[0])
for i, field in enumerate(fields):
row_dict[field] = row_1[i]
return row_dict
return None
def get_db_list(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute("SELECT * FROM connect_config ")
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
return []
def get_db_names(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute("SELECT db_name FROM connect_config ")
data = []
for row in cursor.fetchall():
data.append(row[0])
return data
return []

View File

@@ -1,160 +0,0 @@
import threading
import asyncio
from pilot.configs.config import Config
from pilot.connections import ConnectConfigDao
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect
from pilot.connections.rdbms.conn_postgresql import PostgreSQLDatabase
from pilot.connections.rdbms.conn_starrocks import StarRocksConnect
from pilot.connections.rdbms.conn_doris import DorisConnect
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
from pilot.connections.db_conn_info import DBConfig
from pilot.connections.conn_spark import SparkConnect
from pilot.summary.db_summary_client import DBSummaryClient
CFG = Config()
class ConnectManager:
"""db connect manager"""
def get_all_subclasses(self, cls):
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += self.get_all_subclasses(subclass)
return subclasses
def get_all_completed_types(self):
chat_classes = self.get_all_subclasses(BaseConnect)
support_types = []
for cls in chat_classes:
if cls.db_type:
support_types.append(DBType.of_db_type(cls.db_type))
return support_types
def get_cls_by_dbtype(self, db_type):
chat_classes = self.get_all_subclasses(BaseConnect)
result = None
for cls in chat_classes:
if cls.db_type == db_type:
result = cls
if not result:
raise ValueError("Unsupported Db Type" + db_type)
return result
def __init__(self, system_app: SystemApp):
"""metadata database management initialization"""
# self.storage = DuckdbConnectConfig()
self.storage = ConnectConfigDao()
self.db_summary_client = DBSummaryClient(system_app)
def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name)
db_type = DBType.of_db_type(db_config.get("db_type"))
connect_instance = self.get_cls_by_dbtype(db_type.value())
if db_type.is_file_db():
db_path = db_config.get("db_path")
return connect_instance.from_file_path(db_path)
else:
db_host = db_config.get("db_host")
db_port = db_config.get("db_port")
db_user = db_config.get("db_user")
db_pwd = db_config.get("db_pwd")
return connect_instance.from_uri_db(
host=db_host, port=db_port, user=db_user, pwd=db_pwd, db_name=db_name
)
def test_connect(self, db_info: DBConfig):
try:
db_type = DBType.of_db_type(db_info.db_type)
connect_instance = self.get_cls_by_dbtype(db_type.value())
if db_type.is_file_db():
db_path = db_info.file_path
return connect_instance.from_file_path(db_path)
else:
db_name = db_info.db_name
db_host = db_info.db_host
db_port = db_info.db_port
db_user = db_info.db_user
db_pwd = db_info.db_pwd
return connect_instance.from_uri_db(
host=db_host,
port=db_port,
user=db_user,
pwd=db_pwd,
db_name=db_name,
)
except Exception as e:
print(f"{db_info.db_name} Test connect Failure!{str(e)}")
raise ValueError(f"{db_info.db_name} Test connect Failure!{str(e)}")
def get_db_list(self):
return self.storage.get_db_list()
def get_db_names(self):
return self.storage.get_by_name()
def delete_db(self, db_name: str):
return self.storage.delete_db(db_name)
def edit_db(self, db_info: DBConfig):
return self.storage.update_db_info(
db_info.db_name,
db_info.db_type,
db_info.file_path,
db_info.db_host,
db_info.db_port,
db_info.db_user,
db_info.db_pwd,
db_info.comment,
)
async def async_db_summary_embedding(self, db_name, db_type):
# 在这里执行需要异步运行的代码
self.db_summary_client.db_summary_embedding(db_name, db_type)
def add_db(self, db_info: DBConfig):
print(f"add_db:{db_info.__dict__}")
try:
db_type = DBType.of_db_type(db_info.db_type)
if db_type.is_file_db():
self.storage.add_file_db(
db_info.db_name, db_info.db_type, db_info.file_path
)
else:
self.storage.add_url_db(
db_info.db_name,
db_info.db_type,
db_info.db_host,
db_info.db_port,
db_info.db_user,
db_info.db_pwd,
db_info.comment,
)
# async embedding
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(
self.db_summary_client.db_summary_embedding,
db_info.db_name,
db_info.db_type,
)
except Exception as e:
raise ValueError("Add db connect info error!" + str(e))
return True

Some files were not shown because too many files have changed in this diff Show More