mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-19 00:14:40 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
@@ -1,12 +0,0 @@
|
||||
# Old packages
|
||||
# __all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
|
||||
|
||||
__all__ = ["embedding_engine"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
import importlib
|
||||
|
||||
if name in ["embedding_engine"]:
|
||||
return importlib.import_module("." + name, __name__)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
@@ -1,87 +0,0 @@
|
||||
"""Agentic Workflow Expression Language (AWEL)
|
||||
|
||||
Note:
|
||||
|
||||
AWEL is still an experimental feature and only opens the lowest level API.
|
||||
The stability of this API cannot be guaranteed at present.
|
||||
|
||||
"""
|
||||
|
||||
from pilot.component import SystemApp
|
||||
|
||||
from .dag.base import DAGContext, DAG
|
||||
|
||||
from .operator.base import BaseOperator, WorkflowRunner
|
||||
from .operator.common_operator import (
|
||||
JoinOperator,
|
||||
ReduceStreamOperator,
|
||||
MapOperator,
|
||||
BranchOperator,
|
||||
InputOperator,
|
||||
BranchFunc,
|
||||
)
|
||||
|
||||
from .operator.stream_operator import (
|
||||
StreamifyAbsOperator,
|
||||
UnstreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
|
||||
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
|
||||
from .task.task_impl import (
|
||||
SimpleInputSource,
|
||||
SimpleCallDataInputSource,
|
||||
DefaultTaskContext,
|
||||
DefaultInputContext,
|
||||
SimpleTaskOutput,
|
||||
SimpleStreamTaskOutput,
|
||||
_is_async_iterator,
|
||||
)
|
||||
from .trigger.http_trigger import HttpTrigger
|
||||
from .runner.local_runner import DefaultWorkflowRunner
|
||||
|
||||
__all__ = [
|
||||
"initialize_awel",
|
||||
"DAGContext",
|
||||
"DAG",
|
||||
"BaseOperator",
|
||||
"JoinOperator",
|
||||
"ReduceStreamOperator",
|
||||
"MapOperator",
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
"BranchFunc",
|
||||
"WorkflowRunner",
|
||||
"TaskState",
|
||||
"TaskOutput",
|
||||
"TaskContext",
|
||||
"InputContext",
|
||||
"InputSource",
|
||||
"DefaultWorkflowRunner",
|
||||
"SimpleInputSource",
|
||||
"SimpleCallDataInputSource",
|
||||
"DefaultTaskContext",
|
||||
"DefaultInputContext",
|
||||
"SimpleTaskOutput",
|
||||
"SimpleStreamTaskOutput",
|
||||
"StreamifyAbsOperator",
|
||||
"UnstreamifyAbsOperator",
|
||||
"TransformStreamAbsOperator",
|
||||
"HttpTrigger",
|
||||
]
|
||||
|
||||
|
||||
def initialize_awel(system_app: SystemApp, dag_filepath: str):
|
||||
from .dag.dag_manager import DAGManager
|
||||
from .dag.base import DAGVar
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
from .operator.base import initialize_runner
|
||||
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
|
||||
system_app.register(DefaultTriggerManager)
|
||||
dag_manager = DAGManager(system_app, dag_filepath)
|
||||
system_app.register_instance(dag_manager)
|
||||
initialize_runner(DefaultWorkflowRunner())
|
||||
# Load all dags
|
||||
dag_manager.load_dags()
|
@@ -1,7 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
@@ -1,364 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Sequence, Union, Any, Set
|
||||
import uuid
|
||||
import contextvars
|
||||
import threading
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from functools import cache
|
||||
from concurrent.futures import Executor
|
||||
|
||||
from pilot.component import SystemApp
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
|
||||
|
||||
|
||||
def _is_async_context():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return asyncio.current_task(loop=loop) is not None
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
class DependencyMixin(ABC):
|
||||
@abstractmethod
|
||||
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Set one or more upstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Set one or more downstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
"""
|
||||
|
||||
def __lshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self << nodes
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
"""
|
||||
self.set_upstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self >> nodes
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
|
||||
"""
|
||||
self.set_downstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] >> self"""
|
||||
self.__lshift__(nodes)
|
||||
return self
|
||||
|
||||
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] << self"""
|
||||
self.__rshift__(nodes)
|
||||
return self
|
||||
|
||||
|
||||
class DAGVar:
|
||||
_thread_local = threading.local()
|
||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||
_system_app: SystemApp = None
|
||||
_executor: Executor = None
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
stack.append(dag)
|
||||
cls._async_local.set(stack)
|
||||
else:
|
||||
if not hasattr(cls._thread_local, "current_dag_stack"):
|
||||
cls._thread_local.current_dag_stack = deque()
|
||||
cls._thread_local.current_dag_stack.append(dag)
|
||||
|
||||
@classmethod
|
||||
def exit_dag(cls) -> None:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
if stack:
|
||||
stack.pop()
|
||||
cls._async_local.set(stack)
|
||||
else:
|
||||
if (
|
||||
hasattr(cls._thread_local, "current_dag_stack")
|
||||
and cls._thread_local.current_dag_stack
|
||||
):
|
||||
cls._thread_local.current_dag_stack.pop()
|
||||
|
||||
@classmethod
|
||||
def get_current_dag(cls) -> Optional["DAG"]:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
return stack[-1] if stack else None
|
||||
else:
|
||||
if (
|
||||
hasattr(cls._thread_local, "current_dag_stack")
|
||||
and cls._thread_local.current_dag_stack
|
||||
):
|
||||
return cls._thread_local.current_dag_stack[-1]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_current_system_app(cls) -> SystemApp:
|
||||
if not cls._system_app:
|
||||
raise RuntimeError("System APP not set for DAGVar")
|
||||
return cls._system_app
|
||||
|
||||
@classmethod
|
||||
def set_current_system_app(cls, system_app: SystemApp) -> None:
|
||||
if cls._system_app:
|
||||
logger.warn("System APP has already set, nothing to do")
|
||||
else:
|
||||
cls._system_app = system_app
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls) -> Executor:
|
||||
return cls._executor
|
||||
|
||||
@classmethod
|
||||
def set_executor(cls, executor: Executor) -> None:
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dag: Optional["DAG"] = None,
|
||||
node_id: Optional[str] = None,
|
||||
node_name: Optional[str] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
|
||||
self._system_app: Optional[SystemApp] = (
|
||||
system_app or DAGVar.get_current_system_app()
|
||||
)
|
||||
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
self._node_name: str = node_name
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
return self._system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if self.node_id:
|
||||
return hash(self.node_id)
|
||||
else:
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DAGNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return self._node_name
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAG":
|
||||
return self._dag
|
||||
|
||||
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
self.set_dependency(nodes)
|
||||
|
||||
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
self.set_dependency(nodes, is_upstream=False)
|
||||
|
||||
@property
|
||||
def upstream(self) -> List["DAGNode"]:
|
||||
return self._upstream
|
||||
|
||||
@property
|
||||
def downstream(self) -> List["DAGNode"]:
|
||||
return self._downstream
|
||||
|
||||
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
|
||||
if not isinstance(nodes, Sequence):
|
||||
nodes = [nodes]
|
||||
if not all(isinstance(node, DAGNode) for node in nodes):
|
||||
raise ValueError(
|
||||
"all nodes to set dependency to current node must be instance of 'DAGNode'"
|
||||
)
|
||||
nodes: Sequence[DAGNode] = nodes
|
||||
dags = set([node.dag for node in nodes if node.dag])
|
||||
if self.dag:
|
||||
dags.add(self.dag)
|
||||
if not dags:
|
||||
raise ValueError("set dependency to current node must in a DAG context")
|
||||
if len(dags) != 1:
|
||||
raise ValueError(
|
||||
"set dependency to current node just support in one DAG context"
|
||||
)
|
||||
dag = dags.pop()
|
||||
self._dag = dag
|
||||
|
||||
dag._append_node(self)
|
||||
for node in nodes:
|
||||
if is_upstream and node not in self.upstream:
|
||||
node._dag = dag
|
||||
dag._append_node(node)
|
||||
|
||||
self._upstream.append(node)
|
||||
node._downstream.append(self)
|
||||
elif node not in self._downstream:
|
||||
node._dag = dag
|
||||
dag._append_node(node)
|
||||
|
||||
self._downstream.append(node)
|
||||
node._upstream.append(self)
|
||||
|
||||
|
||||
class DAGContext:
|
||||
def __init__(self) -> None:
|
||||
self._curr_task_ctx = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
return self._curr_task_ctx
|
||||
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
async def get_share_data(self, key: str) -> Any:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||
self._share_data[key] = data
|
||||
|
||||
|
||||
class DAG:
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
) -> None:
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: Set[DAGNode] = None
|
||||
self._leaf_nodes: Set[DAGNode] = None
|
||||
self._trigger_nodes: Set[DAGNode] = None
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
self.node_map[node.node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
self._leaf_nodes = None
|
||||
|
||||
def _new_node_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def dag_id(self) -> str:
|
||||
return self._dag_id
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
nodes = set()
|
||||
for _, node in self.node_map.items():
|
||||
nodes = nodes.union(_get_nodes(node))
|
||||
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
|
||||
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
|
||||
self._trigger_nodes = list(
|
||||
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
|
||||
)
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
if not self._root_nodes:
|
||||
self._build()
|
||||
return self._root_nodes
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
if not self._leaf_nodes:
|
||||
self._build()
|
||||
return self._leaf_nodes
|
||||
|
||||
@property
|
||||
def trigger_nodes(self):
|
||||
if not self._trigger_nodes:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
DAGVar.exit_dag()
|
||||
|
||||
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
|
||||
nodes = set()
|
||||
if not node:
|
||||
return nodes
|
||||
nodes.add(node)
|
||||
stream_nodes = node.upstream if is_upstream else node.downstream
|
||||
for node in stream_nodes:
|
||||
nodes = nodes.union(_get_nodes(node, is_upstream))
|
||||
return nodes
|
@@ -1,42 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
from .loader import DAGLoader, LocalFileDAGLoader
|
||||
from .base import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGManager(BaseComponent):
|
||||
name = ComponentType.AWEL_DAG_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp, dag_filepath: str):
|
||||
super().__init__(system_app)
|
||||
self.dag_loader = LocalFileDAGLoader(dag_filepath)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def load_dags(self):
|
||||
dags = self.dag_loader.load_dags()
|
||||
triggers = []
|
||||
for dag in dags:
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
|
||||
triggers += dag.trigger_nodes
|
||||
from ..trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
|
||||
ComponentType.AWEL_TRIGGER_MANAGER,
|
||||
DefaultTriggerManager,
|
||||
default_component=None,
|
||||
)
|
||||
if trigger_manager:
|
||||
for trigger in triggers:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
else:
|
||||
logger.warn("No trigger manager, not register dag trigger")
|
@@ -1,93 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
import os
|
||||
import hashlib
|
||||
import sys
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from .base import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGLoader(ABC):
|
||||
@abstractmethod
|
||||
def load_dags(self) -> List[DAG]:
|
||||
"""Load dags"""
|
||||
|
||||
|
||||
class LocalFileDAGLoader(DAGLoader):
|
||||
def __init__(self, filepath: str) -> None:
|
||||
super().__init__()
|
||||
self._filepath = filepath
|
||||
|
||||
def load_dags(self) -> List[DAG]:
|
||||
if not os.path.exists(self._filepath):
|
||||
return []
|
||||
if os.path.isdir(self._filepath):
|
||||
return _process_directory(self._filepath)
|
||||
else:
|
||||
return _process_file(self._filepath)
|
||||
|
||||
|
||||
def _process_directory(directory: str) -> List[DAG]:
|
||||
dags = []
|
||||
for file in os.listdir(directory):
|
||||
if file.endswith(".py"):
|
||||
filepath = os.path.join(directory, file)
|
||||
dags += _process_file(filepath)
|
||||
return dags
|
||||
|
||||
|
||||
def _process_file(filepath) -> List[DAG]:
|
||||
mods = _load_modules_from_file(filepath)
|
||||
results = _process_modules(mods)
|
||||
return results
|
||||
|
||||
|
||||
def _load_modules_from_file(filepath: str):
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
|
||||
logger.info(f"Importing {filepath}")
|
||||
|
||||
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
|
||||
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
|
||||
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
|
||||
|
||||
if mod_name in sys.modules:
|
||||
del sys.modules[mod_name]
|
||||
|
||||
def parse(mod_name, filepath):
|
||||
try:
|
||||
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
|
||||
spec = importlib.util.spec_from_loader(mod_name, loader)
|
||||
new_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = new_module
|
||||
loader.exec_module(new_module)
|
||||
return [new_module]
|
||||
except Exception as e:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to import: {filepath}, error message: {msg}")
|
||||
# TODO save error message
|
||||
return []
|
||||
|
||||
return parse(mod_name, filepath)
|
||||
|
||||
|
||||
def _process_modules(mods) -> List[DAG]:
|
||||
top_level_dags = (
|
||||
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
|
||||
)
|
||||
found_dags = []
|
||||
for dag, mod in top_level_dags:
|
||||
try:
|
||||
# TODO validate dag params
|
||||
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
|
||||
found_dags.append(dag)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to dag file, error message: {msg}")
|
||||
return found_dags
|
@@ -1,51 +0,0 @@
|
||||
import pytest
|
||||
import threading
|
||||
import asyncio
|
||||
from ..dag import DAG, DAGContext
|
||||
|
||||
|
||||
def test_dag_context_sync():
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
with dag1:
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
with dag2:
|
||||
assert DAGContext.get_current_dag() == dag2
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
assert DAGContext.get_current_dag() is None
|
||||
|
||||
|
||||
def test_dag_context_threading():
|
||||
def thread_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
thread1 = threading.Thread(target=thread_function, args=(dag1,))
|
||||
thread2 = threading.Thread(target=thread_function, args=(dag2,))
|
||||
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_context_async():
|
||||
async def async_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
await asyncio.gather(async_function(dag1), async_function(dag2))
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
@@ -1,206 +0,0 @@
|
||||
from abc import ABC, abstractmethod, ABCMeta
|
||||
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
import functools
|
||||
from inspect import signature
|
||||
from pilot.component import SystemApp, ComponentType
|
||||
from pilot.utils.executor_utils import (
|
||||
ExecutorFactory,
|
||||
DefaultExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
BlockingFunction,
|
||||
)
|
||||
|
||||
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
|
||||
from ..task.base import (
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
TaskState,
|
||||
OUT,
|
||||
T,
|
||||
InputContext,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
F = TypeVar("F", bound=FunctionType)
|
||||
|
||||
CALL_DATA = Union[Dict, Dict[str, Dict]]
|
||||
|
||||
|
||||
class WorkflowRunner(ABC, Generic[T]):
|
||||
"""Abstract base class representing a runner for executing workflows in a DAG.
|
||||
|
||||
This class defines the interface for executing workflows within the DAG,
|
||||
handling the flow from one DAG node to another.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_workflow(
|
||||
self, node: "BaseOperator", call_data: Optional[CALL_DATA] = None
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow starting from a given operator.
|
||||
|
||||
Args:
|
||||
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||
"""
|
||||
|
||||
|
||||
default_runner: WorkflowRunner = None
|
||||
|
||||
|
||||
class BaseOperatorMeta(ABCMeta):
|
||||
"""Metaclass of BaseOperator."""
|
||||
|
||||
@classmethod
|
||||
def _apply_defaults(cls, func: F) -> F:
|
||||
sig_cache = signature(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
|
||||
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
|
||||
task_id: Optional[str] = kwargs.get("task_id")
|
||||
system_app: Optional[SystemApp] = (
|
||||
kwargs.get("system_app") or DAGVar.get_current_system_app()
|
||||
)
|
||||
executor = kwargs.get("executor") or DAGVar.get_executor()
|
||||
if not executor:
|
||||
if system_app:
|
||||
executor = system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
else:
|
||||
executor = DefaultExecutorFactory().create()
|
||||
DAGVar.set_executor(executor)
|
||||
|
||||
if not task_id and dag:
|
||||
task_id = dag._new_node_id()
|
||||
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
|
||||
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
|
||||
# for arg in sig_cache.parameters:
|
||||
# if arg not in kwargs:
|
||||
# kwargs[arg] = default_args[arg]
|
||||
if not kwargs.get("dag"):
|
||||
kwargs["dag"] = dag
|
||||
if not kwargs.get("task_id"):
|
||||
kwargs["task_id"] = task_id
|
||||
if not kwargs.get("runner"):
|
||||
kwargs["runner"] = runner
|
||||
if not kwargs.get("system_app"):
|
||||
kwargs["system_app"] = system_app
|
||||
if not kwargs.get("executor"):
|
||||
kwargs["executor"] = executor
|
||||
real_obj = func(self, *args, **kwargs)
|
||||
return real_obj
|
||||
|
||||
return cast(T, apply_defaults)
|
||||
|
||||
def __new__(cls, name, bases, namespace, **kwargs):
|
||||
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
|
||||
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
|
||||
return new_cls
|
||||
|
||||
|
||||
class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
"""Abstract base class for operator nodes that can be executed within a workflow.
|
||||
|
||||
This class extends DAGNode by adding execution capabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
task_name: Optional[str] = None,
|
||||
dag: Optional[DAG] = None,
|
||||
runner: WorkflowRunner = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initializes a BaseOperator with an optional workflow runner.
|
||||
|
||||
Args:
|
||||
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
||||
"""
|
||||
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
|
||||
if not runner:
|
||||
from pilot.awel import DefaultWorkflowRunner
|
||||
|
||||
runner = DefaultWorkflowRunner()
|
||||
|
||||
self._runner: WorkflowRunner = runner
|
||||
self._dag_ctx: DAGContext = None
|
||||
|
||||
@property
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
return self._dag_ctx
|
||||
|
||||
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
if not self.node_id:
|
||||
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
|
||||
self._dag_ctx = dag_ctx
|
||||
return await self._do_run(dag_ctx)
|
||||
|
||||
@abstractmethod
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""
|
||||
Abstract method to run the task within the DAG node.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
|
||||
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
|
||||
This method is a high-level wrapper for executing the node.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
return out_ctx.current_task_context.task_output.output
|
||||
|
||||
async def call_stream(
|
||||
self, call_data: Optional[CALL_DATA] = None
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
||||
This method is used for nodes where the output is a stream.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
async def blocking_func_to_async(
|
||||
self, func: BlockingFunction, *args, **kwargs
|
||||
) -> Any:
|
||||
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
|
||||
|
||||
|
||||
def initialize_runner(runner: WorkflowRunner):
|
||||
global default_runner
|
||||
default_runner = runner
|
@@ -1,246 +0,0 @@
|
||||
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import (
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
IN,
|
||||
OUT,
|
||||
InputContext,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
from .base import BaseOperator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
"""Operator that joins inputs using a custom combine function.
|
||||
|
||||
This node type is useful for combining the outputs of upstream nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, combine_function, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not callable(combine_function):
|
||||
raise ValueError("combine_function must be callable")
|
||||
self.combine_function = combine_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the join operation on the DAG context's inputs.
|
||||
Args:
|
||||
dag_ctx (DAGContext): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
input_ctx: InputContext = await curr_task_ctx.task_input.map_all(
|
||||
self.combine_function
|
||||
)
|
||||
# All join result store in the first parent output
|
||||
join_output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(join_output)
|
||||
return join_output
|
||||
|
||||
|
||||
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
def __init__(self, reduce_function=None, **kwargs):
|
||||
"""Initializes a ReduceStreamOperator with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the combine_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if reduce_function and not callable(reduce_function):
|
||||
raise ValueError("reduce_function must be callable")
|
||||
self.reduce_function = reduce_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the join operation on the DAG context's inputs.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_input = curr_task_ctx.task_input
|
||||
if not task_input.check_stream():
|
||||
raise ValueError("ReduceStreamOperator expects stream data")
|
||||
if not task_input.check_single_parent():
|
||||
raise ValueError("ReduceStreamOperator expects single parent")
|
||||
|
||||
reduce_function = self.reduce_function or self.reduce
|
||||
|
||||
input_ctx: InputContext = await task_input.reduce(reduce_function)
|
||||
# All join result store in the first parent output
|
||||
reduce_output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(reduce_output)
|
||||
return reduce_output
|
||||
|
||||
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""Map operator that applies a mapping function to its inputs.
|
||||
|
||||
This operator transforms its input data using a provided mapping function and
|
||||
passes the transformed data downstream.
|
||||
"""
|
||||
|
||||
def __init__(self, map_function=None, **kwargs):
|
||||
"""Initializes a MapDAGNode with a mapping function.
|
||||
|
||||
Args:
|
||||
map_function: A function that defines how to map the input data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the map_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if map_function and not callable(map_function):
|
||||
raise ValueError("map_function must be callable")
|
||||
self.map_function = map_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the mapping operation on the DAG context's inputs.
|
||||
|
||||
This method applies the mapping function to the input context and updates
|
||||
the DAG context with the new data.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext[IN]): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
|
||||
Raises:
|
||||
ValueError: If not a single parent or the map_function is not callable
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
if not curr_task_ctx.task_input.check_single_parent():
|
||||
num_parents = len(curr_task_ctx.task_input.parent_outputs)
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
|
||||
)
|
||||
map_function = self.map_function or self.map
|
||||
|
||||
input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function)
|
||||
# All join result store in the first parent output
|
||||
reduce_output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(reduce_output)
|
||||
return reduce_output
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
|
||||
|
||||
class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""Operator node that branches the workflow based on a provided function.
|
||||
|
||||
This node filters its input data using a branching function and
|
||||
allows for conditional paths in the workflow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
|
||||
Args:
|
||||
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
|
||||
|
||||
Raises:
|
||||
ValueError: If the branch_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if branches:
|
||||
for branch_function, value in branches.items():
|
||||
if not callable(branch_function):
|
||||
raise ValueError("branch_function must be callable")
|
||||
if isinstance(value, BaseOperator):
|
||||
branches[branch_function] = value.node_name or value.node_name
|
||||
self._branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the branching operation on the DAG context's inputs.
|
||||
|
||||
This method applies the branching function to the input context to determine
|
||||
the path of execution in the workflow.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext[IN]): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_input = curr_task_ctx.task_input
|
||||
if task_input.check_stream():
|
||||
raise ValueError("BranchDAGNode expects no stream data")
|
||||
if not task_input.check_single_parent():
|
||||
raise ValueError("BranchDAGNode expects single parent")
|
||||
|
||||
branches = self._branches
|
||||
if not branches:
|
||||
branches = await self.branchs()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[str] = []
|
||||
for func, node_name in branches.items():
|
||||
branch_nodes.append(node_name)
|
||||
branch_func_tasks.append(
|
||||
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||
)
|
||||
|
||||
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||
parent_output = task_input.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(parent_output)
|
||||
skip_node_names = []
|
||||
for i, ctx in enumerate(branch_input_ctxs):
|
||||
node_name = branch_nodes[i]
|
||||
branch_out = ctx.parent_outputs[0].task_output
|
||||
logger.info(
|
||||
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
|
||||
)
|
||||
if ctx.parent_outputs[0].task_output.is_empty:
|
||||
logger.info(f"Skip node name {node_name}")
|
||||
skip_node_names.append(node_name)
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InputOperator(BaseOperator, Generic[OUT]):
|
||||
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._input_source = input_source
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_output = await self._input_source.read(curr_task_ctx)
|
||||
curr_task_ctx.set_task_output(task_output)
|
||||
return task_output
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator, Generic[OUT]):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
from ..task.task_impl import SimpleCallDataInputSource
|
||||
|
||||
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)
|
@@ -1,90 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, AsyncIterator
|
||||
from ..task.base import OUT, IN, TaskOutput, TaskContext
|
||||
from ..dag.base import DAGContext
|
||||
from .base import BaseOperator
|
||||
|
||||
|
||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
|
||||
self.streamify
|
||||
)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
|
||||
"""Convert a value of IN to an AsyncIterator[OUT]
|
||||
|
||||
Args:
|
||||
input_value (IN): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]
|
||||
for i in range(input_value):
|
||||
yield i
|
||||
"""
|
||||
|
||||
|
||||
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.unstreamify(self.unstreamify)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
"""Convert a value of AsyncIterator[IN] to an OUT.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int
|
||||
value_cnt = 0
|
||||
async for v in input_value:
|
||||
value_cnt += 1
|
||||
return value_cnt
|
||||
"""
|
||||
|
||||
|
||||
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.transform_stream(self.transform_stream)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[IN]
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
|
||||
async for v in input_value:
|
||||
yield v + 1
|
||||
"""
|
@@ -1,8 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ResourceGroup(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of current resource group"""
|
@@ -1,82 +0,0 @@
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGNodeInstance:
|
||||
def __init__(self, node_instance: DAG) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DAGInstance:
|
||||
def __init__(self, dag: DAG) -> None:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager:
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
) -> None:
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
self._id2node_data = id2call_data
|
||||
|
||||
@staticmethod
|
||||
def build_from_end_node(
|
||||
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> "JobManager":
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
) -> Dict[str, Dict]:
|
||||
id2call_data = {}
|
||||
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
|
||||
if not call_data:
|
||||
return id2call_data
|
||||
if len(root_nodes) == 1:
|
||||
node = root_nodes[0]
|
||||
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
|
||||
id2call_data[node.node_id] = call_data
|
||||
else:
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
logger.info(
|
||||
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||
)
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
return id2call_data
|
||||
|
||||
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
if not task_id:
|
||||
task_id = str(uuid.uuid4())
|
||||
end_node.set_node_id(task_id)
|
||||
nodes.append(end_node)
|
||||
for node in end_node.upstream:
|
||||
nodes += _build_from_end_node(node)
|
||||
return nodes
|
||||
|
||||
|
||||
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
|
||||
return list(set(filter(lambda x: not x.upstream, nodes)))
|
@@ -1,106 +0,0 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
async def execute_workflow(
|
||||
self, node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> DAGContext:
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext()
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
skip_node_ids = set()
|
||||
await self._execute_node(
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||
)
|
||||
|
||||
return dag_ctx
|
||||
|
||||
async def _execute_node(
|
||||
self,
|
||||
job_manager: JobManager,
|
||||
node: BaseOperator,
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
return
|
||||
|
||||
# Run all upstream node
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||
)
|
||||
|
||||
inputs = [
|
||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||
]
|
||||
input_ctx = DefaultInputContext(inputs)
|
||||
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
|
||||
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
|
||||
|
||||
task_ctx.set_task_input(input_ctx)
|
||||
dag_ctx.set_current_task_context(task_ctx)
|
||||
task_ctx.set_current_state(TaskState.RUNNING)
|
||||
|
||||
if node.node_id in skip_node_ids:
|
||||
task_ctx.set_current_state(TaskState.SKIP)
|
||||
task_ctx.set_task_output(SimpleTaskOutput(None))
|
||||
node_outputs[node.node_id] = task_ctx
|
||||
return
|
||||
try:
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
|
||||
if isinstance(node, BranchOperator):
|
||||
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
|
||||
logger.debug(
|
||||
f"Current is branch operator, skip node names: {skip_nodes}"
|
||||
)
|
||||
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
||||
except Exception as e:
|
||||
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
|
||||
task_ctx.set_current_state(TaskState.FAILED)
|
||||
raise e
|
||||
|
||||
|
||||
def _skip_current_downstream_by_node_name(
|
||||
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
|
||||
):
|
||||
if not skip_nodes:
|
||||
return
|
||||
for child in branch_node.downstream:
|
||||
if child.node_name in skip_nodes:
|
||||
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
||||
|
||||
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||
if isinstance(node, JoinOperator):
|
||||
# Not skip join node
|
||||
return
|
||||
skip_node_ids.add(node.node_id)
|
||||
for child in node.downstream:
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
@@ -1,367 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TypeVar,
|
||||
Generic,
|
||||
Optional,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
Callable,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
|
||||
IN = TypeVar("IN")
|
||||
OUT = TypeVar("OUT")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""Enumeration representing the state of a task in the workflow.
|
||||
|
||||
This Enum defines various states a task can be in during its lifecycle in the DAG.
|
||||
"""
|
||||
|
||||
INIT = "init" # Initial state of the task, not yet started
|
||||
SKIP = "skip" # State indicating the task was skipped
|
||||
RUNNING = "running" # State indicating the task is currently running
|
||||
SUCCESS = "success" # State indicating the task completed successfully
|
||||
FAILED = "failed" # State indicating the task failed during execution
|
||||
|
||||
|
||||
class TaskOutput(ABC, Generic[T]):
|
||||
"""Abstract base class representing the output of a task.
|
||||
|
||||
This class encapsulates the output of a task and provides methods to access the output data.
|
||||
It can be subclassed to implement specific output behaviors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
"""Check if the output is a stream.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is a stream, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the output is empty.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is empty, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> Optional[T]:
|
||||
"""Return the output of the task.
|
||||
|
||||
Returns:
|
||||
T: The output of the task. None if the output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_stream(self) -> Optional[AsyncIterator[T]]:
|
||||
"""Return the output of the task as an asynchronous stream.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
|
||||
"""Set the output data to current object.
|
||||
|
||||
Args:
|
||||
output_data (Union[T, AsyncIterator[T]]): Output data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_output(self) -> "TaskOutput[T]":
|
||||
"""Create new output object"""
|
||||
|
||||
async def map(self, map_func) -> "TaskOutput[T]":
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func: A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the mapping function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reduce(self, reduce_func) -> "TaskOutput[T]":
|
||||
"""Apply a reducing function to the task's output.
|
||||
|
||||
Stream TaskOutput to Nonstream TaskOutput.
|
||||
|
||||
Args:
|
||||
reduce_func: A reducing function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
"""Check if current output meets a given condition.
|
||||
|
||||
Args:
|
||||
condition_func: A function to determine if the condition is met.
|
||||
Returns:
|
||||
bool: True if current output meet the condition, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TaskContext(ABC, Generic[T]):
|
||||
"""Abstract base class representing the context of a task within a DAG.
|
||||
|
||||
This class provides the interface for accessing task-related information
|
||||
and manipulating task output.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_id(self) -> str:
|
||||
"""Return the unique identifier of the task.
|
||||
|
||||
Returns:
|
||||
str: The unique identifier of the task.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_input(self) -> "InputContext":
|
||||
"""Return the InputContext of current task.
|
||||
|
||||
Returns:
|
||||
InputContext: The InputContext of current task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
"""Set the InputContext object to current task.
|
||||
|
||||
Args:
|
||||
input_ctx (InputContext): The InputContext of current task
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_output(self) -> TaskOutput[T]:
|
||||
"""Return the output object of the task.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The output object of the task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_task_output(self, task_output: TaskOutput[T]) -> None:
|
||||
"""Set the output object to current task."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_state(self) -> TaskState:
|
||||
"""Get the current state of the task.
|
||||
|
||||
Returns:
|
||||
TaskState: The current state of the task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Set current task state
|
||||
|
||||
Args:
|
||||
task_state (TaskState): The task state to be set.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_ctx(self) -> "TaskContext":
|
||||
"""Create new task context
|
||||
|
||||
Returns:
|
||||
TaskContext: A new instance of a TaskContext.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Get the metadata of current task
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata
|
||||
"""
|
||||
|
||||
def update_metadata(self, key: str, value: Any) -> None:
|
||||
"""Update metadata with key and value
|
||||
|
||||
Args:
|
||||
key (str): The key of metadata
|
||||
value (str): The value to be add to metadata
|
||||
"""
|
||||
self.metadata[key] = value
|
||||
|
||||
@property
|
||||
def call_data(self) -> Optional[Dict]:
|
||||
"""Get the call data for current data"""
|
||||
return self.metadata.get("call_data")
|
||||
|
||||
def set_call_data(self, call_data: Dict) -> None:
|
||||
"""Set call data for current task"""
|
||||
self.update_metadata("call_data", call_data)
|
||||
|
||||
|
||||
class InputContext(ABC):
|
||||
"""Abstract base class representing the context of inputs to a operator node.
|
||||
|
||||
This class defines methods to manipulate and access the inputs for a operator node.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
"""Get the outputs from the parent nodes.
|
||||
|
||||
Returns:
|
||||
List[TaskContext]: A list of contexts of the parent nodes' outputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
|
||||
"""Apply a mapping function to the inputs.
|
||||
|
||||
Args:
|
||||
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the mapped inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
|
||||
"""Apply a mapping function to all inputs.
|
||||
|
||||
Args:
|
||||
map_func (Callable[..., Any]): A function to be applied to all inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the mapped inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
|
||||
"""Apply a reducing function to the inputs.
|
||||
|
||||
Args:
|
||||
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the reduced inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
|
||||
"""Filter the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the filtered inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Predicate the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate function is False
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the predicate inputs.
|
||||
"""
|
||||
|
||||
def check_single_parent(self) -> bool:
|
||||
"""Check if there is only a single parent output.
|
||||
|
||||
Returns:
|
||||
bool: True if there is only one parent output, False otherwise.
|
||||
"""
|
||||
return len(self.parent_outputs) == 1
|
||||
|
||||
def check_stream(self, skip_empty: bool = False) -> bool:
|
||||
"""Check if all parent outputs are streams.
|
||||
|
||||
Args:
|
||||
skip_empty (bool): Skip empty output or not.
|
||||
|
||||
Returns:
|
||||
bool: True if all parent outputs are streams, False otherwise.
|
||||
"""
|
||||
for out in self.parent_outputs:
|
||||
if out.task_output.is_empty and skip_empty:
|
||||
continue
|
||||
if not (out.task_output and out.task_output.is_stream):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class InputSource(ABC, Generic[T]):
|
||||
"""Abstract base class representing the source of inputs to a DAG node."""
|
||||
|
||||
@abstractmethod
|
||||
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
|
||||
"""Read the data from current input source.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The output object read from current source
|
||||
"""
|
@@ -1,339 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Any,
|
||||
Tuple,
|
||||
Dict,
|
||||
Union,
|
||||
)
|
||||
import asyncio
|
||||
import logging
|
||||
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||
# Init accumulator
|
||||
try:
|
||||
accumulator = await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise ValueError("Stream is empty")
|
||||
is_async = asyncio.iscoroutinefunction(reduce_function)
|
||||
async for element in stream:
|
||||
if is_async:
|
||||
accumulator = await reduce_function(accumulator, element)
|
||||
else:
|
||||
accumulator = reduce_function(accumulator, element)
|
||||
return accumulator
|
||||
|
||||
|
||||
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: T) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
return self._data
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleTaskOutput(None)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
|
||||
async def _apply_func(self, func) -> Any:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
out = await func(self._data)
|
||||
else:
|
||||
out = func(self._data)
|
||||
return out
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
out = await self._apply_func(map_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
return await self._apply_func(condition_func)
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
out = await self._apply_func(transform_func)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: AsyncIterator[T]) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
|
||||
@property
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
return self._data
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleStreamTaskOutput(None)
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
is_async = asyncio.iscoroutinefunction(map_func)
|
||||
|
||||
async def new_iter() -> AsyncIterator[T]:
|
||||
async for out in self._data:
|
||||
if is_async:
|
||||
out = await map_func(out)
|
||||
else:
|
||||
out = map_func(out)
|
||||
yield out
|
||||
|
||||
return SimpleStreamTaskOutput(new_iter())
|
||||
|
||||
async def reduce(self, reduce_func) -> TaskOutput[T]:
|
||||
out = await _reduce_stream(self._data, reduce_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> TaskOutput[T]:
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
def _is_async_iterator(obj):
|
||||
return (
|
||||
hasattr(obj, "__anext__")
|
||||
and callable(getattr(obj, "__anext__", None))
|
||||
and hasattr(obj, "__aiter__")
|
||||
and callable(getattr(obj, "__aiter__", None))
|
||||
)
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data with task context"""
|
||||
|
||||
async def read(self, task_ctx: TaskContext) -> Coroutine[Any, Any, TaskOutput]:
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output = SimpleStreamTaskOutput(data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
return output
|
||||
|
||||
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
def __init__(self, data: Any) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
return self._data
|
||||
|
||||
|
||||
class SimpleCallDataInputSource(BaseInputSource):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
call_data = task_ctx.call_data
|
||||
data = call_data.get("data") if call_data else None
|
||||
if not (call_data and data):
|
||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||
return data
|
||||
|
||||
|
||||
class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
def __init__(
|
||||
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._task_id = task_id
|
||||
self._task_state = task_state
|
||||
self._output = task_output
|
||||
self._task_input = None
|
||||
self._metadata = {}
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
return self._task_id
|
||||
|
||||
@property
|
||||
def task_input(self) -> InputContext:
|
||||
return self._task_input
|
||||
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
self._task_input = input_ctx
|
||||
|
||||
@property
|
||||
def task_output(self) -> TaskOutput:
|
||||
return self._output
|
||||
|
||||
def set_task_output(self, task_output: TaskOutput) -> None:
|
||||
self._output = task_output
|
||||
|
||||
@property
|
||||
def current_state(self) -> TaskState:
|
||||
return self._task_state
|
||||
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
self._task_state = task_state
|
||||
|
||||
def new_ctx(self) -> TaskContext:
|
||||
new_output = self._output.new_output()
|
||||
return DefaultTaskContext(self._task_id, self._task_state, new_output)
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
return self._metadata
|
||||
|
||||
|
||||
class DefaultInputContext(InputContext):
|
||||
def __init__(self, outputs: List[TaskContext]) -> None:
|
||||
super().__init__()
|
||||
self._outputs = outputs
|
||||
|
||||
@property
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
return self._outputs
|
||||
|
||||
async def _apply_func(
|
||||
self, func: Callable[[Any], Any], apply_type: str = "map"
|
||||
) -> Tuple[List[TaskContext], List[TaskOutput]]:
|
||||
new_outputs: List[TaskContext] = []
|
||||
map_tasks = []
|
||||
for out in self._outputs:
|
||||
new_outputs.append(out.new_ctx())
|
||||
result = None
|
||||
if apply_type == "map":
|
||||
result = out.task_output.map(func)
|
||||
elif apply_type == "reduce":
|
||||
result = out.task_output.reduce(func)
|
||||
elif apply_type == "check_condition":
|
||||
result = out.task_output.check_condition(func)
|
||||
else:
|
||||
raise ValueError(f"Unsupport apply type {apply_type}")
|
||||
map_tasks.append(result)
|
||||
results = await asyncio.gather(*map_tasks)
|
||||
return new_outputs, results
|
||||
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
|
||||
new_outputs, results = await self._apply_func(map_func)
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||
if not self._outputs:
|
||||
return DefaultInputContext([])
|
||||
# Some parent may be empty
|
||||
not_empty_idx = 0
|
||||
for i, p in enumerate(self._outputs):
|
||||
if p.task_output.is_empty:
|
||||
continue
|
||||
not_empty_idx = i
|
||||
break
|
||||
# All output is empty?
|
||||
is_steam = self._outputs[not_empty_idx].task_output.is_stream
|
||||
if is_steam:
|
||||
if not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
outputs = []
|
||||
for out in self._outputs:
|
||||
if out.task_output.is_stream:
|
||||
outputs.append(out.task_output.output_stream)
|
||||
else:
|
||||
outputs.append(out.task_output.output)
|
||||
if asyncio.iscoroutinefunction(map_func):
|
||||
map_res = await map_func(*outputs)
|
||||
else:
|
||||
map_res = map_func(*outputs)
|
||||
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
|
||||
single_output.task_output.set_output(map_res)
|
||||
logger.debug(
|
||||
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
|
||||
)
|
||||
return DefaultInputContext([single_output])
|
||||
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||
if not self.check_stream():
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format of stream to apply reduce function"
|
||||
)
|
||||
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
|
||||
new_outputs, results = await self._apply_func(
|
||||
filter_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
if results[i]:
|
||||
result_outputs.append(task_ctx)
|
||||
return DefaultInputContext(result_outputs)
|
||||
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
new_outputs, results = await self._apply_func(
|
||||
predicate_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
if results[i]:
|
||||
task_ctx.task_output.set_output(True)
|
||||
result_outputs.append(task_ctx)
|
||||
else:
|
||||
task_ctx.task_output.set_output(failed_value)
|
||||
result_outputs.append(task_ctx)
|
||||
return DefaultInputContext(result_outputs)
|
@@ -1,102 +0,0 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from typing import AsyncIterator, List
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from .. import (
|
||||
WorkflowRunner,
|
||||
InputOperator,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
DefaultWorkflowRunner,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from ..task.task_impl import _is_async_iterator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return DefaultWorkflowRunner()
|
||||
|
||||
|
||||
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
|
||||
iters = []
|
||||
for _ in range(num_nodes):
|
||||
|
||||
async def stream_iter():
|
||||
for i in range(10):
|
||||
yield i
|
||||
|
||||
stream_iter = stream_iter()
|
||||
assert _is_async_iterator(stream_iter)
|
||||
iters.append(stream_iter)
|
||||
return iters
|
||||
|
||||
|
||||
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
|
||||
iters = []
|
||||
for single_stream in output_streams:
|
||||
|
||||
async def stream_iter():
|
||||
for i in single_stream:
|
||||
yield i
|
||||
|
||||
stream_iter = stream_iter()
|
||||
assert _is_async_iterator(stream_iter)
|
||||
iters.append(stream_iter)
|
||||
return iters
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_input_node(**kwargs):
|
||||
num_nodes = kwargs.get("num_nodes")
|
||||
is_stream = kwargs.get("is_stream", False)
|
||||
if is_stream:
|
||||
outputs = kwargs.get("output_streams")
|
||||
if outputs:
|
||||
if num_nodes and num_nodes != len(outputs):
|
||||
raise ValueError(
|
||||
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
|
||||
)
|
||||
outputs = _create_stream_from(outputs)
|
||||
else:
|
||||
num_nodes = num_nodes or 1
|
||||
outputs = _create_stream(num_nodes)
|
||||
else:
|
||||
outputs = kwargs.get("outputs", ["Hello."])
|
||||
nodes = []
|
||||
for output in outputs:
|
||||
print(f"output: {output}")
|
||||
input_source = SimpleInputSource(output)
|
||||
input_node = InputOperator(input_source)
|
||||
nodes.append(input_node)
|
||||
yield nodes
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def input_node(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes[0]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def stream_input_node(request):
|
||||
param = getattr(request, "param", {})
|
||||
param["is_stream"] = True
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes[0]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def input_nodes(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def stream_input_nodes(request):
|
||||
param = getattr(request, "param", {})
|
||||
param["is_stream"] = True
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes
|
@@ -1,51 +0,0 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
from .. import (
|
||||
DAG,
|
||||
WorkflowRunner,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
InputOperator,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
BranchOperator,
|
||||
ReduceStreamOperator,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from .conftest import (
|
||||
runner,
|
||||
input_node,
|
||||
input_nodes,
|
||||
stream_input_node,
|
||||
stream_input_nodes,
|
||||
_is_async_iterator,
|
||||
)
|
||||
|
||||
|
||||
def _register_dag_to_fastapi_app(dag):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator):
|
||||
with DAG("test_map") as dag:
|
||||
pass
|
||||
# http_req_task = HttpRequestOperator(endpoint="/api/completions")
|
||||
# db_task = DBQueryOperator(table_name="user_info")
|
||||
# prompt_task = PromptTemplateOperator(
|
||||
# system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||
# )
|
||||
# llm_task = ChatGPTLLMOperator(model="chagpt-3.5")
|
||||
# output_parser_task = CommonOutputParserOperator()
|
||||
# http_res_task = HttpResponseOperator()
|
||||
# (
|
||||
# http_req_task
|
||||
# >> db_task
|
||||
# >> prompt_task
|
||||
# >> llm_task
|
||||
# >> output_parser_task
|
||||
# >> http_res_task
|
||||
# )
|
||||
|
||||
_register_dag_to_fastapi_app(dag)
|
@@ -1,141 +0,0 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
from .. import (
|
||||
DAG,
|
||||
WorkflowRunner,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
InputOperator,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
BranchOperator,
|
||||
ReduceStreamOperator,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from .conftest import (
|
||||
runner,
|
||||
input_node,
|
||||
input_nodes,
|
||||
stream_input_node,
|
||||
stream_input_nodes,
|
||||
_is_async_iterator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_node(runner: WorkflowRunner):
|
||||
input_node = InputOperator(SimpleInputSource("hello"))
|
||||
res: DAGContext[str] = await runner.execute_workflow(input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert res.current_task_context.task_output.output == "hello"
|
||||
|
||||
async def new_steam_iter(n: int):
|
||||
for i in range(n):
|
||||
yield i
|
||||
|
||||
num_iter = 10
|
||||
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
|
||||
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
output_steam = res.current_task_context.task_output.output_stream
|
||||
assert output_steam
|
||||
assert _is_async_iterator(output_steam)
|
||||
i = 0
|
||||
async for x in output_steam:
|
||||
assert x == i
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator):
|
||||
with DAG("test_map") as dag:
|
||||
map_node = MapOperator(lambda x: x * 2)
|
||||
stream_input_node >> map_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(map_node)
|
||||
output_steam = res.current_task_context.task_output.output_stream
|
||||
assert output_steam
|
||||
i = 0
|
||||
async for x in output_steam:
|
||||
assert x == i * 2
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"stream_input_node, expect_sum",
|
||||
[
|
||||
({"output_streams": [[0, 1, 2, 3]]}, 6),
|
||||
({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55),
|
||||
],
|
||||
indirect=["stream_input_node"],
|
||||
)
|
||||
async def test_reduce_node(
|
||||
runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int
|
||||
):
|
||||
with DAG("test_reduce_node") as dag:
|
||||
reduce_node = ReduceStreamOperator(lambda x, y: x + y)
|
||||
stream_input_node >> reduce_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(reduce_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert not res.current_task_context.task_output.is_stream
|
||||
assert res.current_task_context.task_output.output == expect_sum
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"input_nodes",
|
||||
[
|
||||
({"outputs": [0, 1, 2]}),
|
||||
],
|
||||
indirect=["input_nodes"],
|
||||
)
|
||||
async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]):
|
||||
def join_func(p1, p2, p3) -> int:
|
||||
return p1 + p2 + p3
|
||||
|
||||
with DAG("test_join_node") as dag:
|
||||
join_node = JoinOperator(join_func)
|
||||
for input_node in input_nodes:
|
||||
input_node >> join_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(join_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert not res.current_task_context.task_output.is_stream
|
||||
assert res.current_task_context.task_output.output == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"input_node, is_odd",
|
||||
[
|
||||
({"outputs": [0]}, False),
|
||||
({"outputs": [1]}, True),
|
||||
],
|
||||
indirect=["input_node"],
|
||||
)
|
||||
async def test_branch_node(
|
||||
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
|
||||
):
|
||||
def join_func(o1, o2) -> int:
|
||||
print(f"join func result, o1: {o1}, o2: {o2}")
|
||||
return o1 or o2
|
||||
|
||||
with DAG("test_join_node") as dag:
|
||||
odd_node = MapOperator(
|
||||
lambda x: 999, task_id="odd_node", task_name="odd_node_name"
|
||||
)
|
||||
even_node = MapOperator(
|
||||
lambda x: 888, task_id="even_node", task_name="even_node_name"
|
||||
)
|
||||
join_node = JoinOperator(join_func)
|
||||
branch_node = BranchOperator(
|
||||
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
|
||||
)
|
||||
branch_node >> odd_node >> join_node
|
||||
branch_node >> even_node >> join_node
|
||||
|
||||
input_node >> branch_node
|
||||
|
||||
res: DAGContext[int] = await runner.execute_workflow(join_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
expect_res = 999 if is_odd else 888
|
||||
assert res.current_task_context.task_output.output == expect_res
|
@@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
@@ -1,137 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from .base import Trigger
|
||||
from ..dag.base import DAG
|
||||
from ..operator.base import BaseOperator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
RequestBody = Union[Request, Type[BaseModel], str]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpTrigger(Trigger):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
streaming_response: Optional[bool] = False,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
status_code: Optional[int] = 200,
|
||||
router_tags: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if not endpoint.startswith("/"):
|
||||
endpoint = "/" + endpoint
|
||||
self._endpoint = endpoint
|
||||
self._methods = methods
|
||||
self._req_body = request_body
|
||||
self._streaming_response = streaming_response
|
||||
self._response_model = response_model
|
||||
self._status_code = status_code
|
||||
self._router_tags = router_tags
|
||||
self._response_headers = response_headers
|
||||
self._response_media_type = response_media_type
|
||||
self._end_node: BaseOperator = None
|
||||
|
||||
async def trigger(self) -> None:
|
||||
pass
|
||||
|
||||
def mount_to_router(self, router: "APIRouter") -> None:
|
||||
from fastapi import Depends
|
||||
|
||||
methods = self._methods if isinstance(self._methods, list) else [self._methods]
|
||||
|
||||
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||
async def _request_body_dependency(request: Request):
|
||||
return await _parse_request_body(request, self._req_body)
|
||||
|
||||
async def route_function(body=Depends(_request_body_dependency)):
|
||||
return await _trigger_dag(
|
||||
body,
|
||||
self.dag,
|
||||
self._streaming_response,
|
||||
self._response_headers,
|
||||
self._response_media_type,
|
||||
)
|
||||
|
||||
route_function.__name__ = name
|
||||
return route_function
|
||||
|
||||
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
|
||||
request_model = (
|
||||
self._req_body
|
||||
if isinstance(self._req_body, type)
|
||||
and issubclass(self._req_body, BaseModel)
|
||||
else None
|
||||
)
|
||||
dynamic_route_function = create_route_function(function_name, request_model)
|
||||
logger.info(
|
||||
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
|
||||
)
|
||||
|
||||
router.api_route(
|
||||
self._endpoint,
|
||||
methods=methods,
|
||||
response_model=self._response_model,
|
||||
status_code=self._status_code,
|
||||
tags=self._router_tags,
|
||||
)(dynamic_route_function)
|
||||
|
||||
|
||||
async def _parse_request_body(
|
||||
request: Request, request_body_cls: Optional[Type[BaseModel]]
|
||||
):
|
||||
if not request_body_cls:
|
||||
return None
|
||||
if request.method == "POST":
|
||||
json_data = await request.json()
|
||||
return request_body_cls(**json_data)
|
||||
elif request.method == "GET":
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
return request
|
||||
|
||||
|
||||
async def _trigger_dag(
|
||||
body: Any,
|
||||
dag: DAG,
|
||||
streaming_response: Optional[bool] = False,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
) -> Any:
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
end_node = dag.leaf_nodes
|
||||
if len(end_node) != 1:
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = end_node[0]
|
||||
if not streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
else:
|
||||
headers = response_headers
|
||||
media_type = response_media_type if response_media_type else "text/event-stream"
|
||||
if not headers:
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
end_node.call_stream(call_data={"data": body}),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
)
|
@@ -1,74 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TYPE_CHECKING, Optional
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
|
||||
from pilot.component import SystemApp, BaseComponent, ComponentType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerManager(ABC):
|
||||
@abstractmethod
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
""" "Register a trigger to current manager"""
|
||||
|
||||
|
||||
class HttpTriggerManager(TriggerManager):
|
||||
def __init__(
|
||||
self,
|
||||
router: Optional["APIRouter"] = None,
|
||||
router_prefix: Optional[str] = "/api/v1/awel/trigger",
|
||||
) -> None:
|
||||
if not router:
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
self._router_prefix = router_prefix
|
||||
self._router = router
|
||||
self._trigger_map = {}
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if not isinstance(trigger, HttpTrigger):
|
||||
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
|
||||
trigger: HttpTrigger = trigger
|
||||
trigger_id = trigger.node_id
|
||||
if trigger_id not in self._trigger_map:
|
||||
trigger.mount_to_router(self._router)
|
||||
self._trigger_map[trigger_id] = trigger
|
||||
|
||||
def _init_app(self, system_app: SystemApp):
|
||||
logger.info(
|
||||
f"Include router {self._router} to prefix path {self._router_prefix}"
|
||||
)
|
||||
system_app.app.include_router(
|
||||
self._router, prefix=self._router_prefix, tags=["AWEL"]
|
||||
)
|
||||
|
||||
|
||||
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
name = ComponentType.AWEL_TRIGGER_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
self.system_app = system_app
|
||||
self.http_trigger = HttpTriggerManager()
|
||||
super().__init__(None)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
logger.info(f"Register trigger {trigger}")
|
||||
self.http_trigger.register_trigger(trigger)
|
||||
else:
|
||||
raise ValueError(f"Unsupport trigger: {trigger}")
|
||||
|
||||
def after_register(self) -> None:
|
||||
self.http_trigger._init_app(self.system_app)
|
@@ -1,11 +0,0 @@
|
||||
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
|
||||
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||
|
||||
from .commands.command import execute_command, get_command
|
||||
from .commands.generator import PluginPromptGenerator
|
||||
from .commands.disply_type.show_chart_gen import static_message_img_path
|
||||
|
||||
from .common.schema import Status, PluginStorageType
|
||||
|
||||
from .commands.command_mange import ApiCall
|
||||
from .commands.command import execute_command
|
@@ -1,61 +0,0 @@
|
||||
"""Commands for converting audio to text."""
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@command(
|
||||
"read_audio_from_file",
|
||||
"Convert Audio to text",
|
||||
'"filename": "<filename>"',
|
||||
CFG.huggingface_audio_to_text_model,
|
||||
"Configure huggingface_audio_to_text_model.",
|
||||
)
|
||||
def read_audio_from_file(filename: str) -> str:
|
||||
"""
|
||||
Convert audio to text.
|
||||
|
||||
Args:
|
||||
filename (str): The path to the audio file
|
||||
|
||||
Returns:
|
||||
str: The text from the audio
|
||||
"""
|
||||
with open(filename, "rb") as audio_file:
|
||||
audio = audio_file.read()
|
||||
return read_audio(audio)
|
||||
|
||||
|
||||
def read_audio(audio: bytes) -> str:
|
||||
"""
|
||||
Convert audio to text.
|
||||
|
||||
Args:
|
||||
audio (bytes): The audio to convert
|
||||
|
||||
Returns:
|
||||
str: The text from the audio
|
||||
"""
|
||||
model = CFG.huggingface_audio_to_text_model
|
||||
api_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
api_token = CFG.huggingface_api_token
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
|
||||
if api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
data=audio,
|
||||
)
|
||||
|
||||
text = json.loads(response.content.decode("utf-8"))["text"]
|
||||
return f"The audio says: {text}"
|
@@ -1,124 +0,0 @@
|
||||
""" Image Generation Module for AutoGPT."""
|
||||
import io
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@command("generate_image", "Generate Image", '"prompt": "<prompt>"', CFG.image_provider)
|
||||
def generate_image(prompt: str, size: int = 256) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
|
||||
|
||||
# HuggingFace
|
||||
if CFG.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename)
|
||||
# SD WebUI
|
||||
elif CFG.image_provider == "sdwebui":
|
||||
return generate_image_with_sd_webui(prompt, filename, size)
|
||||
return "No Image Provider Set"
|
||||
|
||||
|
||||
def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = (
|
||||
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
|
||||
)
|
||||
if CFG.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {CFG.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
json={
|
||||
"inputs": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image.save(filename)
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
filename: str,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if CFG.sd_webui_auth:
|
||||
username, password = CFG.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(filename)
|
||||
|
||||
return f"Saved to disk:{filename}"
|
@@ -1,153 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from .exception_not_commands import NotCommands
|
||||
from .generator import PluginPromptGenerator
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
def _resolve_pathlike_command_args(command_args):
|
||||
if "directory" in command_args and command_args["directory"] in {"", "/"}:
|
||||
# todo
|
||||
command_args["directory"] = ""
|
||||
else:
|
||||
for pathlike in ["filename", "directory", "clone_path"]:
|
||||
if pathlike in command_args:
|
||||
# todo
|
||||
command_args[pathlike] = ""
|
||||
return command_args
|
||||
|
||||
|
||||
def execute_ai_response_json(
|
||||
prompt: PluginPromptGenerator,
|
||||
ai_response,
|
||||
user_input: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
command_registry:
|
||||
ai_response:
|
||||
prompt:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
cfg = Config()
|
||||
|
||||
command_name, arguments = get_command(ai_response)
|
||||
|
||||
if cfg.speak_mode:
|
||||
say_text(f"I want to execute {command_name}")
|
||||
|
||||
arguments = _resolve_pathlike_command_args(arguments)
|
||||
# Execute command
|
||||
if command_name is not None and command_name.lower().startswith("error"):
|
||||
result = f"Command {command_name} threw the following error: {arguments}"
|
||||
elif command_name == "human_feedback":
|
||||
result = f"Human feedback: {user_input}"
|
||||
else:
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_pre_command():
|
||||
continue
|
||||
command_name, arguments = plugin.pre_command(command_name, arguments)
|
||||
command_result = execute_command(
|
||||
command_name,
|
||||
arguments,
|
||||
prompt,
|
||||
)
|
||||
result = f"{command_result}"
|
||||
return result
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments,
|
||||
plugin_generator: PluginPromptGenerator,
|
||||
):
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
command_name (str): The name of the command to execute
|
||||
arguments (dict): The arguments for the command
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
"""
|
||||
|
||||
cmd = plugin_generator.command_registry.commands.get(command_name)
|
||||
|
||||
# If the command is found, call it with the provided arguments
|
||||
if cmd:
|
||||
try:
|
||||
return cmd(**arguments)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error: {str(e)}")
|
||||
# return f"Error: {str(e)}"
|
||||
# TODO: Change these to take in a file rather than pasted code, if
|
||||
# non-file is given, return instructions "Input should be a python
|
||||
# filepath, write your code to file and try again
|
||||
else:
|
||||
for command in plugin_generator.commands:
|
||||
if (
|
||||
command_name == command["label"].lower()
|
||||
or command_name == command["name"].lower()
|
||||
):
|
||||
try:
|
||||
# 删除非定义参数
|
||||
diff_ags = list(
|
||||
set(arguments.keys()).difference(set(command["args"].keys()))
|
||||
)
|
||||
for arg_name in diff_ags:
|
||||
del arguments[arg_name]
|
||||
print(str(arguments))
|
||||
return command["function"](**arguments)
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
raise NotCommands("非可用命令" + command_name)
|
||||
|
||||
|
||||
def get_command(response_json: Dict):
|
||||
"""Parse the response and return the command name and arguments
|
||||
|
||||
Args:
|
||||
response_json (json): The response from the AI
|
||||
|
||||
Returns:
|
||||
tuple: The command name and arguments
|
||||
|
||||
Raises:
|
||||
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||
|
||||
Exception: If any other error occurs
|
||||
"""
|
||||
try:
|
||||
if "command" not in response_json:
|
||||
return "Error:", "Missing 'command' object in JSON"
|
||||
|
||||
if not isinstance(response_json, dict):
|
||||
return "Error:", f"'response_json' object is not dictionary {response_json}"
|
||||
|
||||
command = response_json["command"]
|
||||
if not isinstance(command, dict):
|
||||
return "Error:", "'command' object is not a dictionary"
|
||||
|
||||
if "name" not in command:
|
||||
return "Error:", "Missing 'name' field in 'command' object"
|
||||
|
||||
command_name = command["name"]
|
||||
|
||||
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||
arguments = command.get("args", {})
|
||||
|
||||
return command_name, arguments
|
||||
except json.decoder.JSONDecodeError:
|
||||
return "Error:", "Invalid JSON"
|
||||
# All other errors, return "Error: + error message"
|
||||
except Exception as e:
|
||||
return "Error:", str(e)
|
@@ -1,493 +0,0 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
import pandas as pd
|
||||
|
||||
from pilot.common.json_utils import serialize
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Optional, List
|
||||
from pydantic import BaseModel
|
||||
from pilot.base_modules.agent.common.schema import Status, ApiTagType
|
||||
from pilot.base_modules.agent.commands.command import execute_command
|
||||
from pilot.base_modules.agent.commands.generator import PluginPromptGenerator
|
||||
from pilot.common.string_utils import extract_content_open_ending, extract_content
|
||||
|
||||
# Unique identifier for auto-gpt commands
|
||||
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
|
||||
|
||||
|
||||
class Command:
|
||||
"""A class representing a command.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (str): A brief description of what the command does.
|
||||
signature (str): The signature of the function that the command executes. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.method = method
|
||||
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if not self.enabled:
|
||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
return self.method(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}: {self.description}, args: {self.signature}"
|
||||
|
||||
|
||||
class CommandRegistry:
|
||||
"""
|
||||
The CommandRegistry class is a manager for a collection of Command objects.
|
||||
It allows the registration, modification, and retrieval of Command objects,
|
||||
as well as the scanning and loading of command plugins from a specified
|
||||
directory.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.commands = {}
|
||||
|
||||
def _import_module(self, module_name: str) -> Any:
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
def _reload_module(self, module: Any) -> Any:
|
||||
return importlib.reload(module)
|
||||
|
||||
def register(self, cmd: Command) -> None:
|
||||
self.commands[cmd.name] = cmd
|
||||
|
||||
def unregister(self, command_name: str):
|
||||
if command_name in self.commands:
|
||||
del self.commands[command_name]
|
||||
else:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
|
||||
def reload_commands(self) -> None:
|
||||
"""Reloads all loaded command plugins."""
|
||||
for cmd_name in self.commands:
|
||||
cmd = self.commands[cmd_name]
|
||||
module = self._import_module(cmd.__module__)
|
||||
reloaded_module = self._reload_module(module)
|
||||
if hasattr(reloaded_module, "register"):
|
||||
reloaded_module.register(self)
|
||||
|
||||
def is_valid_command(self, name: str) -> bool:
|
||||
if name not in self.commands:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_command(self, name: str) -> Callable[..., Any]:
|
||||
return self.commands[name]
|
||||
|
||||
def call(self, command_name: str, **kwargs) -> Any:
|
||||
if command_name not in self.commands:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
command = self.commands[command_name]
|
||||
return command(**kwargs)
|
||||
|
||||
def command_prompt(self) -> str:
|
||||
"""
|
||||
Returns a string representation of all registered `Command` objects for use in a prompt
|
||||
"""
|
||||
commands_list = [
|
||||
f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values())
|
||||
]
|
||||
return "\n".join(commands_list)
|
||||
|
||||
def import_commands(self, module_name: str) -> None:
|
||||
"""
|
||||
Imports the specified Python module containing command plugins.
|
||||
|
||||
This method imports the associated module and registers any functions or
|
||||
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
|
||||
as `Command` objects. The registered `Command` objects are then added to the
|
||||
`commands` dictionary of the `CommandRegistry` object.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to import for command plugins.
|
||||
"""
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
# Register decorated functions
|
||||
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
|
||||
attr, AUTO_GPT_COMMAND_IDENTIFIER
|
||||
):
|
||||
self.register(attr.command)
|
||||
# Register command classes
|
||||
elif (
|
||||
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
|
||||
):
|
||||
cmd_instance = attr()
|
||||
self.register(cmd_instance)
|
||||
|
||||
|
||||
def command(
|
||||
name: str,
|
||||
description: str,
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""The command decorator is used to create Command objects from ordinary functions."""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Command:
|
||||
cmd = Command(
|
||||
name=name,
|
||||
description=description,
|
||||
method=func,
|
||||
signature=signature,
|
||||
enabled=enabled,
|
||||
disabled_reason=disabled_reason,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
wrapper.command = cmd
|
||||
|
||||
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PluginStatus(BaseModel):
|
||||
name: str
|
||||
location: List[int]
|
||||
args: dict
|
||||
status: Status = Status.TODO.value
|
||||
logo_url: str = None
|
||||
api_result: str = None
|
||||
err_msg: str = None
|
||||
start_time = datetime.now().timestamp() * 1000
|
||||
end_time: int = None
|
||||
|
||||
df: Any = None
|
||||
|
||||
|
||||
class ApiCall:
|
||||
agent_prefix = "<api-call>"
|
||||
agent_end = "</api-call>"
|
||||
name_prefix = "<name>"
|
||||
name_end = "</name>"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugin_generator: Any = None,
|
||||
display_registry: Any = None,
|
||||
backend_rendering: bool = False,
|
||||
):
|
||||
# self.name: str = ""
|
||||
# self.status: Status = Status.TODO.value
|
||||
# self.logo_url: str = None
|
||||
# self.args = {}
|
||||
# self.api_result: str = None
|
||||
# self.err_msg: str = None
|
||||
|
||||
self.plugin_status_map = {}
|
||||
|
||||
self.plugin_generator = plugin_generator
|
||||
self.display_registry = display_registry
|
||||
self.start_time = datetime.now().timestamp() * 1000
|
||||
self.backend_rendering: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return f"ApiCall(name={self.name}, status={self.status}, args={self.args})"
|
||||
|
||||
def __is_need_wait_plugin_call(self, api_call_context):
|
||||
start_agent_count = api_call_context.count(self.agent_prefix)
|
||||
end_agent_count = api_call_context.count(self.agent_end)
|
||||
|
||||
if start_agent_count > 0:
|
||||
return True
|
||||
else:
|
||||
# 末尾新出字符检测
|
||||
check_len = len(self.agent_prefix)
|
||||
last_text = api_call_context[-check_len:]
|
||||
for i in range(check_len):
|
||||
text_tmp = last_text[-i:]
|
||||
prefix_tmp = self.agent_prefix[:i]
|
||||
if text_tmp == prefix_tmp:
|
||||
return True
|
||||
else:
|
||||
i += 1
|
||||
return False
|
||||
|
||||
def check_last_plugin_call_ready(self, all_context):
|
||||
start_agent_count = all_context.count(self.agent_prefix)
|
||||
end_agent_count = all_context.count(self.agent_end)
|
||||
|
||||
if start_agent_count > 0 and start_agent_count == end_agent_count:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
|
||||
error_md_tags = [
|
||||
"```",
|
||||
"```python",
|
||||
"```xml",
|
||||
"```json",
|
||||
"```markdown",
|
||||
"```sql",
|
||||
]
|
||||
if include_end == False:
|
||||
md_tag_end = ""
|
||||
else:
|
||||
md_tag_end = "```"
|
||||
for tag in error_md_tags:
|
||||
all_context = all_context.replace(
|
||||
tag + api_context + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
tag + "\n" + api_context + "\n" + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
tag + " " + api_context + " " + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(tag + api_context, api_context)
|
||||
return all_context
|
||||
|
||||
def api_view_context(self, all_context: str, display_mode: bool = False):
|
||||
call_context_map = extract_content_open_ending(
|
||||
all_context, self.agent_prefix, self.agent_end, True
|
||||
)
|
||||
for api_index, api_context in call_context_map.items():
|
||||
api_status = self.plugin_status_map.get(api_context)
|
||||
if api_status is not None:
|
||||
if display_mode:
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context)
|
||||
if Status.FAILED.value == api_status.status:
|
||||
all_context = all_context.replace(
|
||||
api_context,
|
||||
f'\n<span style="color:red">Error:</span>{api_status.err_msg}\n'
|
||||
+ self.to_view_antv_vis(api_status),
|
||||
)
|
||||
else:
|
||||
all_context = all_context.replace(
|
||||
api_context, self.to_view_antv_vis(api_status)
|
||||
)
|
||||
|
||||
else:
|
||||
all_context = self.__deal_error_md_tags(
|
||||
all_context, api_context, False
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
api_context, self.to_view_text(api_status)
|
||||
)
|
||||
|
||||
else:
|
||||
# not ready api call view change
|
||||
now_time = datetime.now().timestamp() * 1000
|
||||
cost = (now_time - self.start_time) / 1000
|
||||
cost_str = "{:.2f}".format(cost)
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context)
|
||||
|
||||
all_context = all_context.replace(
|
||||
api_context,
|
||||
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
|
||||
)
|
||||
|
||||
return all_context
|
||||
|
||||
def update_from_context(self, all_context):
|
||||
api_context_map = extract_content(
|
||||
all_context, self.agent_prefix, self.agent_end, True
|
||||
)
|
||||
for api_index, api_context in api_context_map.items():
|
||||
api_context = api_context.replace("\\n", "").replace("\n", "")
|
||||
api_call_element = ET.fromstring(api_context)
|
||||
api_name = api_call_element.find("name").text
|
||||
if api_name.find("[") >= 0 or api_name.find("]") >= 0:
|
||||
api_name = api_name.replace("[", "").replace("]", "")
|
||||
api_args = {}
|
||||
args_elements = api_call_element.find("args")
|
||||
for child_element in args_elements.iter():
|
||||
api_args[child_element.tag] = child_element.text
|
||||
|
||||
api_status = self.plugin_status_map.get(api_context)
|
||||
if api_status is None:
|
||||
api_status = PluginStatus(
|
||||
name=api_name, location=[api_index], args=api_args
|
||||
)
|
||||
self.plugin_status_map[api_context] = api_status
|
||||
else:
|
||||
api_status.location.append(api_index)
|
||||
|
||||
def __to_view_param_str(self, api_status):
|
||||
param = {}
|
||||
if api_status.name:
|
||||
param["name"] = api_status.name
|
||||
param["status"] = api_status.status
|
||||
if api_status.logo_url:
|
||||
param["logo"] = api_status.logo_url
|
||||
|
||||
if api_status.err_msg:
|
||||
param["err_msg"] = api_status.err_msg
|
||||
|
||||
if api_status.api_result:
|
||||
param["result"] = api_status.api_result
|
||||
|
||||
return json.dumps(param, default=serialize, ensure_ascii=False)
|
||||
|
||||
def to_view_text(self, api_status: PluginStatus):
|
||||
api_call_element = ET.Element("dbgpt-view")
|
||||
api_call_element.text = self.__to_view_param_str(api_status)
|
||||
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||
return result.decode("utf-8")
|
||||
|
||||
def to_view_antv_vis(self, api_status: PluginStatus):
|
||||
if self.backend_rendering:
|
||||
html_table = api_status.df.to_html(
|
||||
index=False, escape=False, sparsify=False
|
||||
)
|
||||
table_str = "".join(html_table.split())
|
||||
table_str = table_str.replace("\n", " ")
|
||||
html = f""" \n<div><b>[SQL]{api_status.args["sql"]}</b></div><div class="w-full overflow-auto">{table_str}</div>\n """
|
||||
return html
|
||||
else:
|
||||
api_call_element = ET.Element("chart-view")
|
||||
api_call_element.attrib["content"] = self.__to_antv_vis_param(api_status)
|
||||
api_call_element.text = "\n"
|
||||
# api_call_element.set("content", self.__to_antv_vis_param(api_status))
|
||||
# api_call_element.text = self.__to_antv_vis_param(api_status)
|
||||
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||
return result.decode("utf-8")
|
||||
|
||||
# return f'<chart-view content="{self.__to_antv_vis_param(api_status)}">'
|
||||
|
||||
def __to_antv_vis_param(self, api_status: PluginStatus):
|
||||
param = {}
|
||||
if api_status.name:
|
||||
param["type"] = api_status.name
|
||||
if api_status.args:
|
||||
param["sql"] = api_status.args["sql"]
|
||||
# if api_status.err_msg:
|
||||
# param["err_msg"] = api_status.err_msg
|
||||
|
||||
if api_status.api_result:
|
||||
param["data"] = api_status.api_result
|
||||
else:
|
||||
param["data"] = []
|
||||
return json.dumps(param, ensure_ascii=False)
|
||||
|
||||
def run(self, llm_text):
|
||||
if self.__is_need_wait_plugin_call(llm_text):
|
||||
# wait api call generate complete
|
||||
if self.check_last_plugin_call_ready(llm_text):
|
||||
self.update_from_context(llm_text)
|
||||
for key, value in self.plugin_status_map.items():
|
||||
if value.status == Status.TODO.value:
|
||||
value.status = Status.RUNNING.value
|
||||
logging.info(f"插件执行:{value.name},{value.args}")
|
||||
try:
|
||||
value.api_result = execute_command(
|
||||
value.name, value.args, self.plugin_generator
|
||||
)
|
||||
value.status = Status.COMPLETED.value
|
||||
except Exception as e:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
return self.api_view_context(llm_text)
|
||||
|
||||
def run_display_sql(self, llm_text, sql_run_func):
|
||||
if self.__is_need_wait_plugin_call(llm_text):
|
||||
# wait api call generate complete
|
||||
if self.check_last_plugin_call_ready(llm_text):
|
||||
self.update_from_context(llm_text)
|
||||
for key, value in self.plugin_status_map.items():
|
||||
if value.status == Status.TODO.value:
|
||||
value.status = Status.RUNNING.value
|
||||
logging.info(f"sql展示执行:{value.name},{value.args}")
|
||||
try:
|
||||
sql = value.args["sql"]
|
||||
if sql:
|
||||
param = {
|
||||
"df": sql_run_func(sql),
|
||||
}
|
||||
value.df = param["df"]
|
||||
if self.display_registry.is_valid_command(value.name):
|
||||
value.api_result = self.display_registry.call(
|
||||
value.name, **param
|
||||
)
|
||||
else:
|
||||
value.api_result = self.display_registry.call(
|
||||
"response_table", **param
|
||||
)
|
||||
|
||||
value.status = Status.COMPLETED.value
|
||||
except Exception as e:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
return self.api_view_context(llm_text, True)
|
||||
|
||||
def display_sql_llmvis(self, llm_text, sql_run_func):
|
||||
"""
|
||||
Render charts using the Antv standard protocol
|
||||
Args:
|
||||
llm_text: LLM response text
|
||||
sql_run_func: sql run function
|
||||
|
||||
Returns:
|
||||
ChartView protocol text
|
||||
"""
|
||||
try:
|
||||
if self.__is_need_wait_plugin_call(llm_text):
|
||||
# wait api call generate complete
|
||||
if self.check_last_plugin_call_ready(llm_text):
|
||||
self.update_from_context(llm_text)
|
||||
for key, value in self.plugin_status_map.items():
|
||||
if value.status == Status.TODO.value:
|
||||
value.status = Status.RUNNING.value
|
||||
logging.info(f"sql展示执行:{value.name},{value.args}")
|
||||
try:
|
||||
sql = value.args["sql"]
|
||||
if sql is not None and len(sql) > 0:
|
||||
data_df = sql_run_func(sql)
|
||||
value.df = data_df
|
||||
value.api_result = json.loads(
|
||||
data_df.to_json(
|
||||
orient="records",
|
||||
date_format="iso",
|
||||
date_unit="s",
|
||||
)
|
||||
)
|
||||
value.status = Status.COMPLETED.value
|
||||
else:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = "No executable sql!"
|
||||
|
||||
except Exception as e:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
except Exception as e:
|
||||
logging.error("Api parsing exception", e)
|
||||
raise ValueError("Api parsing exception," + str(e))
|
||||
|
||||
return self.api_view_context(llm_text, True)
|
@@ -1,314 +0,0 @@
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
import pandas as pd
|
||||
import uuid
|
||||
import os
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
from matplotlib.font_manager import FontManager
|
||||
from pilot.common.string_utils import is_scientific_notation
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
||||
|
||||
|
||||
def data_pre_classification(df: DataFrame):
|
||||
## Data pre-classification
|
||||
columns = df.columns.tolist()
|
||||
|
||||
number_columns = []
|
||||
non_numeric_colums = []
|
||||
|
||||
# 收集数据分类小于10个的列
|
||||
non_numeric_colums_value_map = {}
|
||||
numeric_colums_value_map = {}
|
||||
for column_name in columns:
|
||||
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
|
||||
number_columns.append(column_name)
|
||||
unique_values = df[column_name].unique()
|
||||
numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||
else:
|
||||
non_numeric_colums.append(column_name)
|
||||
unique_values = df[column_name].unique()
|
||||
non_numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||
|
||||
sorted_numeric_colums_value_map = dict(
|
||||
sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||
)
|
||||
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
|
||||
|
||||
sorted_colums_value_map = dict(
|
||||
sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||
)
|
||||
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
|
||||
|
||||
# Analyze x-coordinate
|
||||
if len(non_numeric_colums_sort_list) > 0:
|
||||
x_cloumn = non_numeric_colums_sort_list[-1]
|
||||
non_numeric_colums_sort_list.remove(x_cloumn)
|
||||
else:
|
||||
x_cloumn = number_columns[0]
|
||||
numeric_colums_sort_list.remove(x_cloumn)
|
||||
|
||||
# Analyze y-coordinate
|
||||
if len(numeric_colums_sort_list) > 0:
|
||||
y_column = numeric_colums_sort_list[0]
|
||||
numeric_colums_sort_list.remove(y_column)
|
||||
else:
|
||||
raise ValueError("Not enough numeric columns for chart!")
|
||||
|
||||
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
|
||||
|
||||
|
||||
def zh_font_set():
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
|
||||
def format_axis(value, pos):
|
||||
# 判断是否为数字
|
||||
if is_scientific_notation(value):
|
||||
# 判断是否需要进行非科学计数法格式化
|
||||
|
||||
return "{:.2f}".format(value)
|
||||
return value
|
||||
|
||||
|
||||
@command(
|
||||
"response_line_chart",
|
||||
"Line chart display, used to display comparative trend analysis data",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_line_chart(df: DataFrame) -> str:
|
||||
logger.info(f"response_line_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
try:
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
# ## 复杂折线图实现
|
||||
if len(num_colmns) > 0:
|
||||
num_colmns.append(y)
|
||||
df_melted = pd.melt(
|
||||
df,
|
||||
id_vars=x,
|
||||
value_vars=num_colmns,
|
||||
var_name="line",
|
||||
value_name="Value",
|
||||
)
|
||||
sns.lineplot(
|
||||
data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2"
|
||||
)
|
||||
else:
|
||||
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
|
||||
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
|
||||
# ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
|
||||
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, dpi=100, transparent=True)
|
||||
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
except Exception as e:
|
||||
logging.error("Draw Line Chart Faild!" + str(e), e)
|
||||
raise ValueError("Draw Line Chart Faild!" + str(e))
|
||||
|
||||
|
||||
@command(
|
||||
"response_bar_chart",
|
||||
"Histogram, suitable for comparative analysis of multiple target values",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_bar_chart(df: DataFrame) -> str:
|
||||
logger.info(f"response_bar_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
|
||||
hue = None
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
if len(non_num_columns) >= 1:
|
||||
hue = non_num_columns[0]
|
||||
|
||||
if len(num_colmns) >= 1:
|
||||
if hue:
|
||||
if len(num_colmns) >= 2:
|
||||
can_use_columns = num_colmns[:2]
|
||||
else:
|
||||
can_use_columns = num_colmns
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
for sub_y_column in can_use_columns:
|
||||
sns.barplot(
|
||||
data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
|
||||
)
|
||||
else:
|
||||
if len(num_colmns) > 5:
|
||||
can_use_columns = num_colmns[:5]
|
||||
else:
|
||||
can_use_columns = num_colmns
|
||||
can_use_columns.append(y)
|
||||
|
||||
df_melted = pd.melt(
|
||||
df,
|
||||
id_vars=x,
|
||||
value_vars=can_use_columns,
|
||||
var_name="line",
|
||||
value_name="Value",
|
||||
)
|
||||
sns.barplot(
|
||||
data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax
|
||||
)
|
||||
else:
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
|
||||
# 设置 y 轴刻度格式为普通数字格式
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
|
||||
# ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
|
||||
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, dpi=100, transparent=True)
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
@command(
|
||||
"response_pie_chart",
|
||||
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_pie_chart(df: DataFrame) -> str:
|
||||
logger.info(f"response_pie_chart")
|
||||
columns = df.columns.tolist()
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
|
||||
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
ax = df.plot(
|
||||
kind="pie",
|
||||
y=columns[1],
|
||||
ax=ax,
|
||||
labels=df[columns[0]].values,
|
||||
startangle=90,
|
||||
autopct="%1.1f%%",
|
||||
)
|
||||
|
||||
plt.axis("equal") # 使饼图为正圆形
|
||||
# plt.title(columns[0])
|
||||
|
||||
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100, transparent=True)
|
||||
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
|
||||
return html_img
|
@@ -1,21 +0,0 @@
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"response_table",
|
||||
"Table display, suitable for display with many display columns or non-numeric columns",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_table(df: DataFrame) -> str:
|
||||
logger.info(f"response_table")
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
table_str = table_str.replace("\n", " ")
|
||||
html = f""" \n<div class="w-full overflow-auto">{table_str}</div>\n """
|
||||
return html
|
@@ -1,37 +0,0 @@
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"response_data_text",
|
||||
"Text display, the default display method, suitable for single-line or simple content display",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_data_text(df: DataFrame) -> str:
|
||||
logger.info(f"response_data_text")
|
||||
data = df.values
|
||||
|
||||
row_size = data.shape[0]
|
||||
value_str = ""
|
||||
text_info = ""
|
||||
if row_size > 1:
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
text_info = html.replace("\n", " ")
|
||||
elif row_size == 1:
|
||||
row = data[0]
|
||||
for value in row:
|
||||
if value_str:
|
||||
value_str = value_str + f", ** {value} **"
|
||||
else:
|
||||
value_str = f" ** {value} **"
|
||||
text_info = f" {value_str}"
|
||||
else:
|
||||
text_info = f"##### _没有找到可用的数据_"
|
||||
return text_info
|
@@ -1,4 +0,0 @@
|
||||
class NotCommands(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
@@ -1,135 +0,0 @@
|
||||
""" A module for generating custom prompt strings."""
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
class PluginPromptGenerator:
|
||||
"""
|
||||
A class for generating custom prompt strings based on constraints, commands,
|
||||
resources, and performance evaluations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the PromptGenerator object with empty lists of constraints,
|
||||
commands, resources, and performance evaluations.
|
||||
"""
|
||||
self.constraints = []
|
||||
self.commands = []
|
||||
self.resources = []
|
||||
self.performance_evaluation = []
|
||||
self.goals = []
|
||||
self.command_registry = None
|
||||
self.response_format = {
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user",
|
||||
},
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""
|
||||
Add a constraint to the constraints list.
|
||||
|
||||
Args:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args=None,
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a command to the commands list with a label, name, and optional arguments.
|
||||
|
||||
Args:
|
||||
command_label (str): The label of the command.
|
||||
command_name (str): The name of the command.
|
||||
args (dict, optional): A dictionary containing argument names and their
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
if args is None:
|
||||
args = {}
|
||||
|
||||
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
|
||||
|
||||
command = {
|
||||
"label": command_label,
|
||||
"name": command_name,
|
||||
"args": command_args,
|
||||
"function": function,
|
||||
}
|
||||
|
||||
self.commands.append(command)
|
||||
|
||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate a formatted string representation of a command.
|
||||
|
||||
Args:
|
||||
command (dict): A dictionary containing command information.
|
||||
|
||||
Returns:
|
||||
str: The formatted command string.
|
||||
"""
|
||||
args_string = ", ".join(
|
||||
f'"{key}": "{value}"' for key, value in command["args"].items()
|
||||
)
|
||||
return f'"{command["name"]}": {command["label"]} , args: {args_string}'
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
Add a resource to the resources list.
|
||||
|
||||
Args:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
self.resources.append(resource)
|
||||
|
||||
def add_performance_evaluation(self, evaluation: str) -> None:
|
||||
"""
|
||||
Add a performance evaluation item to the performance_evaluation list.
|
||||
|
||||
Args:
|
||||
evaluation (str): The evaluation item to be added.
|
||||
"""
|
||||
self.performance_evaluation.append(evaluation)
|
||||
|
||||
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
|
||||
"""
|
||||
Generate a numbered list from given items based on the item_type.
|
||||
|
||||
Args:
|
||||
items (list): A list of items to be numbered.
|
||||
item_type (str, optional): The type of items in the list.
|
||||
Defaults to 'list'.
|
||||
|
||||
Returns:
|
||||
str: The formatted numbered list.
|
||||
"""
|
||||
if item_type == "command":
|
||||
command_strings = []
|
||||
if self.command_registry:
|
||||
command_strings += [
|
||||
str(item)
|
||||
for item in self.command_registry.commands.values()
|
||||
if item.enabled
|
||||
]
|
||||
# terminate command is added manually
|
||||
command_strings += [self._generate_command_string(item) for item in items]
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
def generate_commands_string(self) -> str:
|
||||
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
@@ -1,18 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PluginStorageType(Enum):
|
||||
Git = "git"
|
||||
Oss = "oss"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
TODO = "todo"
|
||||
RUNNING = "running"
|
||||
FAILED = "failed"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
class ApiTagType(Enum):
|
||||
API_VIEW = "dbgpt_view"
|
||||
API_CALL = "dbgpt_call"
|
@@ -1,163 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
UploadFile,
|
||||
File,
|
||||
)
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
)
|
||||
|
||||
from .model import (
|
||||
PluginHubParam,
|
||||
PagenationFilter,
|
||||
PagenationResult,
|
||||
PluginHubFilter,
|
||||
MyPluginFilter,
|
||||
)
|
||||
from .hub.agent_hub import AgentHub
|
||||
from .db.plugin_hub_db import PluginHubEntity
|
||||
from .plugins_util import scan_plugins
|
||||
from .commands.generator import PluginPromptGenerator
|
||||
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModuleAgent(BaseComponent, ABC):
|
||||
name = ComponentType.AGENT_HUB
|
||||
|
||||
def __init__(self):
|
||||
# load plugins
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
|
||||
|
||||
def refresh_plugins(self):
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def load_select_plugin(
|
||||
self, generator: PluginPromptGenerator, select_plugins: List[str]
|
||||
) -> PluginPromptGenerator:
|
||||
logger.info(f"load_select_plugin:{select_plugins}")
|
||||
# load select plugin
|
||||
for plugin in self.plugins:
|
||||
if plugin._name in select_plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
generator = plugin.post_prompt(generator)
|
||||
return generator
|
||||
|
||||
|
||||
module_agent = ModuleAgent()
|
||||
|
||||
|
||||
@router.post("/v1/agent/hub/update", response_model=Result[str])
|
||||
async def agent_hub_update(update_param: PluginHubParam = Body()):
|
||||
logger.info(f"agent_hub_update:{update_param.__dict__}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
branch = (
|
||||
update_param.branch
|
||||
if update_param.branch is not None and len(update_param.branch) > 0
|
||||
else "main"
|
||||
)
|
||||
authorization = (
|
||||
update_param.authorization
|
||||
if update_param.branch is not None and len(update_param.branch) > 0
|
||||
else None
|
||||
)
|
||||
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Agent Hub Update Error!", e)
|
||||
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")
|
||||
|
||||
|
||||
@router.post("/v1/agent/query", response_model=Result[str])
|
||||
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
logger.info(f"get_agent_list:{filter.__dict__}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
filter_enetity: PluginHubEntity = PluginHubEntity()
|
||||
if filter.filter:
|
||||
attrs = vars(filter.filter) # 获取原始对象的属性字典
|
||||
for attr, value in attrs.items():
|
||||
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
|
||||
|
||||
datas, total_pages, total_count = agent_hub.hub_dao.list(
|
||||
filter_enetity, filter.page_index, filter.page_size
|
||||
)
|
||||
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
|
||||
result.page_index = filter.page_index
|
||||
result.page_size = filter.page_size
|
||||
result.total_page = total_pages
|
||||
result.total_row_count = total_count
|
||||
result.datas = datas
|
||||
# print(json.dumps(result.to_dic()))
|
||||
return Result.succ(result.to_dic())
|
||||
|
||||
|
||||
@router.post("/v1/agent/my", response_model=Result[str])
|
||||
async def my_agents(user: str = None):
|
||||
logger.info(f"my_agents:{user}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agents = agent_hub.get_my_plugin(user)
|
||||
agent_dicts = []
|
||||
for agent in agents:
|
||||
agent_dicts.append(agent.__dict__)
|
||||
|
||||
return Result.succ(agent_dicts)
|
||||
|
||||
|
||||
@router.post("/v1/agent/install", response_model=Result[str])
|
||||
async def agent_install(plugin_name: str, user: str = None):
|
||||
logger.info(f"agent_install:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.install_plugin(plugin_name, user)
|
||||
|
||||
module_agent.refresh_plugins()
|
||||
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Install Error!", e)
|
||||
return Result.failed(code="E0021", msg=f"Plugin Install Error {e}")
|
||||
|
||||
|
||||
@router.post("/v1/agent/uninstall", response_model=Result[str])
|
||||
async def agent_uninstall(plugin_name: str, user: str = None):
|
||||
logger.info(f"agent_uninstall:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.uninstall_plugin(plugin_name, user)
|
||||
|
||||
module_agent.refresh_plugins()
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Uninstall Error!", e)
|
||||
return Result.failed(code="E0022", msg=f"Plugin Uninstall Error {e}")
|
||||
|
||||
|
||||
@router.post("/v1/personal/agent/upload", response_model=Result[str])
|
||||
async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
|
||||
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
await agent_hub.upload_my_plugin(doc_file, user)
|
||||
module_agent.refresh_plugins()
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Upload Personal Plugin Error!", e)
|
||||
return Result.failed(code="E0023", msg=f"Upload Personal Plugin Error {e}")
|
@@ -1,156 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
|
||||
|
||||
class MyPluginEntity(Base):
|
||||
__tablename__ = "my_plugin"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
tenant = Column(String(255), nullable=True, comment="user's tenant")
|
||||
user_code = Column(String(255), nullable=False, comment="user code")
|
||||
user_name = Column(String(255), nullable=True, comment="user name")
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
file_name = Column(String(255), nullable=False, comment="plugin package file name")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
use_count = Column(
|
||||
Integer, nullable=True, default=0, comment="plugin total use count"
|
||||
)
|
||||
succ_count = Column(
|
||||
Integer, nullable=True, default=0, comment="plugin total success count"
|
||||
)
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(
|
||||
DateTime, default=datetime.utcnow, comment="plugin install time"
|
||||
)
|
||||
UniqueConstraint("user_code", "name", name="uk_name")
|
||||
|
||||
|
||||
class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def add(self, engity: MyPluginEntity):
|
||||
session = self.get_session()
|
||||
my_plugin = MyPluginEntity(
|
||||
tenant=engity.tenant,
|
||||
user_code=engity.user_code,
|
||||
user_name=engity.user_name,
|
||||
name=engity.name,
|
||||
type=engity.type,
|
||||
version=engity.version,
|
||||
use_count=engity.use_count or 0,
|
||||
succ_count=engity.succ_count or 0,
|
||||
sys_code=engity.sys_code,
|
||||
gmt_created=datetime.now(),
|
||||
)
|
||||
session.add(my_plugin)
|
||||
session.commit()
|
||||
id = my_plugin.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def update(self, entity: MyPluginEntity):
|
||||
session = self.get_session()
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
|
||||
def get_by_user(self, user: str) -> list[MyPluginEntity]:
|
||||
session = self.get_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
|
||||
session = self.get_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
|
||||
result = my_plugins.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
|
||||
session = self.get_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
all_count = my_plugins.count()
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||
if query.sys_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
|
||||
|
||||
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
|
||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
def count(self, query: MyPluginEntity):
|
||||
session = self.get_session()
|
||||
my_plugins = session.query(func.count(MyPluginEntity.id))
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||
if query.sys_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
|
||||
count = my_plugins.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def delete(self, plugin_id: int):
|
||||
session = self.get_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
query = MyPluginEntity(id=plugin_id)
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
my_plugins.delete()
|
||||
session.commit()
|
||||
session.close()
|
@@ -1,159 +0,0 @@
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, DDL
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from pilot.base_modules.meta_data.meta_data import Base
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
|
||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
||||
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
|
||||
|
||||
|
||||
class PluginHubEntity(Base):
|
||||
__tablename__ = "plugin_hub"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
description = Column(String(255), nullable=False, comment="plugin description")
|
||||
author = Column(String(255), nullable=True, comment="plugin author")
|
||||
email = Column(String(255), nullable=True, comment="plugin author email")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
storage_channel = Column(String(255), comment="plugin storage channel")
|
||||
storage_url = Column(String(255), comment="plugin download url")
|
||||
download_param = Column(String(255), comment="plugin download param")
|
||||
gmt_created = Column(
|
||||
DateTime, default=datetime.utcnow, comment="plugin upload time"
|
||||
)
|
||||
installed = Column(Integer, default=False, comment="plugin already installed count")
|
||||
|
||||
UniqueConstraint("name", name="uk_name")
|
||||
Index("idx_q_type", "type")
|
||||
|
||||
|
||||
class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def add(self, engity: PluginHubEntity):
|
||||
session = self.get_session()
|
||||
timezone = pytz.timezone("Asia/Shanghai")
|
||||
plugin_hub = PluginHubEntity(
|
||||
name=engity.name,
|
||||
author=engity.author,
|
||||
email=engity.email,
|
||||
type=engity.type,
|
||||
version=engity.version,
|
||||
storage_channel=engity.storage_channel,
|
||||
storage_url=engity.storage_url,
|
||||
gmt_created=timezone.localize(datetime.now()),
|
||||
)
|
||||
session.add(plugin_hub)
|
||||
session.commit()
|
||||
id = plugin_hub.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def update(self, entity: PluginHubEntity):
|
||||
session = self.get_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def list(
|
||||
self, query: PluginHubEntity, page=1, page_size=20
|
||||
) -> list[PluginHubEntity]:
|
||||
session = self.get_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
all_count = plugin_hubs.count()
|
||||
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
)
|
||||
|
||||
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
|
||||
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
def get_by_storage_url(self, storage_url):
|
||||
session = self.get_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||
session = self.get_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
|
||||
result = plugin_hubs.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def count(self, query: PluginHubEntity):
|
||||
session = self.get_session()
|
||||
plugin_hubs = session.query(func.count(PluginHubEntity.id))
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
)
|
||||
count = plugin_hubs.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def delete(self, plugin_id: int):
|
||||
session = self.get_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
if plugin_id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
|
||||
plugin_hubs.delete()
|
||||
session.commit()
|
||||
session.close()
|
@@ -1,217 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import glob
|
||||
import shutil
|
||||
from fastapi import UploadFile
|
||||
from typing import Any
|
||||
import tempfile
|
||||
|
||||
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
||||
from ..common.schema import PluginStorageType
|
||||
from ..plugins_util import scan_plugins, update_from_git
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Default_User = "default"
|
||||
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
|
||||
TEMP_PLUGIN_PATH = ""
|
||||
|
||||
|
||||
class AgentHub:
|
||||
def __init__(self, plugin_dir) -> None:
|
||||
self.hub_dao = PluginHubDao()
|
||||
self.my_plugin_dao = MyPluginDao()
|
||||
os.makedirs(plugin_dir, exist_ok=True)
|
||||
self.plugin_dir = plugin_dir
|
||||
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
|
||||
|
||||
def install_plugin(self, plugin_name: str, user_name: str = None):
|
||||
logger.info(f"install_plugin {plugin_name}")
|
||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||
if plugin_entity:
|
||||
if plugin_entity.storage_channel == PluginStorageType.Git.value:
|
||||
try:
|
||||
branch_name = None
|
||||
authorization = None
|
||||
if plugin_entity.download_param:
|
||||
download_param = json.loads(plugin_entity.download_param)
|
||||
branch_name = download_param.get("branch_name")
|
||||
authorization = download_param.get("authorization")
|
||||
file_name = self.__download_from_git(
|
||||
plugin_entity.storage_url, branch_name, authorization
|
||||
)
|
||||
|
||||
# add to my plugins and edit hub status
|
||||
plugin_entity.installed = plugin_entity.installed + 1
|
||||
|
||||
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
|
||||
user_name, plugin_name
|
||||
)
|
||||
if my_plugin_entity is None:
|
||||
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
||||
my_plugin_entity.file_name = file_name
|
||||
if user_name:
|
||||
# TODO use user
|
||||
my_plugin_entity.user_code = user_name
|
||||
my_plugin_entity.user_name = user_name
|
||||
my_plugin_entity.tenant = ""
|
||||
else:
|
||||
my_plugin_entity.user_code = Default_User
|
||||
|
||||
with self.hub_dao.get_session() as session:
|
||||
try:
|
||||
if my_plugin_entity.id is None:
|
||||
session.add(my_plugin_entity)
|
||||
else:
|
||||
session.merge(my_plugin_entity)
|
||||
session.merge(plugin_entity)
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
logger.error("install merge roll back!" + str(e))
|
||||
session.rollback()
|
||||
except Exception as e:
|
||||
logger.error("install pluguin exception!", e)
|
||||
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
||||
|
||||
def uninstall_plugin(self, plugin_name, user):
|
||||
logger.info(f"uninstall_plugin:{plugin_name},{user}")
|
||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
|
||||
if plugin_entity is not None:
|
||||
plugin_entity.installed = plugin_entity.installed - 1
|
||||
with self.hub_dao.get_session() as session:
|
||||
try:
|
||||
my_plugin_q = session.query(MyPluginEntity).filter(
|
||||
MyPluginEntity.name == plugin_name
|
||||
)
|
||||
if user:
|
||||
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||
my_plugin_q.delete()
|
||||
if plugin_entity is not None:
|
||||
session.merge(plugin_entity)
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
|
||||
if plugin_entity is not None:
|
||||
# delete package file if not use
|
||||
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
|
||||
have_installed = False
|
||||
for plugin_info in plugin_infos:
|
||||
if plugin_info.installed > 0:
|
||||
have_installed = True
|
||||
break
|
||||
if not have_installed:
|
||||
plugin_repo_name = (
|
||||
plugin_entity.storage_url.replace(".git", "")
|
||||
.strip("/")
|
||||
.split("/")[-1]
|
||||
)
|
||||
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
else:
|
||||
files = glob.glob(
|
||||
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
|
||||
def __download_from_git(self, github_repo, branch_name, authorization):
|
||||
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
|
||||
|
||||
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
|
||||
my_plugin_entity = MyPluginEntity()
|
||||
my_plugin_entity.name = hub_plugin.name
|
||||
my_plugin_entity.type = hub_plugin.type
|
||||
my_plugin_entity.version = hub_plugin.version
|
||||
return my_plugin_entity
|
||||
|
||||
def refresh_hub_from_git(
|
||||
self,
|
||||
github_repo: str = None,
|
||||
branch_name: str = "main",
|
||||
authorization: str = None,
|
||||
):
|
||||
logger.info("refresh_hub_by_git start!")
|
||||
update_from_git(
|
||||
self.temp_hub_file_path, github_repo, branch_name, authorization
|
||||
)
|
||||
git_plugins = scan_plugins(self.temp_hub_file_path)
|
||||
try:
|
||||
for git_plugin in git_plugins:
|
||||
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
|
||||
if old_hub_info:
|
||||
plugin_hub_info = old_hub_info
|
||||
else:
|
||||
plugin_hub_info = PluginHubEntity()
|
||||
plugin_hub_info.type = ""
|
||||
plugin_hub_info.storage_channel = PluginStorageType.Git.value
|
||||
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
|
||||
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
|
||||
plugin_hub_info.email = getattr(git_plugin, "_email", "")
|
||||
download_param = {}
|
||||
if branch_name:
|
||||
download_param["branch_name"] = branch_name
|
||||
if authorization and len(authorization) > 0:
|
||||
download_param["authorization"] = authorization
|
||||
plugin_hub_info.download_param = json.dumps(download_param)
|
||||
plugin_hub_info.installed = 0
|
||||
|
||||
plugin_hub_info.name = git_plugin._name
|
||||
plugin_hub_info.version = git_plugin._version
|
||||
plugin_hub_info.description = git_plugin._description
|
||||
self.hub_dao.update(plugin_hub_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
|
||||
|
||||
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
|
||||
# We can not move temp file in windows system when we open file in context of `with`
|
||||
file_path = os.path.join(self.plugin_dir, doc_file.filename)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(self.plugin_dir, doc_file.filename),
|
||||
)
|
||||
|
||||
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
|
||||
|
||||
if user is None or len(user) <= 0:
|
||||
user = Default_User
|
||||
|
||||
for my_plugin in my_plugins:
|
||||
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
|
||||
user, my_plugin._name
|
||||
)
|
||||
if my_plugin_entiy is None:
|
||||
my_plugin_entiy = MyPluginEntity()
|
||||
my_plugin_entiy.name = my_plugin._name
|
||||
my_plugin_entiy.version = my_plugin._version
|
||||
my_plugin_entiy.type = "Personal"
|
||||
my_plugin_entiy.user_code = user
|
||||
my_plugin_entiy.user_name = user
|
||||
my_plugin_entiy.tenant = ""
|
||||
my_plugin_entiy.file_name = doc_file.filename
|
||||
self.my_plugin_dao.update(my_plugin_entiy)
|
||||
|
||||
def reload_my_plugins(self):
|
||||
logger.info(f"load_plugins start!")
|
||||
return scan_plugins(self.plugin_dir)
|
||||
|
||||
def get_my_plugin(self, user: str):
|
||||
logger.info(f"get_my_plugin:{user}")
|
||||
if not user:
|
||||
user = Default_User
|
||||
return self.my_plugin_dao.get_by_user(user)
|
@@ -1,69 +0,0 @@
|
||||
from typing import TypedDict, Optional, Dict, List
|
||||
from dataclasses import dataclass
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Generic, Any
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PagenationFilter(BaseModel, Generic[T]):
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
filter: T = None
|
||||
|
||||
|
||||
class PagenationResult(BaseModel, Generic[T]):
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
total_page: int = 0
|
||||
total_row_count: int = 0
|
||||
datas: List[T] = []
|
||||
|
||||
def to_dic(self):
|
||||
data_dicts = []
|
||||
for item in self.datas:
|
||||
data_dicts.append(item.__dict__)
|
||||
return {
|
||||
"page_index": self.page_index,
|
||||
"page_size": self.page_size,
|
||||
"total_page": self.total_page,
|
||||
"total_row_count": self.total_row_count,
|
||||
"datas": data_dicts,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginHubFilter(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
author: str
|
||||
email: str
|
||||
type: str
|
||||
version: str
|
||||
storage_channel: str
|
||||
storage_url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyPluginFilter(BaseModel):
|
||||
tenant: str
|
||||
user_code: str
|
||||
user_name: str
|
||||
name: str
|
||||
file_name: str
|
||||
type: str
|
||||
version: str
|
||||
|
||||
|
||||
class PluginHubParam(BaseModel):
|
||||
channel: Optional[str] = Field("git", description="Plugin storage channel")
|
||||
url: Optional[str] = Field(
|
||||
"https://github.com/eosphoros-ai/DB-GPT-Plugins.git",
|
||||
description="Plugin storage url",
|
||||
)
|
||||
branch: Optional[str] = Field(
|
||||
"main", description="github download branch", nullable=True
|
||||
)
|
||||
authorization: Optional[str] = Field(
|
||||
None, description="github download authorization", nullable=True
|
||||
)
|
@@ -1,260 +0,0 @@
|
||||
"""加载组件"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
import zipfile
|
||||
import fnmatch
|
||||
import requests
|
||||
import git
|
||||
import threading
|
||||
import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
from zipimport import zipimporter
|
||||
|
||||
import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
"""
|
||||
Loader zip plugin file. Native support Auto_gpt_plugin
|
||||
|
||||
Args:
|
||||
zip_path (str): Path to the zipfile.
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of module names found or empty list if none were found.
|
||||
"""
|
||||
result = []
|
||||
with zipfile.ZipFile(zip_path, "r") as zfile:
|
||||
for name in zfile.namelist():
|
||||
if name.endswith("__init__.py") and not name.startswith("__MACOSX"):
|
||||
logger.debug(f"Found module '{name}' in the zipfile at: {name}")
|
||||
result.append(name)
|
||||
if len(result) == 0:
|
||||
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
|
||||
return result
|
||||
|
||||
|
||||
def write_dict_to_json_file(data: dict, file_path: str) -> None:
|
||||
"""
|
||||
Write a dictionary to a JSON file.
|
||||
Args:
|
||||
data (dict): Dictionary to write.
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, indent=4)
|
||||
|
||||
|
||||
def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||
"""
|
||||
Create a directory if it does not exist.
|
||||
Args:
|
||||
directory_path (str): Path to the directory.
|
||||
Returns:
|
||||
bool: True if the directory was created, else False.
|
||||
"""
|
||||
if not os.path.exists(directory_path):
|
||||
try:
|
||||
os.makedirs(directory_path)
|
||||
logger.debug(f"Created directory: {directory_path}")
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warn(f"Error creating directory {directory_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"Directory {directory_path} already exists")
|
||||
return True
|
||||
|
||||
|
||||
def load_native_plugins(cfg: Config):
|
||||
if not cfg.plugins_auto_load:
|
||||
print("not auto load_native_plugins")
|
||||
return
|
||||
|
||||
def load_from_git(cfg: Config):
|
||||
print("async load_native_plugins")
|
||||
branch_name = cfg.plugins_git_branch
|
||||
native_plugin_repo = "DB-GPT-Plugins"
|
||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||
try:
|
||||
session = requests.Session()
|
||||
response = session.get(
|
||||
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={"Authorization": "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"},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
files = glob.glob(
|
||||
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
print("save file")
|
||||
cfg.set_plugins(scan_plugins(cfg.debug_mode))
|
||||
else:
|
||||
print("get file failed,response code:", response.status_code)
|
||||
except Exception as e:
|
||||
print("load plugin from git exception!" + str(e))
|
||||
|
||||
t = threading.Thread(target=load_from_git, args=(cfg,))
|
||||
t.start()
|
||||
|
||||
|
||||
def __scan_plugin_file(file_path, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
logger.info(f"__scan_plugin_file:{file_path},{debug}")
|
||||
loaded_plugins = []
|
||||
if moduleList := inspect_zip_for_modules(str(file_path), debug):
|
||||
for module in moduleList:
|
||||
plugin = Path(file_path)
|
||||
module = Path(module)
|
||||
logger.debug(f"Plugin: {plugin} Module: {module}")
|
||||
zipped_package = zipimporter(str(plugin))
|
||||
zipped_module = zipped_package.load_module(str(module.parent))
|
||||
for key in dir(zipped_module):
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
a_module = getattr(zipped_module, key)
|
||||
a_keys = dir(a_module)
|
||||
if (
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
):
|
||||
loaded_plugins.append(a_module())
|
||||
return loaded_plugins
|
||||
|
||||
|
||||
def scan_plugins(
|
||||
plugins_file_path: str, file_name: str = "", debug: bool = False
|
||||
) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
cfg (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Path]]: List of plugins.
|
||||
"""
|
||||
|
||||
loaded_plugins = []
|
||||
# Generic plugins
|
||||
plugins_path = Path(plugins_file_path)
|
||||
if file_name:
|
||||
plugin_path = Path(plugins_path, file_name)
|
||||
loaded_plugins = __scan_plugin_file(plugin_path)
|
||||
else:
|
||||
for plugin_path in plugins_path.glob("*.zip"):
|
||||
loaded_plugins.extend(__scan_plugin_file(plugin_path))
|
||||
|
||||
if loaded_plugins:
|
||||
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||
for plugin in loaded_plugins:
|
||||
logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}")
|
||||
return loaded_plugins
|
||||
|
||||
|
||||
def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
||||
"""Check if the plugin is in the allowlist or denylist.
|
||||
|
||||
Args:
|
||||
plugin_name (str): Name of the plugin.
|
||||
cfg (Config): Config object.
|
||||
|
||||
Returns:
|
||||
True or False
|
||||
"""
|
||||
logger.debug(f"Checking if plugin {plugin_name} should be loaded")
|
||||
if plugin_name in cfg.plugins_denylist:
|
||||
logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.")
|
||||
return False
|
||||
if plugin_name in cfg.plugins_allowlist:
|
||||
logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.")
|
||||
return True
|
||||
ack = input(
|
||||
f"WARNING: Plugin {plugin_name} found. But not in the"
|
||||
f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): "
|
||||
)
|
||||
return ack.lower() == cfg.authorise_key
|
||||
|
||||
|
||||
def update_from_git(
|
||||
download_path: str,
|
||||
github_repo: str = "",
|
||||
branch_name: str = "main",
|
||||
authorization: str = None,
|
||||
):
|
||||
os.makedirs(download_path, exist_ok=True)
|
||||
if github_repo:
|
||||
if github_repo.index("github.com") <= 0:
|
||||
raise ValueError("Not a correct Github repository address!" + github_repo)
|
||||
github_repo = github_repo.replace(".git", "")
|
||||
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
|
||||
plugin_repo_name = github_repo.strip("/").split("/")[-1]
|
||||
else:
|
||||
url = (
|
||||
"https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
|
||||
)
|
||||
plugin_repo_name = "DB-GPT-Plugins"
|
||||
try:
|
||||
session = requests.Session()
|
||||
headers = {}
|
||||
if authorization and len(authorization) > 0:
|
||||
headers = {"Authorization": authorization}
|
||||
response = session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(download_path)
|
||||
files = glob.glob(os.path.join(plugins_path_path, f"{plugin_repo_name}*"))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
file_name = (
|
||||
f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
|
||||
)
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
return plugin_repo_name
|
||||
else:
|
||||
logger.error("update plugins faild,response code:", response.status_code)
|
||||
raise ValueError("download plugin faild!" + response.status_code)
|
||||
except Exception as e:
|
||||
logger.error("update plugins from git exception!" + str(e))
|
||||
raise ValueError("download plugin exception!", e)
|
||||
|
||||
|
||||
def __fetch_from_git(local_path, git_url):
|
||||
logger.info("fetch plugins from git to local path:{}", local_path)
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
repo = git.Repo(local_path)
|
||||
if repo.is_repo():
|
||||
repo.remotes.origin.pull()
|
||||
else:
|
||||
git.Repo.clone_from(git_url, local_path)
|
||||
|
||||
# if repo.head.is_valid():
|
||||
# clone succ, fetch plugins info
|
@@ -1,3 +0,0 @@
|
||||
flask_sqlalchemy==3.0.5
|
||||
flask==2.3.2
|
||||
gitpython==3.1.36
|
@@ -1,6 +0,0 @@
|
||||
class ModuleMangeApi:
|
||||
def module_name(self):
|
||||
pass
|
||||
|
||||
def register(self):
|
||||
pass
|
@@ -1,25 +0,0 @@
|
||||
from typing import TypeVar, Generic, List, Any
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseDao(Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
orm_base=None,
|
||||
database: str = None,
|
||||
db_engine: Any = None,
|
||||
session: Any = None,
|
||||
) -> None:
|
||||
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
||||
self._orm_base = orm_base
|
||||
self._database = database
|
||||
|
||||
self._db_engine = db_engine
|
||||
self._session = session
|
||||
|
||||
def get_session(self):
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=self._db_engine)
|
||||
session = Session()
|
||||
return session
|
@@ -1,93 +0,0 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import logging
|
||||
|
||||
from sqlalchemy import create_engine, DDL
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config as AlembicConfig
|
||||
from urllib.parse import quote
|
||||
from pilot.configs.config import Config
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# DB-GPT meta_data database config, now support mysql and sqlite
|
||||
CFG = Config()
|
||||
default_db_path = os.path.join(os.getcwd(), "meta_data")
|
||||
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
|
||||
# Meta Info
|
||||
META_DATA_DATABASE = CFG.LOCAL_DB_NAME
|
||||
db_name = META_DATA_DATABASE
|
||||
db_path = default_db_path + f"/{db_name}.db"
|
||||
connection = sqlite3.connect(db_path)
|
||||
|
||||
|
||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||
engine_temp = create_engine(
|
||||
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
|
||||
)
|
||||
# check and auto create mysqldatabase
|
||||
try:
|
||||
# try to connect
|
||||
with engine_temp.connect() as conn:
|
||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
||||
conn.execute(DDL(f"CREATE DATABASE IF NOT EXISTS {db_name}"))
|
||||
print(f"Already connect '{db_name}'")
|
||||
|
||||
except OperationalError as e:
|
||||
# if connect failed, create dbgpt database
|
||||
logger.error(f"{db_name} not connect success!")
|
||||
|
||||
engine = create_engine(
|
||||
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
|
||||
)
|
||||
else:
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
|
||||
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = Session()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Base.metadata.create_all()
|
||||
|
||||
alembic_ini_path = default_db_path + "/alembic.ini"
|
||||
alembic_cfg = AlembicConfig(alembic_ini_path)
|
||||
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
|
||||
|
||||
os.makedirs(default_db_path + "/alembic", exist_ok=True)
|
||||
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
|
||||
|
||||
alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
|
||||
|
||||
alembic_cfg.attributes["target_metadata"] = Base.metadata
|
||||
alembic_cfg.attributes["session"] = session
|
||||
|
||||
|
||||
def ddl_init_and_upgrade(disable_alembic_upgrade: bool):
|
||||
"""Initialize and upgrade database metadata
|
||||
|
||||
Args:
|
||||
disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata
|
||||
"""
|
||||
if disable_alembic_upgrade:
|
||||
logger.info(
|
||||
"disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic"
|
||||
)
|
||||
return
|
||||
|
||||
with engine.connect() as connection:
|
||||
alembic_cfg.attributes["connection"] = connection
|
||||
heads = command.heads(alembic_cfg)
|
||||
print("heads:" + str(heads))
|
||||
|
||||
command.revision(alembic_cfg, "dbgpt ddl upate", True)
|
||||
command.upgrade(alembic_cfg, "head")
|
@@ -1 +0,0 @@
|
||||
|
10
pilot/cache/__init__.py
vendored
10
pilot/cache/__init__.py
vendored
@@ -1,10 +0,0 @@
|
||||
from pilot.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
from pilot.cache.manager import CacheManager, initialize_cache
|
||||
|
||||
__all__ = [
|
||||
"LLMCacheKey",
|
||||
"LLMCacheValue",
|
||||
"LLMCacheClient",
|
||||
"CacheManager",
|
||||
"initialize_cache",
|
||||
]
|
161
pilot/cache/base.py
vendored
161
pilot/cache/base.py
vendored
@@ -1,161 +0,0 @@
|
||||
from abc import ABC, abstractmethod, abstractclassmethod
|
||||
from typing import Any, TypeVar, Generic, Optional, Type, Dict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
T = TypeVar("T", bound="Serializable")
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
@abstractmethod
|
||||
def serialize(self) -> bytes:
|
||||
"""Convert the object into bytes for storage or transmission.
|
||||
|
||||
Returns:
|
||||
bytes: The byte array after serialization
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the object's state to a dictionary."""
|
||||
|
||||
# @staticmethod
|
||||
# @abstractclassmethod
|
||||
# def from_dict(cls: Type["Serializable"], obj_dict: Dict) -> "Serializable":
|
||||
# """Deserialize a dictionary to an Serializable object.
|
||||
# """
|
||||
|
||||
|
||||
class RetrievalPolicy(str, Enum):
|
||||
EXACT_MATCH = "exact_match"
|
||||
SIMILARITY_MATCH = "similarity_match"
|
||||
|
||||
|
||||
class CachePolicy(str, Enum):
|
||||
LRU = "lru"
|
||||
FIFO = "fifo"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH
|
||||
cache_policy: Optional[CachePolicy] = CachePolicy.LRU
|
||||
|
||||
|
||||
class CacheKey(Serializable, ABC, Generic[K]):
|
||||
"""The key of the cache. Must be hashable and comparable.
|
||||
|
||||
Supported cache keys:
|
||||
- The LLM cache key: Include user prompt and the parameters to LLM.
|
||||
- The embedding model cache key: Include the texts to embedding and the parameters to embedding model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of the key."""
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check equality with another key."""
|
||||
|
||||
@abstractmethod
|
||||
def get_hash_bytes(self) -> bytes:
|
||||
"""Return the byte array of hash value."""
|
||||
|
||||
@abstractmethod
|
||||
def get_value(self) -> K:
|
||||
"""Get the underlying value of the cache key.
|
||||
|
||||
Returns:
|
||||
K: The real object of current cache key
|
||||
"""
|
||||
|
||||
|
||||
class CacheValue(Serializable, ABC, Generic[V]):
|
||||
"""Cache value abstract class."""
|
||||
|
||||
@abstractmethod
|
||||
def get_value(self) -> V:
|
||||
"""Get the underlying real value."""
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
"""The serializer abstract class for serializing cache keys and values."""
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self, obj: Serializable) -> bytes:
|
||||
"""Serialize a cache object.
|
||||
|
||||
Args:
|
||||
obj (Serializable): The object to serialize
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
|
||||
"""Deserialize data back into a cache object of the specified type.
|
||||
|
||||
Args:
|
||||
data (bytes): The byte array to deserialize
|
||||
cls (Type[Serializable]): The type of current object
|
||||
|
||||
Returns:
|
||||
Serializable: The serializable object
|
||||
"""
|
||||
|
||||
|
||||
class CacheClient(ABC, Generic[K, V]):
|
||||
"""The cache client interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[CacheValue[V]]:
|
||||
"""Retrieve a value from the cache using the provided key.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
value (CacheValue[V]): The value to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def exists(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
"""Check if a key exists in the cache.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Return:
|
||||
bool: True if the key in the cache, otherwise is False
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_key(self, **kwargs) -> CacheKey[K]:
|
||||
"""Create a cache key with params"""
|
||||
|
||||
@abstractmethod
|
||||
def new_value(self, **kwargs) -> CacheValue[K]:
|
||||
"""Create a cache key with params"""
|
0
pilot/cache/embedding_cache.py
vendored
0
pilot/cache/embedding_cache.py
vendored
148
pilot/cache/llm_cache.py
vendored
148
pilot/cache/llm_cache.py
vendored
@@ -1,148 +0,0 @@
|
||||
from typing import Optional, Dict, Any, Union, List
|
||||
from dataclasses import dataclass, asdict
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from pilot.cache.base import CacheKey, CacheValue, Serializer, CacheClient, CacheConfig
|
||||
from pilot.cache.manager import CacheManager
|
||||
from pilot.cache.storage.base import CacheStorage
|
||||
from pilot.model.base import ModelType, ModelOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCacheKeyData:
|
||||
prompt: str
|
||||
model_name: str
|
||||
temperature: Optional[float] = 0.7
|
||||
max_new_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
model_type: Optional[str] = ModelType.HF
|
||||
|
||||
|
||||
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCacheValueData:
|
||||
output: CacheOutputType
|
||||
user: Optional[str] = None
|
||||
_is_list: Optional[bool] = False
|
||||
|
||||
@staticmethod
|
||||
def from_dict(**kwargs) -> "LLMCacheValueData":
|
||||
output = kwargs.get("output")
|
||||
if not output:
|
||||
raise ValueError("Can't new LLMCacheValueData object, output is None")
|
||||
if isinstance(output, dict):
|
||||
output = ModelOutput(**output)
|
||||
elif isinstance(output, list):
|
||||
kwargs["_is_list"] = True
|
||||
output_list = []
|
||||
for out in output:
|
||||
if isinstance(out, dict):
|
||||
out = ModelOutput(**out)
|
||||
output_list.append(out)
|
||||
output = output_list
|
||||
kwargs["output"] = output
|
||||
return LLMCacheValueData(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
output = self.output
|
||||
is_list = False
|
||||
if isinstance(output, list):
|
||||
output_list = []
|
||||
is_list = True
|
||||
for out in output:
|
||||
output_list.append(out.to_dict())
|
||||
output = output_list
|
||||
else:
|
||||
output = output.to_dict()
|
||||
return {"output": output, "_is_list": is_list, "user": self.user}
|
||||
|
||||
@property
|
||||
def is_list(self) -> bool:
|
||||
return self._is_list
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not isinstance(self.output, list):
|
||||
return f"user: {self.user}, output: {self.output}"
|
||||
else:
|
||||
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
|
||||
|
||||
|
||||
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
||||
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._serializer = serializer
|
||||
self.config = LLMCacheKeyData(**kwargs)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
serialize_bytes = self.serialize()
|
||||
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, LLMCacheKey):
|
||||
return False
|
||||
return self.config == other.config
|
||||
|
||||
def get_hash_bytes(self) -> bytes:
|
||||
serialize_bytes = self.serialize()
|
||||
return hashlib.sha256(serialize_bytes).digest()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self.config)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return self._serializer.serialize(self)
|
||||
|
||||
def get_value(self) -> LLMCacheKeyData:
|
||||
return self.config
|
||||
|
||||
|
||||
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
||||
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._serializer = serializer
|
||||
self.value = LLMCacheValueData.from_dict(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return self.value.to_dict()
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return self._serializer.serialize(self)
|
||||
|
||||
def get_value(self) -> LLMCacheValueData:
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"vaue: {str(self.value)}"
|
||||
|
||||
|
||||
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
||||
def __init__(self, cache_manager: CacheManager) -> None:
|
||||
super().__init__()
|
||||
self._cache_manager: CacheManager = cache_manager
|
||||
|
||||
async def get(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[LLMCacheValue]:
|
||||
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: LLMCacheKey,
|
||||
value: LLMCacheValue,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
return await self._cache_manager.set(key, value, cache_config)
|
||||
|
||||
async def exists(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
return await self.get(key, cache_config) is not None
|
||||
|
||||
def new_key(self, **kwargs) -> LLMCacheKey:
|
||||
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
|
||||
|
||||
def new_value(self, **kwargs) -> LLMCacheValue:
|
||||
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)
|
126
pilot/cache/manager.py
vendored
126
pilot/cache/manager.py
vendored
@@ -1,126 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Type
|
||||
import logging
|
||||
from concurrent.futures import Executor
|
||||
from pilot.cache.storage.base import CacheStorage, StorageItem
|
||||
from pilot.cache.base import (
|
||||
K,
|
||||
V,
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheConfig,
|
||||
Serializer,
|
||||
Serializable,
|
||||
)
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheManager(BaseComponent, ABC):
|
||||
name = ComponentType.MODEL_CACHE_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
"""Set cache"""
|
||||
|
||||
@abstractmethod
|
||||
async def get(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
"""Get cache with key"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def serializer(self) -> Serializer:
|
||||
"""Get cache serializer"""
|
||||
|
||||
|
||||
class LocalCacheManager(CacheManager):
|
||||
def __init__(
|
||||
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
|
||||
) -> None:
|
||||
super().__init__(system_app)
|
||||
self._serializer = serializer
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def executor(self) -> Executor:
|
||||
"""Return executor to submit task"""
|
||||
self._executor = self.system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
if self._storage.support_async():
|
||||
await self._storage.aset(key, value, cache_config)
|
||||
else:
|
||||
await blocking_func_to_async(
|
||||
self.executor, self._storage.set, key, value, cache_config
|
||||
)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
if self._storage.support_async():
|
||||
item_bytes = await self._storage.aget(key, cache_config)
|
||||
else:
|
||||
item_bytes = await blocking_func_to_async(
|
||||
self.executor, self._storage.get, key, cache_config
|
||||
)
|
||||
if not item_bytes:
|
||||
return None
|
||||
return self._serializer.deserialize(item_bytes.value_data, cls)
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
return self._serializer
|
||||
|
||||
|
||||
def initialize_cache(
|
||||
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
|
||||
):
|
||||
from pilot.cache.protocal.json_protocal import JsonSerializer
|
||||
from pilot.cache.storage.base import MemoryCacheStorage
|
||||
|
||||
cache_storage = None
|
||||
if storage_type == "disk":
|
||||
try:
|
||||
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
|
||||
|
||||
cache_storage = DiskCacheStorage(
|
||||
persist_dir, mem_table_buffer_mb=max_memory_mb
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warn(
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
|
||||
)
|
||||
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
|
||||
else:
|
||||
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
|
||||
system_app.register(
|
||||
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
|
||||
)
|
0
pilot/cache/protocal/__init__.py
vendored
0
pilot/cache/protocal/__init__.py
vendored
44
pilot/cache/protocal/json_protocal.py
vendored
44
pilot/cache/protocal/json_protocal.py
vendored
@@ -1,44 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Type
|
||||
import json
|
||||
|
||||
from pilot.cache.base import Serializable, Serializer
|
||||
|
||||
JSON_ENCODING = "utf-8"
|
||||
|
||||
|
||||
class JsonSerializable(Serializable, ABC):
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
"""Return the dict of current serializable object"""
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Convert the object into bytes for storage or transmission."""
|
||||
return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
|
||||
|
||||
|
||||
class JsonSerializer(Serializer):
|
||||
"""The serializer abstract class for serializing cache keys and values."""
|
||||
|
||||
def serialize(self, obj: Serializable) -> bytes:
|
||||
"""Serialize a cache object.
|
||||
|
||||
Args:
|
||||
obj (Serializable): The object to serialize
|
||||
"""
|
||||
return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
|
||||
|
||||
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
|
||||
"""Deserialize data back into a cache object of the specified type.
|
||||
|
||||
Args:
|
||||
data (bytes): The byte array to deserialize
|
||||
cls (Type[Serializable]): The type of current object
|
||||
|
||||
Returns:
|
||||
Serializable: The serializable object
|
||||
"""
|
||||
# Convert bytes back to JSON and then to the specified class
|
||||
json_data = json.loads(data.decode(JSON_ENCODING))
|
||||
# Assume that the cls has an __init__ that accepts a dictionary
|
||||
return cls(**json_data)
|
0
pilot/cache/storage/__init__.py
vendored
0
pilot/cache/storage/__init__.py
vendored
252
pilot/cache/storage/base.py
vendored
252
pilot/cache/storage/base.py
vendored
@@ -1,252 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
import msgpack
|
||||
import logging
|
||||
|
||||
from pilot.cache.base import (
|
||||
K,
|
||||
V,
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheClient,
|
||||
CacheConfig,
|
||||
RetrievalPolicy,
|
||||
CachePolicy,
|
||||
)
|
||||
from pilot.utils.memory_utils import _get_object_bytes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageItem:
|
||||
"""
|
||||
A class representing a storage item.
|
||||
|
||||
This class encapsulates data related to a storage item, such as its length,
|
||||
the hash of the key, and the data for both the key and value.
|
||||
|
||||
Parameters:
|
||||
length (int): The bytes length of the storage item.
|
||||
key_hash (bytes): The hash value of the storage item's key.
|
||||
key_data (bytes): The data of the storage item's key, represented in bytes.
|
||||
value_data (bytes): The data of the storage item's value, also in bytes.
|
||||
"""
|
||||
|
||||
length: int # The bytes length of the storage item
|
||||
key_hash: bytes # The hash value of the storage item's key
|
||||
key_data: bytes # The data of the storage item's key
|
||||
value_data: bytes # The data of the storage item's value
|
||||
|
||||
@staticmethod
|
||||
def build_from(
|
||||
key_hash: bytes, key_data: bytes, value_data: bytes
|
||||
) -> "StorageItem":
|
||||
length = (
|
||||
32
|
||||
+ _get_object_bytes(key_hash)
|
||||
+ _get_object_bytes(key_data)
|
||||
+ _get_object_bytes(value_data)
|
||||
)
|
||||
return StorageItem(
|
||||
length=length, key_hash=key_hash, key_data=key_data, value_data=value_data
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
|
||||
key_hash = key.get_hash_bytes()
|
||||
key_data = key.serialize()
|
||||
value_data = value.serialize()
|
||||
return StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize the StorageItem into a byte stream using MessagePack.
|
||||
|
||||
This method packs the object data into a dictionary, marking the
|
||||
key_data and value_data fields as raw binary data to avoid re-serialization.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized bytes.
|
||||
"""
|
||||
obj = {
|
||||
"length": self.length,
|
||||
"key_hash": msgpack.ExtType(1, self.key_hash),
|
||||
"key_data": msgpack.ExtType(2, self.key_data),
|
||||
"value_data": msgpack.ExtType(3, self.value_data),
|
||||
}
|
||||
return msgpack.packb(obj)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data: bytes) -> "StorageItem":
|
||||
"""Deserialize bytes back into a StorageItem using MessagePack.
|
||||
|
||||
This extracts the fields from the MessagePack dict back into
|
||||
a StorageItem object.
|
||||
|
||||
Args:
|
||||
data (bytes): Serialized bytes
|
||||
|
||||
Returns:
|
||||
StorageItem: Deserialized StorageItem object.
|
||||
"""
|
||||
obj = msgpack.unpackb(data)
|
||||
key_hash = obj["key_hash"].data
|
||||
key_data = obj["key_data"].data
|
||||
value_data = obj["value_data"].data
|
||||
|
||||
return StorageItem(
|
||||
length=obj["length"],
|
||||
key_hash=key_hash,
|
||||
key_data=key_data,
|
||||
value_data=value_data,
|
||||
)
|
||||
|
||||
|
||||
class CacheStorage(ABC):
|
||||
@abstractmethod
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
"""Check whether the CacheConfig is legal.
|
||||
|
||||
Args:
|
||||
cache_config (Optional[CacheConfig]): Cache config.
|
||||
raise_error (Optional[bool]): Whether raise error if illegal.
|
||||
|
||||
Returns:
|
||||
ValueError: Error when raise_error is True and config is illegal.
|
||||
"""
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
|
||||
"""
|
||||
|
||||
async def aget(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
value (CacheValue[V]): The value to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
"""
|
||||
|
||||
async def aset(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
value (CacheValue[V]): The value to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MemoryCacheStorage(CacheStorage):
|
||||
def __init__(self, max_memory_mb: int = 256):
|
||||
self.cache = OrderedDict()
|
||||
self.max_memory = max_memory_mb * 1024 * 1024
|
||||
self.current_memory_usage = 0
|
||||
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
):
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
# Exact match retrieval
|
||||
key_hash = hash(key)
|
||||
item: StorageItem = self.cache.get(key_hash)
|
||||
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
|
||||
|
||||
if not item:
|
||||
return None
|
||||
# Move the item to the end of the OrderedDict to signify recent use.
|
||||
self.cache.move_to_end(key_hash)
|
||||
return item
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
key_hash = hash(key)
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
# Calculate memory size of the new entry
|
||||
new_entry_size = _get_object_bytes(item)
|
||||
# Evict entries if necessary
|
||||
while self.current_memory_usage + new_entry_size > self.max_memory:
|
||||
self._apply_cache_policy(cache_config)
|
||||
|
||||
# Store the item in the cache.
|
||||
self.cache[key_hash] = item
|
||||
self.current_memory_usage += new_entry_size
|
||||
logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}")
|
||||
|
||||
def exists(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
return self.get(key, cache_config) is not None
|
||||
|
||||
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
|
||||
# Remove the oldest/newest item based on the cache policy.
|
||||
if cache_config and cache_config.cache_policy == CachePolicy.FIFO:
|
||||
self.cache.popitem(last=False)
|
||||
else: # Default is LRU
|
||||
self.cache.popitem(last=True)
|
0
pilot/cache/storage/disk/__init__.py
vendored
0
pilot/cache/storage/disk/__init__.py
vendored
93
pilot/cache/storage/disk/disk_storage.py
vendored
93
pilot/cache/storage/disk/disk_storage.py
vendored
@@ -1,93 +0,0 @@
|
||||
from typing import Optional
|
||||
import logging
|
||||
from pilot.cache.base import (
|
||||
K,
|
||||
V,
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheConfig,
|
||||
RetrievalPolicy,
|
||||
CachePolicy,
|
||||
)
|
||||
from pilot.cache.storage.base import StorageItem, CacheStorage
|
||||
from rocksdict import Rdict
|
||||
from rocksdict import Rdict, Options, SliceTransform, PlainTableFactoryOptions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def db_options(
|
||||
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
|
||||
):
|
||||
opt = Options()
|
||||
# create table
|
||||
opt.create_if_missing(True)
|
||||
# config to more jobs, default 2
|
||||
opt.set_max_background_jobs(background_threads)
|
||||
# configure mem-table to a large value
|
||||
opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024)
|
||||
# opt.set_write_buffer_size(1024)
|
||||
# opt.set_level_zero_file_num_compaction_trigger(4)
|
||||
# configure l0 and l1 size, let them have the same size (1 GB)
|
||||
# opt.set_max_bytes_for_level_base(0x40000000)
|
||||
# 256 MB file size
|
||||
# opt.set_target_file_size_base(0x10000000)
|
||||
# use a smaller compaction multiplier
|
||||
# opt.set_max_bytes_for_level_multiplier(4.0)
|
||||
# use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
|
||||
# opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8))
|
||||
# set to plain-table
|
||||
# opt.set_plain_table_factory(PlainTableFactoryOptions())
|
||||
return opt
|
||||
|
||||
|
||||
class DiskCacheStorage(CacheStorage):
|
||||
def __init__(
|
||||
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.db: Rdict = Rdict(
|
||||
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
|
||||
)
|
||||
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
):
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
|
||||
# Exact match retrieval
|
||||
key_hash = key.get_hash_bytes()
|
||||
item_bytes = self.db.get(key_hash)
|
||||
if not item_bytes:
|
||||
return None
|
||||
item = StorageItem.deserialize(item_bytes)
|
||||
logger.debug(f"Read file cache, key: {key}, storage item: {item}")
|
||||
return item
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
key_hash = item.key_hash
|
||||
self.db[key_hash] = item.serialize()
|
||||
logger.debug(f"Save file cache, key: {key}, value: {value}")
|
0
pilot/cache/storage/tests/__init__.py
vendored
0
pilot/cache/storage/tests/__init__.py
vendored
53
pilot/cache/storage/tests/test_storage.py
vendored
53
pilot/cache/storage/tests/test_storage.py
vendored
@@ -1,53 +0,0 @@
|
||||
import pytest
|
||||
from ..base import StorageItem
|
||||
from pilot.utils.memory_utils import _get_object_bytes
|
||||
|
||||
|
||||
def test_build_from():
|
||||
key_hash = b"key_hash"
|
||||
key_data = b"key_data"
|
||||
value_data = b"value_data"
|
||||
item = StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
assert item.key_hash == key_hash
|
||||
assert item.key_data == key_data
|
||||
assert item.value_data == value_data
|
||||
assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes(
|
||||
key_data
|
||||
) + _get_object_bytes(value_data)
|
||||
|
||||
|
||||
def test_build_from_kv():
|
||||
class MockCacheKey:
|
||||
def get_hash_bytes(self):
|
||||
return b"key_hash"
|
||||
|
||||
def serialize(self):
|
||||
return b"key_data"
|
||||
|
||||
class MockCacheValue:
|
||||
def serialize(self):
|
||||
return b"value_data"
|
||||
|
||||
key = MockCacheKey()
|
||||
value = MockCacheValue()
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
|
||||
assert item.key_hash == key.get_hash_bytes()
|
||||
assert item.key_data == key.serialize()
|
||||
assert item.value_data == value.serialize()
|
||||
|
||||
|
||||
def test_serialize_deserialize():
|
||||
key_hash = b"key_hash"
|
||||
key_data = b"key_data"
|
||||
value_data = b"value_data"
|
||||
item = StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
serialized = item.serialize()
|
||||
deserialized = StorageItem.deserialize(serialized)
|
||||
|
||||
assert deserialized.key_hash == item.key_hash
|
||||
assert deserialized.key_data == item.key_data
|
||||
assert deserialized.value_data == item.value_data
|
||||
assert deserialized.length == item.length
|
@@ -1,47 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Coroutine, List, Any
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
|
||||
chat_factory = ChatFactory()
|
||||
|
||||
|
||||
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||
"""llm_chat_response_nostream"""
|
||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||
res = await chat.get_llm_response()
|
||||
return res
|
||||
|
||||
|
||||
async def llm_chat_response(chat_scene: str, **chat_param):
|
||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||
return chat.stream_call()
|
||||
|
||||
|
||||
async def run_async_tasks(
|
||||
tasks: List[Coroutine],
|
||||
concurrency_limit: int = None,
|
||||
) -> List[Any]:
|
||||
"""Run a list of async tasks."""
|
||||
tasks_to_execute: List[Any] = tasks
|
||||
|
||||
async def _gather() -> List[Any]:
|
||||
if concurrency_limit:
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def _execute_task(task):
|
||||
async with semaphore:
|
||||
return await task
|
||||
|
||||
# Execute tasks with semaphore limit
|
||||
return await asyncio.gather(
|
||||
*[_execute_task(task) for task in tasks_to_execute]
|
||||
)
|
||||
else:
|
||||
return await asyncio.gather(*tasks_to_execute)
|
||||
|
||||
# outputs: List[Any] = asyncio.run(_gather())
|
||||
return await _gather()
|
@@ -1,34 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from collections import deque
|
||||
|
||||
|
||||
class FixedSizeDict(OrderedDict):
|
||||
def __init__(self, max_size):
|
||||
super().__init__()
|
||||
self.max_size = max_size
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if len(self) >= self.max_size:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class FixedSizeList:
|
||||
def __init__(self, max_size):
|
||||
self.max_size = max_size
|
||||
self.list = deque(maxlen=max_size)
|
||||
|
||||
def append(self, value):
|
||||
self.list.append(value)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.list[index]
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
self.list[index] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.list)
|
||||
|
||||
def __str__(self):
|
||||
return str(list(self.list))
|
@@ -1,61 +0,0 @@
|
||||
"""Utilities for formatting strings."""
|
||||
import json
|
||||
from string import Formatter
|
||||
from typing import Any, List, Mapping, Sequence, Union
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
"""A subclass of formatter that checks for extra keys."""
|
||||
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Check to see if extra parameters are passed."""
|
||||
extra = set(kwargs).difference(used_args)
|
||||
if extra:
|
||||
raise KeyError(extra)
|
||||
|
||||
def vformat(
|
||||
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||
) -> str:
|
||||
"""Check that no arguments are provided."""
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"No arguments should be provided, "
|
||||
"everything should be passed as keyword arguments."
|
||||
)
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def validate_input_variables(
|
||||
self, format_string: str, input_variables: List[str]
|
||||
) -> None:
|
||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||
super().format(format_string, **dummy_inputs)
|
||||
|
||||
|
||||
class NoStrictFormatter(StrictFormatter):
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Not check unused args"""
|
||||
pass
|
||||
|
||||
|
||||
formatter = StrictFormatter()
|
||||
no_strict_formatter = NoStrictFormatter()
|
||||
|
||||
|
||||
class MyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
return obj.__dict__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
@@ -1,448 +0,0 @@
|
||||
"""General utils functions."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
||||
class GlobalsHelper:
|
||||
"""Helper to retrieve globals.
|
||||
|
||||
Helpful for global caching of certain variables that can be expensive to load.
|
||||
(e.g. tokenization)
|
||||
|
||||
"""
|
||||
|
||||
_tokenizer: Optional[Callable[[str], List]] = None
|
||||
_stopwords: Optional[List[str]] = None
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Callable[[str], List]:
|
||||
"""Get tokenizer."""
|
||||
if self._tokenizer is None:
|
||||
tiktoken_import_err = (
|
||||
"`tiktoken` package not found, please run `pip install tiktoken`"
|
||||
)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(tiktoken_import_err)
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
self._tokenizer = cast(Callable[[str], List], enc.encode)
|
||||
self._tokenizer = partial(self._tokenizer, allowed_special="all")
|
||||
return self._tokenizer # type: ignore
|
||||
|
||||
@property
|
||||
def stopwords(self) -> List[str]:
|
||||
"""Get stopwords."""
|
||||
if self._stopwords is None:
|
||||
try:
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`nltk` package not found, please run `pip install nltk`"
|
||||
)
|
||||
|
||||
from llama_index.utils import get_cache_dir
|
||||
|
||||
cache_dir = get_cache_dir()
|
||||
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
|
||||
|
||||
# update nltk path for nltk so that it finds the data
|
||||
if nltk_data_dir not in nltk.data.path:
|
||||
nltk.data.path.append(nltk_data_dir)
|
||||
|
||||
try:
|
||||
nltk.data.find("corpora/stopwords")
|
||||
except LookupError:
|
||||
nltk.download("stopwords", download_dir=nltk_data_dir)
|
||||
self._stopwords = stopwords.words("english")
|
||||
return self._stopwords
|
||||
|
||||
|
||||
globals_helper = GlobalsHelper()
|
||||
|
||||
|
||||
def get_new_id(d: Set) -> str:
|
||||
"""Get a new ID."""
|
||||
while True:
|
||||
new_id = str(uuid.uuid4())
|
||||
if new_id not in d:
|
||||
break
|
||||
return new_id
|
||||
|
||||
|
||||
def get_new_int_id(d: Set) -> int:
|
||||
"""Get a new integer ID."""
|
||||
while True:
|
||||
new_id = random.randint(0, sys.maxsize)
|
||||
if new_id not in d:
|
||||
break
|
||||
return new_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
|
||||
"""Temporary setter.
|
||||
|
||||
Utility class for setting a temporary value for an attribute on a class.
|
||||
Taken from: https://tinyurl.com/2p89xymh
|
||||
|
||||
"""
|
||||
prev_values = {k: getattr(obj, k) for k in kwargs}
|
||||
for k, v in kwargs.items():
|
||||
setattr(obj, k, v)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k, v in prev_values.items():
|
||||
setattr(obj, k, v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorToRetry:
|
||||
"""Exception types that should be retried.
|
||||
|
||||
Args:
|
||||
exception_cls (Type[Exception]): Class of exception.
|
||||
check_fn (Optional[Callable[[Any]], bool]]):
|
||||
A function that takes an exception instance as input and returns
|
||||
whether to retry.
|
||||
|
||||
"""
|
||||
|
||||
exception_cls: Type[Exception]
|
||||
check_fn: Optional[Callable[[Any], bool]] = None
|
||||
|
||||
|
||||
def retry_on_exceptions_with_backoff(
|
||||
lambda_fn: Callable,
|
||||
errors_to_retry: List[ErrorToRetry],
|
||||
max_tries: int = 10,
|
||||
min_backoff_secs: float = 0.5,
|
||||
max_backoff_secs: float = 60.0,
|
||||
) -> Any:
|
||||
"""Execute lambda function with retries and exponential backoff.
|
||||
|
||||
Args:
|
||||
lambda_fn (Callable): Function to be called and output we want.
|
||||
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
|
||||
At least one needs to be provided.
|
||||
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
|
||||
min_backoff_secs (float): Minimum amount of backoff time between attempts.
|
||||
Defaults to 0.5.
|
||||
max_backoff_secs (float): Maximum amount of backoff time between attempts.
|
||||
Defaults to 60.
|
||||
|
||||
"""
|
||||
if not errors_to_retry:
|
||||
raise ValueError("At least one error to retry needs to be provided")
|
||||
|
||||
error_checks = {
|
||||
error_to_retry.exception_cls: error_to_retry.check_fn
|
||||
for error_to_retry in errors_to_retry
|
||||
}
|
||||
exception_class_tuples = tuple(error_checks.keys())
|
||||
|
||||
backoff_secs = min_backoff_secs
|
||||
tries = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
return lambda_fn()
|
||||
except exception_class_tuples as e:
|
||||
traceback.print_exc()
|
||||
tries += 1
|
||||
if tries >= max_tries:
|
||||
raise
|
||||
check_fn = error_checks.get(e.__class__)
|
||||
if check_fn and not check_fn(e):
|
||||
raise
|
||||
time.sleep(backoff_secs)
|
||||
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
|
||||
|
||||
|
||||
def truncate_text(text: str, max_length: int) -> str:
|
||||
"""Truncate text to a maximum length."""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
|
||||
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
|
||||
"""Iterate over an iterable in batches.
|
||||
|
||||
>>> list(iter_batch([1,2,3,4,5], 3))
|
||||
[[1, 2, 3], [4, 5]]
|
||||
"""
|
||||
source_iter = iter(iterable)
|
||||
while source_iter:
|
||||
b = list(islice(source_iter, size))
|
||||
if len(b) == 0:
|
||||
break
|
||||
yield b
|
||||
|
||||
|
||||
def concat_dirs(dirname: str, basename: str) -> str:
|
||||
"""
|
||||
Append basename to dirname, avoiding backslashes when running on windows.
|
||||
|
||||
os.path.join(dirname, basename) will add a backslash before dirname if
|
||||
basename does not end with a slash, so we make sure it does.
|
||||
"""
|
||||
dirname += "/" if dirname[-1] != "/" else ""
|
||||
return os.path.join(dirname, basename)
|
||||
|
||||
|
||||
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
|
||||
"""
|
||||
Optionally get a tqdm iterable. Ensures tqdm.auto is used.
|
||||
"""
|
||||
_iterator = items
|
||||
if show_progress:
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
return tqdm(items, desc=desc)
|
||||
except ImportError:
|
||||
pass
|
||||
return _iterator
|
||||
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
tokens = globals_helper.tokenizer(text)
|
||||
return len(tokens)
|
||||
|
||||
|
||||
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
|
||||
"""
|
||||
Args:
|
||||
model_name(str): the model name of the tokenizer.
|
||||
For instance, fxmarty/tiny-llama-fast-tokenizer.
|
||||
"""
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`transformers` package not found, please run `pip install transformers`"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
return tokenizer.tokenize
|
||||
|
||||
|
||||
def get_cache_dir() -> str:
|
||||
"""Locate a platform-appropriate cache directory for llama_index,
|
||||
and create it if it doesn't yet exist.
|
||||
"""
|
||||
# User override
|
||||
if "LLAMA_INDEX_CACHE_DIR" in os.environ:
|
||||
path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
|
||||
|
||||
# Linux, Unix, AIX, etc.
|
||||
elif os.name == "posix" and sys.platform != "darwin":
|
||||
path = Path("/tmp/llama_index")
|
||||
|
||||
# Mac OS
|
||||
elif sys.platform == "darwin":
|
||||
path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
|
||||
|
||||
# Windows (hopefully)
|
||||
else:
|
||||
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
|
||||
"~\\AppData\\Local"
|
||||
)
|
||||
path = Path(local, "llama_index")
|
||||
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(
|
||||
path, exist_ok=True
|
||||
) # prevents https://github.com/jerryjliu/llama_index/issues/7362
|
||||
return str(path)
|
||||
|
||||
|
||||
def add_sync_version(func: Any) -> Any:
|
||||
"""Decorator for adding sync version of an async function. The sync version
|
||||
is added as a function attribute to the original function, func.
|
||||
|
||||
Args:
|
||||
func(Any): the async function for which a sync variant will be built.
|
||||
"""
|
||||
assert asyncio.iscoroutinefunction(func)
|
||||
|
||||
@wraps(func)
|
||||
def _wrapper(*args: Any, **kwds: Any) -> Any:
|
||||
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
|
||||
|
||||
func.sync = _wrapper
|
||||
return func
|
||||
|
||||
|
||||
# Sample text from llama_index's readme
|
||||
SAMPLE_TEXT = """
|
||||
Context
|
||||
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
|
||||
They are pre-trained on large amounts of publicly available data.
|
||||
How do we best augment LLMs with our own private data?
|
||||
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
|
||||
|
||||
Proposed Solution
|
||||
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
|
||||
you build LLM apps. It provides the following tools:
|
||||
|
||||
Offers data connectors to ingest your existing data sources and data formats
|
||||
(APIs, PDFs, docs, SQL, etc.)
|
||||
Provides ways to structure your data (indices, graphs) so that this data can be
|
||||
easily used with LLMs.
|
||||
Provides an advanced retrieval/query interface over your data:
|
||||
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
|
||||
Allows easy integrations with your outer application framework
|
||||
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
|
||||
LlamaIndex provides tools for both beginner users and advanced users.
|
||||
Our high-level API allows beginner users to use LlamaIndex to ingest and
|
||||
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
|
||||
customize and extend any module (data connectors, indices, retrievers, query engines,
|
||||
reranking modules), to fit their needs.
|
||||
"""
|
||||
|
||||
_LLAMA_INDEX_COLORS = {
|
||||
"llama_pink": "38;2;237;90;200",
|
||||
"llama_blue": "38;2;90;149;237",
|
||||
"llama_turquoise": "38;2;11;159;203",
|
||||
"llama_lavender": "38;2;155;135;227",
|
||||
}
|
||||
|
||||
_ANSI_COLORS = {
|
||||
"red": "31",
|
||||
"green": "32",
|
||||
"yellow": "33",
|
||||
"blue": "34",
|
||||
"magenta": "35",
|
||||
"cyan": "36",
|
||||
"pink": "38;5;200",
|
||||
}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], use_llama_index_colors: bool = True
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get a mapping of items to colors.
|
||||
|
||||
Args:
|
||||
items (List[str]): List of items to be mapped to colors.
|
||||
use_llama_index_colors (bool, optional): Flag to indicate
|
||||
whether to use LlamaIndex colors or ANSI colors.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Mapping of items to colors.
|
||||
"""
|
||||
if use_llama_index_colors:
|
||||
color_palette = _LLAMA_INDEX_COLORS
|
||||
else:
|
||||
color_palette = _ANSI_COLORS
|
||||
|
||||
colors = list(color_palette.keys())
|
||||
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
|
||||
|
||||
def _get_colored_text(text: str, color: str) -> str:
|
||||
"""
|
||||
Get the colored version of the input text.
|
||||
|
||||
Args:
|
||||
text (str): Input text.
|
||||
color (str): Color to be applied to the text.
|
||||
|
||||
Returns:
|
||||
str: Colored version of the input text.
|
||||
"""
|
||||
all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
|
||||
|
||||
if color not in all_colors:
|
||||
return f"\033[1;3m{text}\033[0m" # just bolded and italicized
|
||||
|
||||
color = all_colors[color]
|
||||
|
||||
return f"\033[1;3;{color}m{text}\033[0m"
|
||||
|
||||
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||
"""
|
||||
Print the text with the specified color.
|
||||
|
||||
Args:
|
||||
text (str): Text to be printed.
|
||||
color (str, optional): Color to be applied to the text. Supported colors are:
|
||||
llama_pink, llama_blue, llama_turquoise, llama_lavender,
|
||||
red, green, yellow, blue, magenta, cyan, pink.
|
||||
end (str, optional): String appended after the last character of the text.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
text_to_print = _get_colored_text(text, color) if color is not None else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
|
||||
def infer_torch_device() -> str:
|
||||
"""Infer the input to torch.device."""
|
||||
try:
|
||||
has_cuda = torch.cuda.is_available()
|
||||
except NameError:
|
||||
import torch
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def unit_generator(x: Any) -> Generator[Any, None, None]:
|
||||
"""A function that returns a generator of a single element.
|
||||
|
||||
Args:
|
||||
x (Any): the element to build yield
|
||||
|
||||
Yields:
|
||||
Any: the single element
|
||||
"""
|
||||
yield x
|
||||
|
||||
|
||||
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
|
||||
"""A function that returns a generator of a single element.
|
||||
|
||||
Args:
|
||||
x (Any): the element to build yield
|
||||
|
||||
Yields:
|
||||
Any: the single element
|
||||
"""
|
||||
yield x
|
@@ -1,7 +0,0 @@
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
if isinstance(obj, date):
|
||||
return obj.isoformat()
|
@@ -1,43 +0,0 @@
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW = 3900
|
||||
DEFAULT_NUM_OUTPUTS = 256
|
||||
|
||||
|
||||
class LLMMetadata(BaseModel):
|
||||
context_window: int = Field(
|
||||
default=DEFAULT_CONTEXT_WINDOW,
|
||||
description=(
|
||||
"Total number of tokens the model can be input and output for one response."
|
||||
),
|
||||
)
|
||||
num_output: int = Field(
|
||||
default=DEFAULT_NUM_OUTPUTS,
|
||||
description="Number of tokens the model can output when generating a response.",
|
||||
)
|
||||
is_chat_model: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Set True if the model exposes a chat interface (i.e. can be passed a"
|
||||
" sequence of messages, rather than text), like OpenAI's"
|
||||
" /v1/chat/completions endpoint."
|
||||
),
|
||||
)
|
||||
is_function_calling_model: bool = Field(
|
||||
default=False,
|
||||
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
|
||||
description=(
|
||||
"Set True if the model supports function calling messages, similar to"
|
||||
" OpenAI's function calling API. For example, converting 'Email Anya to"
|
||||
" see if she wants to get coffee next Friday' to a function call like"
|
||||
" `send_email(to: string, body: string)`."
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="unknown",
|
||||
description=(
|
||||
"The model's name used for logging, testing, and sanity checking. For some"
|
||||
" models this can be automatically discerned. For other models, like"
|
||||
" locally loaded models, this must be manually specified."
|
||||
),
|
||||
)
|
@@ -1,58 +0,0 @@
|
||||
import markdown2
|
||||
|
||||
|
||||
def datas_to_table_html(data):
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(data[1:], columns=data[0])
|
||||
table_style = """<style>
|
||||
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
||||
</style>"""
|
||||
html_table = df.to_html(index=False, escape=False)
|
||||
|
||||
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
||||
|
||||
return html.replace("\n", " ")
|
||||
|
||||
|
||||
def generate_markdown_table(data):
|
||||
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n"""
|
||||
# 获取表格列数
|
||||
num_cols = len(data[0])
|
||||
# 生成表头
|
||||
header = "| "
|
||||
for i in range(num_cols):
|
||||
header += data[0][i] + " | "
|
||||
|
||||
# 生成分隔线
|
||||
separator = "| "
|
||||
for i in range(num_cols):
|
||||
separator += "--- | "
|
||||
|
||||
# 生成表格内容
|
||||
content = ""
|
||||
for row in data[1:]:
|
||||
content += "| "
|
||||
for i in range(num_cols):
|
||||
content += str(row[i]) + " | "
|
||||
content += "\n"
|
||||
|
||||
# 合并表头、分隔线和表格内容
|
||||
table = header + "\n" + separator + "\n" + content
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def generate_htm_table(data):
|
||||
markdown_text = generate_markdown_table(data)
|
||||
html_table = markdown2.markdown(markdown_text, extras=["tables"])
|
||||
return html_table
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n"
|
||||
# print(generate_htm_table(mk_text))
|
||||
|
||||
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
|
||||
|
||||
print(table_style.replace("\n", " "))
|
@@ -1,6 +0,0 @@
|
||||
import os
|
||||
|
||||
|
||||
def has_path(filename):
|
||||
directory = os.path.dirname(filename)
|
||||
return bool(directory)
|
@@ -1,6 +0,0 @@
|
||||
def csv_colunm_foramt(val):
|
||||
if str(val).find("$") >= 0:
|
||||
return float(val.replace("$", "").replace(",", ""))
|
||||
if str(val).find("¥") >= 0:
|
||||
return float(val.replace("¥", "").replace(",", ""))
|
||||
return val
|
@@ -1,239 +0,0 @@
|
||||
"""General prompt helper that can help deal with LLM context window token limitations.
|
||||
|
||||
At its core, it calculates available context size by starting with the context window
|
||||
size of an LLM and reserve token space for the prompt template, and the output.
|
||||
|
||||
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
||||
make use of the available context window (and thereby reducing the number of LLM calls
|
||||
needed), or truncating them so that they fit in a single LLM call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from string import Formatter
|
||||
from typing import Callable, List, Optional, Sequence
|
||||
|
||||
from pydantic import Field, PrivateAttr, BaseModel
|
||||
|
||||
from pilot.common.global_helper import globals_helper
|
||||
from pilot.common.llm_metadata import LLMMetadata
|
||||
from pilot.embedding_engine.loader.token_splitter import TokenTextSplitter
|
||||
|
||||
DEFAULT_PADDING = 5
|
||||
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW = 3000 # tokens
|
||||
DEFAULT_NUM_OUTPUTS = 256 # tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptHelper(BaseModel):
|
||||
"""Prompt helper.
|
||||
|
||||
General prompt helper that can help deal with LLM context window token limitations.
|
||||
|
||||
At its core, it calculates available context size by starting with the context
|
||||
window size of an LLM and reserve token space for the prompt template, and the
|
||||
output.
|
||||
|
||||
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
||||
make use of the available context window (and thereby reducing the number of LLM
|
||||
calls needed), or truncating them so that they fit in a single LLM call.
|
||||
|
||||
Args:
|
||||
context_window (int): Context window for the LLM.
|
||||
num_output (int): Number of outputs for the LLM.
|
||||
chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size
|
||||
chunk_size_limit (Optional[int]): Maximum chunk size to use.
|
||||
tokenizer (Optional[Callable[[str], List]]): Tokenizer to use.
|
||||
separator (str): Separator for text splitter
|
||||
|
||||
"""
|
||||
|
||||
context_window: int = Field(
|
||||
default=DEFAULT_CONTEXT_WINDOW,
|
||||
description="The maximum context size that will get sent to the LLM.",
|
||||
)
|
||||
num_output: int = Field(
|
||||
default=DEFAULT_NUM_OUTPUTS,
|
||||
description="The amount of token-space to leave in input for generation.",
|
||||
)
|
||||
chunk_overlap_ratio: float = Field(
|
||||
default=DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
description="The percentage token amount that each chunk should overlap.",
|
||||
)
|
||||
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
|
||||
separator: str = Field(
|
||||
default=" ", description="The separator when chunking tokens."
|
||||
)
|
||||
|
||||
_tokenizer: Callable[[str], List] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
||||
num_output: int = DEFAULT_NUM_OUTPUTS,
|
||||
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
chunk_size_limit: Optional[int] = None,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
separator: str = " ",
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
|
||||
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
|
||||
|
||||
# TODO: make configurable
|
||||
self._tokenizer = tokenizer or globals_helper.tokenizer
|
||||
|
||||
super().__init__(
|
||||
context_window=context_window,
|
||||
num_output=num_output,
|
||||
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||
chunk_size_limit=chunk_size_limit,
|
||||
separator=separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm_metadata(
|
||||
cls,
|
||||
llm_metadata: LLMMetadata,
|
||||
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
chunk_size_limit: Optional[int] = None,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
separator: str = " ",
|
||||
) -> "PromptHelper":
|
||||
"""Create from llm predictor.
|
||||
|
||||
This will autofill values like context_window and num_output.
|
||||
|
||||
"""
|
||||
context_window = llm_metadata.context_window
|
||||
if llm_metadata.num_output == -1:
|
||||
num_output = DEFAULT_NUM_OUTPUTS
|
||||
else:
|
||||
num_output = llm_metadata.num_output
|
||||
|
||||
return cls(
|
||||
context_window=context_window,
|
||||
num_output=num_output,
|
||||
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||
chunk_size_limit=chunk_size_limit,
|
||||
tokenizer=tokenizer,
|
||||
separator=separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "PromptHelper"
|
||||
|
||||
def _get_available_context_size(self, template: str) -> int:
|
||||
"""Get available context size.
|
||||
|
||||
This is calculated as:
|
||||
available context window = total context window
|
||||
- input (partially filled prompt)
|
||||
- output (room reserved for response)
|
||||
|
||||
Notes:
|
||||
- Available context size is further clamped to be non-negative.
|
||||
"""
|
||||
empty_prompt_txt = get_empty_prompt_txt(template)
|
||||
num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt))
|
||||
context_size_tokens = (
|
||||
self.context_window - num_empty_prompt_tokens - self.num_output
|
||||
)
|
||||
if context_size_tokens < 0:
|
||||
raise ValueError(
|
||||
f"Calculated available context size {context_size_tokens} was"
|
||||
" not non-negative."
|
||||
)
|
||||
return context_size_tokens
|
||||
|
||||
def _get_available_chunk_size(
|
||||
self, prompt_template: str, num_chunks: int = 1, padding: int = 5
|
||||
) -> int:
|
||||
"""Get available chunk size.
|
||||
|
||||
This is calculated as:
|
||||
available chunk size = available context window // number_chunks
|
||||
- padding
|
||||
|
||||
Notes:
|
||||
- By default, we use padding of 5 (to save space for formatting needs).
|
||||
- Available chunk size is further clamped to chunk_size_limit if specified.
|
||||
"""
|
||||
available_context_size = self._get_available_context_size(prompt_template)
|
||||
result = available_context_size // num_chunks - padding
|
||||
if self.chunk_size_limit is not None:
|
||||
result = min(result, self.chunk_size_limit)
|
||||
return result
|
||||
|
||||
def get_text_splitter_given_prompt(
|
||||
self,
|
||||
prompt_template: str,
|
||||
num_chunks: int = 1,
|
||||
padding: int = DEFAULT_PADDING,
|
||||
) -> TokenTextSplitter:
|
||||
"""Get text splitter configured to maximally pack available context window,
|
||||
taking into account of given prompt, and desired number of chunks.
|
||||
"""
|
||||
chunk_size = self._get_available_chunk_size(
|
||||
prompt_template, num_chunks, padding=padding
|
||||
)
|
||||
if chunk_size <= 0:
|
||||
raise ValueError(f"Chunk size {chunk_size} is not positive.")
|
||||
chunk_overlap = int(self.chunk_overlap_ratio * chunk_size)
|
||||
return TokenTextSplitter(
|
||||
separator=self.separator,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
tokenizer=self._tokenizer,
|
||||
)
|
||||
|
||||
def repack(
|
||||
self,
|
||||
prompt_template: str,
|
||||
text_chunks: Sequence[str],
|
||||
padding: int = DEFAULT_PADDING,
|
||||
) -> List[str]:
|
||||
"""Repack text chunks to fit available context window.
|
||||
|
||||
This will combine text chunks into consolidated chunks
|
||||
that more fully "pack" the prompt template given the context_window.
|
||||
|
||||
"""
|
||||
text_splitter = self.get_text_splitter_given_prompt(
|
||||
prompt_template, padding=padding
|
||||
)
|
||||
combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()])
|
||||
return text_splitter.split_text(combined_str)
|
||||
|
||||
|
||||
def get_empty_prompt_txt(template: str) -> str:
|
||||
"""Get empty prompt text.
|
||||
|
||||
Substitute empty strings in parts of the prompt that have
|
||||
not yet been filled out. Skip variables that have already
|
||||
been partially formatted. This is used to compute the initial tokens.
|
||||
|
||||
"""
|
||||
# partial_kargs = prompt.kwargs
|
||||
|
||||
partial_kargs = {}
|
||||
template_vars = get_template_vars(template)
|
||||
empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs}
|
||||
all_kwargs = {**partial_kargs, **empty_kwargs}
|
||||
prompt = template.format(**all_kwargs)
|
||||
return prompt
|
||||
|
||||
|
||||
def get_template_vars(template_str: str) -> List[str]:
|
||||
"""Get template variables from a template string."""
|
||||
variables = []
|
||||
formatter = Formatter()
|
||||
|
||||
for _, variable_name, _, _ in formatter.parse(template_str):
|
||||
if variable_name:
|
||||
variables.append(variable_name)
|
||||
|
||||
return variables
|
@@ -1,56 +0,0 @@
|
||||
from enum import auto, Enum
|
||||
import os
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
SINGLE = "###"
|
||||
TWO = "</s>"
|
||||
THREE = auto()
|
||||
FOUR = auto()
|
||||
|
||||
|
||||
class ExampleType(Enum):
|
||||
ONE_SHOT = "one_shot"
|
||||
FEW_SHOT = "few_shot"
|
||||
|
||||
|
||||
class DbInfo:
|
||||
def __init__(self, name, is_file_db: bool = False):
|
||||
self.name = name
|
||||
self.is_file_db = is_file_db
|
||||
|
||||
|
||||
class DBType(Enum):
|
||||
Mysql = DbInfo("mysql")
|
||||
OCeanBase = DbInfo("oceanbase")
|
||||
DuckDb = DbInfo("duckdb", True)
|
||||
SQLite = DbInfo("sqlite", True)
|
||||
Oracle = DbInfo("oracle")
|
||||
MSSQL = DbInfo("mssql")
|
||||
Postgresql = DbInfo("postgresql")
|
||||
Clickhouse = DbInfo("clickhouse")
|
||||
StarRocks = DbInfo("starrocks")
|
||||
Spark = DbInfo("spark", True)
|
||||
Doris = DbInfo("doris")
|
||||
|
||||
def value(self):
|
||||
return self._value_.name
|
||||
|
||||
def is_file_db(self):
|
||||
return self._value_.is_file_db
|
||||
|
||||
@staticmethod
|
||||
def of_db_type(db_type: str):
|
||||
for item in DBType:
|
||||
if item.value() == db_type:
|
||||
return item
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_file_db_name_from_path(db_type: str, local_db_path: str):
|
||||
"""Parse out the database name of the embedded database from the file path"""
|
||||
base_name = os.path.basename(local_db_path)
|
||||
db_name = os.path.splitext(base_name)[0]
|
||||
if "." in db_name:
|
||||
db_name = os.path.splitext(db_name)[0]
|
||||
return db_type + "_" + db_name
|
@@ -1,491 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sqlparse
|
||||
import regex as re
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Optional
|
||||
import sqlalchemy
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
|
||||
|
||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
return (
|
||||
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
||||
f' Columns: {str(index["column_names"])}'
|
||||
)
|
||||
|
||||
|
||||
class Database:
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
self._schema = schema
|
||||
if include_tables and ignore_tables:
|
||||
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||
|
||||
self._inspector = inspect(self._engine)
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
Session = scoped_session(session_factory)
|
||||
|
||||
self._db_sessions = Session
|
||||
|
||||
self._all_tables = set()
|
||||
self.view_support = False
|
||||
self._usable_tables = set()
|
||||
self._include_tables = set()
|
||||
self._ignore_tables = set()
|
||||
self._custom_table_info = set()
|
||||
self._indexes_in_table_info = set()
|
||||
self._usable_tables = set()
|
||||
self._usable_tables = set()
|
||||
self._sample_rows_in_table_info = set()
|
||||
self._indexes_in_table_info = indexes_in_table_info
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> Database:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return string representation of dialect to use."""
|
||||
return self._engine.dialect.name
|
||||
|
||||
def get_usable_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
if self._include_tables:
|
||||
return self._include_tables
|
||||
return self._all_tables - self._ignore_tables
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
warnings.warn(
|
||||
"This method is deprecated - please use `get_usable_table_names`."
|
||||
)
|
||||
return self.get_usable_table_names()
|
||||
|
||||
def get_session_db(self, connect):
|
||||
sql = text(f"select DATABASE()")
|
||||
cursor = connect.execute(sql)
|
||||
result = cursor.fetchone()[0]
|
||||
return result
|
||||
|
||||
def get_session(self, db_name: str):
|
||||
session = self._db_sessions()
|
||||
|
||||
self._metadata = MetaData()
|
||||
# sql = f"use {db_name}"
|
||||
sql = text(f"use `{db_name}`")
|
||||
session.execute(sql)
|
||||
|
||||
# 处理表信息数据
|
||||
|
||||
self._metadata.reflect(bind=self._engine, schema=db_name)
|
||||
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
self._all_tables = set(
|
||||
self._inspector.get_table_names(schema=db_name)
|
||||
+ (
|
||||
self._inspector.get_view_names(schema=db_name)
|
||||
if self.view_support
|
||||
else []
|
||||
)
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
def get_current_db_name(self, session) -> str:
|
||||
return session.execute(text("SELECT DATABASE()")).scalar()
|
||||
|
||||
def table_simple_info(self, session):
|
||||
_sql = f"""
|
||||
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
|
||||
"""
|
||||
cursor = session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
"""Information about all tables in the database."""
|
||||
return self.get_table_info()
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables.
|
||||
|
||||
Follows best practices as specified in: Rajkumar et al, 2022
|
||||
(https://arxiv.org/abs/2204.00498)
|
||||
|
||||
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
||||
appended to each table description. This can increase performance as
|
||||
demonstrated in the paper.
|
||||
"""
|
||||
all_table_names = self.get_usable_table_names()
|
||||
if table_names is not None:
|
||||
missing_tables = set(table_names).difference(all_table_names)
|
||||
if missing_tables:
|
||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
all_table_names = table_names
|
||||
|
||||
meta_tables = [
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
for table in meta_tables:
|
||||
if self._custom_table_info and table.name in self._custom_table_info:
|
||||
tables.append(self._custom_table_info[table.name])
|
||||
continue
|
||||
|
||||
# add create table command
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
if self._indexes_in_table_info:
|
||||
table_info += f"\n{self._get_table_indexes(table)}\n"
|
||||
if self._sample_rows_in_table_info:
|
||||
table_info += f"\n{self._get_sample_rows(table)}\n"
|
||||
if has_extra_info:
|
||||
table_info += "*/"
|
||||
tables.append(table_info)
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def _get_sample_rows(self, table: Table) -> str:
|
||||
# build the select command
|
||||
command = select(table).limit(self._sample_rows_in_table_info)
|
||||
|
||||
# save the columns in string format
|
||||
columns_str = "\t".join([col.name for col in table.columns])
|
||||
|
||||
try:
|
||||
# get the sample rows
|
||||
with self._engine.connect() as connection:
|
||||
sample_rows_result: CursorResult = connection.execute(command)
|
||||
# shorten values in the sample rows
|
||||
sample_rows = list(
|
||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
||||
)
|
||||
|
||||
# save the sample rows in string format
|
||||
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
||||
|
||||
# in some dialects when there are no rows in the table a
|
||||
# 'ProgrammingError' is returned
|
||||
except ProgrammingError:
|
||||
sample_rows_str = ""
|
||||
|
||||
return (
|
||||
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
||||
f"{columns_str}\n"
|
||||
f"{sample_rows_str}"
|
||||
)
|
||||
|
||||
def _get_table_indexes(self, table: Table) -> str:
|
||||
indexes = self._inspector.get_indexes(table.name)
|
||||
indexes_formatted = "\n".join(map(_format_index, indexes))
|
||||
return f"Table Indexes:\n{indexes_formatted}"
|
||||
|
||||
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables."""
|
||||
try:
|
||||
return self.get_table_info(table_names)
|
||||
except ValueError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def __write(self, session, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self.get_session_db(session)
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
session.execute(text(f"use `{db_cache}`"))
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def __query(self, session, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def query_ex(self, session, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
Returns:
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = list(i[0:] for i in cursor.keys())
|
||||
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
if not command:
|
||||
return []
|
||||
parsed, ttype, sql_type = self.__sql_parse(command)
|
||||
if ttype == sqlparse.tokens.DML:
|
||||
if sql_type == "SELECT":
|
||||
return self.__query(session, command, fetch)
|
||||
else:
|
||||
self.__write(session, command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self.__query(session, select_sql)
|
||||
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
cursor = session.execute(text(command))
|
||||
session.commit()
|
||||
if cursor.returns_rows:
|
||||
result = cursor.fetchall()
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
if not result:
|
||||
return self.__query(session, "SHOW COLUMNS FROM test")
|
||||
return result
|
||||
else:
|
||||
return self.__query(session, "SHOW COLUMNS FROM test")
|
||||
|
||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
|
||||
If the statement throws an error, the error message is returned.
|
||||
"""
|
||||
try:
|
||||
return self.run(session, command, fetch)
|
||||
except SQLAlchemyError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0]
|
||||
not in [
|
||||
"information_schema",
|
||||
"performance_schema",
|
||||
"sys",
|
||||
"mysql",
|
||||
"knowledge_management",
|
||||
]
|
||||
]
|
||||
|
||||
def convert_sql_write_to_select(self, write_sql):
|
||||
"""
|
||||
SQL classification processing
|
||||
author:xiangh8
|
||||
Args:
|
||||
sql:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# 将SQL命令转换为小写,并按空格拆分
|
||||
parts = write_sql.lower().split()
|
||||
# 获取命令类型(insert, delete, update)
|
||||
cmd_type = parts[0]
|
||||
|
||||
# 根据命令类型进行处理
|
||||
if cmd_type == "insert":
|
||||
match = re.match(
|
||||
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
|
||||
)
|
||||
if match:
|
||||
table_name, columns, values = match.groups()
|
||||
# 将字段列表和值列表分割为单独的字段和值
|
||||
columns = columns.split(",")
|
||||
values = values.split(",")
|
||||
# 构造 WHERE 子句
|
||||
where_clause = " AND ".join(
|
||||
[
|
||||
f"{col.strip()}={val.strip()}"
|
||||
for col, val in zip(columns, values)
|
||||
]
|
||||
)
|
||||
return f"SELECT * FROM {table_name} WHERE {where_clause}"
|
||||
|
||||
elif cmd_type == "delete":
|
||||
table_name = parts[2] # delete from <table_name> ...
|
||||
# 返回一个select语句,它选择该表的所有数据
|
||||
return f"SELECT * FROM {table_name} "
|
||||
|
||||
elif cmd_type == "update":
|
||||
table_name = parts[1]
|
||||
set_idx = parts.index("set")
|
||||
where_idx = parts.index("where")
|
||||
# 截取 `set` 子句中的字段名
|
||||
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
|
||||
# 截取 `where` 之后的条件语句
|
||||
where_clause = " ".join(parts[where_idx + 1 :])
|
||||
# 返回一个select语句,它选择更新的数据
|
||||
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||
|
||||
def __sql_parse(self, sql):
|
||||
sql = sql.strip()
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
sql_type = parsed.get_type()
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
|
||||
return parsed, ttype, sql_type
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW INDEXES FROM `{table_name}`"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[2], index[4]) for index in indexes]
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW CREATE TABLE `{table_name}`"))
|
||||
ans = cursor.fetchall()
|
||||
res = ans[0][1]
|
||||
res = re.sub(r"\s*ENGINE\s*=\s*InnoDB\s*", " ", res, flags=re.IGNORECASE)
|
||||
res = re.sub(
|
||||
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", res, flags=re.IGNORECASE
|
||||
)
|
||||
res = re.sub(r"\s*COLLATE\s*=\s*\w+\s*", " ", res, flags=re.IGNORECASE)
|
||||
return res
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
|
||||
table_name
|
||||
)
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT @@character_set_database"))
|
||||
character_set = cursor.fetchone()[0]
|
||||
return character_set
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT @@collation_database"))
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grant info."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW GRANTS"))
|
||||
grants = cursor.fetchall()
|
||||
return grants
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT user, host FROM mysql.user"))
|
||||
users = cursor.fetchall()
|
||||
return [(user[0], user[1]) for user in users]
|
||||
|
||||
def get_table_comments(self, database):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{database}'""".format(
|
||||
database
|
||||
)
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
@@ -1,81 +0,0 @@
|
||||
import re
|
||||
|
||||
|
||||
def is_all_chinese(text):
|
||||
### Determine whether the string is pure Chinese
|
||||
pattern = re.compile(r"^[一-龥]+$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_number_chinese(text):
|
||||
### Determine whether the string is numbers and Chinese
|
||||
pattern = re.compile(r"^[\d一-龥]+$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_chinese_include_number(text):
|
||||
### Determine whether the string is pure Chinese or Chinese containing numbers
|
||||
pattern = re.compile(r"^[一-龥]+[\d一-龥]*$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_scientific_notation(string):
|
||||
# 科学计数法的正则表达式
|
||||
pattern = r"^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$"
|
||||
# 使用正则表达式匹配字符串
|
||||
match = re.match(pattern, str(string))
|
||||
# 判断是否匹配成功
|
||||
if match is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def extract_content(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text
|
||||
match_map = {}
|
||||
start_index = long_string.find(s1)
|
||||
while start_index != -1:
|
||||
if is_include:
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
extracted_content = long_string[start_index : end_index + len(s2)]
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
extracted_content = long_string[start_index + len(s1) : end_index]
|
||||
if extracted_content:
|
||||
match_map[start_index] = extracted_content
|
||||
start_index = long_string.find(s1, start_index + 1)
|
||||
return match_map
|
||||
|
||||
|
||||
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text open ending
|
||||
match_map = {}
|
||||
start_index = long_string.find(s1)
|
||||
while start_index != -1:
|
||||
if long_string.find(s2, start_index) <= 0:
|
||||
end_index = len(long_string)
|
||||
else:
|
||||
if is_include:
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
if is_include:
|
||||
extracted_content = long_string[start_index : end_index + len(s2)]
|
||||
else:
|
||||
extracted_content = long_string[start_index + len(s1) : end_index]
|
||||
if extracted_content:
|
||||
match_map[start_index] = extracted_content
|
||||
start_index = long_string.find(s1, start_index + 1)
|
||||
return match_map
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
|
||||
s1 = "123"
|
||||
s2 = "456"
|
||||
|
||||
print(extract_content_open_ending(s, s1, s2, True))
|
@@ -1,221 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import sys
|
||||
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
|
||||
from enum import Enum
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
# Checking for type hints during runtime
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LifeCycle:
|
||||
"""This class defines hooks for lifecycle events of a component."""
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the component starts."""
|
||||
pass
|
||||
|
||||
async def async_before_start(self):
|
||||
"""Asynchronous version of before_start."""
|
||||
pass
|
||||
|
||||
def after_start(self):
|
||||
"""Called after the component has started."""
|
||||
pass
|
||||
|
||||
async def async_after_start(self):
|
||||
"""Asynchronous version of after_start."""
|
||||
pass
|
||||
|
||||
def before_stop(self):
|
||||
"""Called before the component stops."""
|
||||
pass
|
||||
|
||||
async def async_before_stop(self):
|
||||
"""Asynchronous version of before_stop."""
|
||||
pass
|
||||
|
||||
|
||||
class ComponentType(str, Enum):
|
||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||
MODEL_REGISTRY = "dbgpt_model_registry"
|
||||
MODEL_API_SERVER = "dbgpt_model_api_server"
|
||||
MODEL_CACHE_MANAGER = "dbgpt_model_cache_manager"
|
||||
AGENT_HUB = "dbgpt_agent_hub"
|
||||
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
||||
TRACER = "dbgpt_tracer"
|
||||
TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage"
|
||||
RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
|
||||
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
|
||||
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
|
||||
|
||||
|
||||
class BaseComponent(LifeCycle, ABC):
|
||||
"""Abstract Base Component class. All custom components should extend this."""
|
||||
|
||||
name = "base_dbgpt_component"
|
||||
|
||||
def __init__(self, system_app: Optional[SystemApp] = None):
|
||||
if system_app is not None:
|
||||
self.init_app(system_app)
|
||||
|
||||
@abstractmethod
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the component with the main application.
|
||||
|
||||
This method needs to be implemented by every component to define how it integrates
|
||||
with the main system app.
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseComponent)
|
||||
|
||||
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
|
||||
|
||||
|
||||
class SystemApp(LifeCycle):
|
||||
"""Main System Application class that manages the lifecycle and registration of components."""
|
||||
|
||||
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
|
||||
self.components: Dict[
|
||||
str, BaseComponent
|
||||
] = {} # Dictionary to store registered components.
|
||||
self._asgi_app = asgi_app
|
||||
|
||||
@property
|
||||
def app(self) -> Optional["FastAPI"]:
|
||||
"""Returns the internal ASGI app."""
|
||||
return self._asgi_app
|
||||
|
||||
def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
|
||||
"""Register a new component by its type.
|
||||
|
||||
Args:
|
||||
component (Type[BaseComponent]): The component class to register
|
||||
|
||||
Returns:
|
||||
T: The instance of registered component
|
||||
"""
|
||||
instance = component(self, *args, **kwargs)
|
||||
self.register_instance(instance)
|
||||
return instance
|
||||
|
||||
def register_instance(self, instance: T) -> T:
|
||||
"""Register an already initialized component.
|
||||
|
||||
Args:
|
||||
instance (T): The component instance to register
|
||||
|
||||
Returns:
|
||||
T: The instance of registered component
|
||||
"""
|
||||
name = instance.name
|
||||
if isinstance(name, ComponentType):
|
||||
name = name.value
|
||||
if name in self.components:
|
||||
raise RuntimeError(
|
||||
f"Componse name {name} already exists: {self.components[name]}"
|
||||
)
|
||||
logger.info(f"Register component with name {name} and instance: {instance}")
|
||||
self.components[name] = instance
|
||||
instance.init_app(self)
|
||||
return instance
|
||||
|
||||
def get_component(
|
||||
self,
|
||||
name: Union[str, ComponentType],
|
||||
component_type: Type[T],
|
||||
default_component=_EMPTY_DEFAULT_COMPONENT,
|
||||
or_register_component: Type[BaseComponent] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Retrieve a registered component by its name and type.
|
||||
|
||||
Args:
|
||||
name (Union[str, ComponentType]): Component name
|
||||
component_type (Type[T]): The type of current retrieve component
|
||||
default_component : The default component instance if not retrieve by name
|
||||
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
|
||||
|
||||
Returns:
|
||||
T: The instance retrieved by component name
|
||||
"""
|
||||
if isinstance(name, ComponentType):
|
||||
name = name.value
|
||||
component = self.components.get(name)
|
||||
if not component:
|
||||
if or_register_component:
|
||||
return self.register(or_register_component, *args, **kwargs)
|
||||
if default_component != _EMPTY_DEFAULT_COMPONENT:
|
||||
return default_component
|
||||
raise ValueError(f"No component found with name {name}")
|
||||
if not isinstance(component, component_type):
|
||||
raise TypeError(f"Component {name} is not of type {component_type}")
|
||||
return component
|
||||
|
||||
def before_start(self):
|
||||
"""Invoke the before_start hooks for all registered components."""
|
||||
for _, v in self.components.items():
|
||||
v.before_start()
|
||||
|
||||
async def async_before_start(self):
|
||||
"""Asynchronously invoke the before_start hooks for all registered components."""
|
||||
tasks = [v.async_before_start() for _, v in self.components.items()]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def after_start(self):
|
||||
"""Invoke the after_start hooks for all registered components."""
|
||||
for _, v in self.components.items():
|
||||
v.after_start()
|
||||
|
||||
async def async_after_start(self):
|
||||
"""Asynchronously invoke the after_start hooks for all registered components."""
|
||||
tasks = [v.async_after_start() for _, v in self.components.items()]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def before_stop(self):
|
||||
"""Invoke the before_stop hooks for all registered components."""
|
||||
for _, v in self.components.items():
|
||||
try:
|
||||
v.before_stop()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def async_before_stop(self):
|
||||
"""Asynchronously invoke the before_stop hooks for all registered components."""
|
||||
tasks = [v.async_before_stop() for _, v in self.components.items()]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _build(self):
|
||||
"""Integrate lifecycle events with the internal ASGI app if available."""
|
||||
if not self.app:
|
||||
return
|
||||
|
||||
@self.app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""ASGI app startup event handler."""
|
||||
|
||||
async def _startup_func():
|
||||
try:
|
||||
await self.async_after_start()
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting system app: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.create_task(_startup_func())
|
||||
self.after_start()
|
||||
|
||||
@self.app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""ASGI app shutdown event handler."""
|
||||
await self.async_before_stop()
|
||||
self.before_stop()
|
@@ -1,16 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
print("Setting random seed to 42")
|
||||
random.seed(42)
|
||||
|
||||
# Load the users .env file into environment variables
|
||||
load_dotenv(verbose=True, override=True)
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
load_dotenv(os.path.join(ROOT_PATH, ".plugin_env"))
|
||||
|
||||
del load_dotenv
|
@@ -1,284 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from pilot.component import SystemApp
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""Configuration class to store the state of bools for different scripts access"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Config class"""
|
||||
|
||||
self.NEW_SERVER_MODE = False
|
||||
self.SERVER_LIGHT_MODE = False
|
||||
|
||||
# Gradio language version: en, zh
|
||||
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
||||
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
|
||||
|
||||
self.debug_mode = False
|
||||
self.skip_reprompt = False
|
||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||
|
||||
# self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
|
||||
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
|
||||
)
|
||||
# User agent header to use when making HTTP requests
|
||||
# Some websites might just completely deny request with an error code if
|
||||
# no user agent was found.
|
||||
self.user_agent = os.getenv(
|
||||
"USER_AGENT",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36"
|
||||
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||
)
|
||||
|
||||
# This is a proxy server, just for test_py. we will remove this later.
|
||||
self.proxy_api_key = os.getenv("PROXY_API_KEY")
|
||||
self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY")
|
||||
|
||||
# In order to be compatible with the new and old model parameter design
|
||||
if self.bard_proxy_api_key:
|
||||
os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key
|
||||
|
||||
# tongyi
|
||||
self.tongyi_proxy_api_key = os.getenv("TONGYI_PROXY_API_KEY")
|
||||
if self.tongyi_proxy_api_key:
|
||||
os.environ["tongyi_proxyllm_proxy_api_key"] = self.tongyi_proxy_api_key
|
||||
|
||||
# zhipu
|
||||
self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY")
|
||||
if self.zhipu_proxy_api_key:
|
||||
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
|
||||
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
|
||||
"ZHIPU_MODEL_VERSION"
|
||||
)
|
||||
|
||||
# wenxin
|
||||
self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY")
|
||||
self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_API_SECRET")
|
||||
self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION")
|
||||
if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret:
|
||||
os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key
|
||||
os.environ[
|
||||
"wenxin_proxyllm_proxy_api_secret"
|
||||
] = self.wenxin_proxy_api_secret
|
||||
os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version
|
||||
|
||||
# xunfei spark
|
||||
self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION")
|
||||
self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
|
||||
self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
|
||||
self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID")
|
||||
if self.spark_proxy_api_key and self.spark_proxy_api_secret:
|
||||
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key
|
||||
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret
|
||||
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version
|
||||
os.environ["spark_proxyllm_proxy_api_app_id"] = self.spark_proxy_api_appid
|
||||
|
||||
# baichuan proxy
|
||||
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
||||
self.bc_proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET")
|
||||
self.bc_model_version = os.getenv("BAICHUN_MODEL_NAME")
|
||||
if self.bc_proxy_api_key and self.bc_proxy_api_secret:
|
||||
os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key
|
||||
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
|
||||
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
|
||||
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
|
||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID")
|
||||
self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID")
|
||||
|
||||
self.use_mac_os_tts = False
|
||||
self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS")
|
||||
|
||||
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
||||
self.exit_key = os.getenv("EXIT_KEY", "n")
|
||||
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
|
||||
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
||||
|
||||
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
|
||||
self.image_provider = os.getenv("IMAGE_PROVIDER")
|
||||
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
||||
self.huggingface_image_model = os.getenv(
|
||||
"HUGGINGFACE_IMAGE_MODEL", "CompVis/stable-diffusion-v1-4"
|
||||
)
|
||||
self.huggingface_audio_to_text_model = os.getenv(
|
||||
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
|
||||
)
|
||||
self.speak_mode = False
|
||||
|
||||
from pilot.prompts.prompt_registry import PromptTemplateRegistry
|
||||
|
||||
self.prompt_template_registry = PromptTemplateRegistry()
|
||||
### Related configuration of built-in commands
|
||||
self.command_registry = []
|
||||
|
||||
### Relate configuration of disply commands
|
||||
self.command_disply = []
|
||||
|
||||
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
|
||||
if disabled_command_categories:
|
||||
self.disabled_command_categories = disabled_command_categories.split(",")
|
||||
else:
|
||||
self.disabled_command_categories = []
|
||||
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
|
||||
)
|
||||
### message stor file
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
|
||||
self.plugins: List["AutoGPTPluginTemplate"] = []
|
||||
self.plugins_openai = []
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true"
|
||||
|
||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||
|
||||
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||
if plugins_allowlist:
|
||||
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||
else:
|
||||
self.plugins_allowlist = []
|
||||
|
||||
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
|
||||
if plugins_denylist:
|
||||
self.plugins_denylist = plugins_denylist.split(",")
|
||||
else:
|
||||
self.plugins_denylist = []
|
||||
### Native SQL Execution Capability Control Configuration
|
||||
self.NATIVE_SQL_CAN_RUN_DDL = (
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true"
|
||||
)
|
||||
self.NATIVE_SQL_CAN_RUN_WRITE = (
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
self.LOCAL_DB_MANAGE = None
|
||||
|
||||
###dbgpt meta info database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
|
||||
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
|
||||
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
|
||||
self.LOCAL_DB_HOST = "127.0.0.1"
|
||||
|
||||
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME", "dbgpt")
|
||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
|
||||
|
||||
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db")
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
|
||||
self.LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH")
|
||||
|
||||
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
|
||||
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
|
||||
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
||||
self.PROXYLLM_BACKEND = None
|
||||
if self.LLM_MODEL == "proxyllm":
|
||||
self.PROXYLLM_BACKEND = os.getenv("PROXYLLM_BACKEND")
|
||||
|
||||
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
|
||||
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
|
||||
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)
|
||||
self.MODEL_SERVER = os.getenv(
|
||||
"MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)
|
||||
)
|
||||
|
||||
### Vector Store Configuration
|
||||
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
|
||||
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
|
||||
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
|
||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||
|
||||
# QLoRA
|
||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"
|
||||
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False").lower() == "true"
|
||||
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
||||
self.IS_LOAD_8BIT = False
|
||||
# In order to be compatible with the new and old model parameter design
|
||||
os.environ["load_8bit"] = str(self.IS_LOAD_8BIT)
|
||||
os.environ["load_4bit"] = str(self.IS_LOAD_4BIT)
|
||||
|
||||
### EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
|
||||
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
|
||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
|
||||
# default recall similarity score, between 0 and 1
|
||||
self.KNOWLEDGE_SEARCH_RECALL_SCORE = float(
|
||||
os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3)
|
||||
)
|
||||
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
|
||||
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
|
||||
)
|
||||
# Whether to enable Chat Knowledge Search Rewrite Mode
|
||||
self.KNOWLEDGE_SEARCH_REWRITE = (
|
||||
os.getenv("KNOWLEDGE_SEARCH_REWRITE", "False").lower() == "true"
|
||||
)
|
||||
# Control whether to display the source document of knowledge on the front end.
|
||||
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = (
|
||||
os.getenv("KNOWLEDGE_CHAT_SHOW_RELATIONS", "False").lower() == "true"
|
||||
)
|
||||
|
||||
### SUMMARY_CONFIG Configuration
|
||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||
|
||||
self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None)
|
||||
|
||||
### Log level
|
||||
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
self.SYSTEM_APP: Optional["SystemApp"] = None
|
||||
|
||||
### Temporary configuration
|
||||
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
|
||||
|
||||
self.MODEL_CACHE_ENABLE: bool = (
|
||||
os.getenv("MODEL_CACHE_ENABLE", "True").lower() == "true"
|
||||
)
|
||||
self.MODEL_CACHE_STORAGE_TYPE: str = os.getenv(
|
||||
"MODEL_CACHE_STORAGE_TYPE", "disk"
|
||||
)
|
||||
self.MODEL_CACHE_MAX_MEMORY_MB: int = int(
|
||||
os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256)
|
||||
)
|
||||
self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
|
||||
"MODEL_CACHE_STORAGE_DISK_DIR"
|
||||
)
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
||||
def set_templature(self, value: int) -> None:
|
||||
"""Set the temperature value."""
|
||||
self.temperature = value
|
||||
|
||||
def set_speak_mode(self, value: bool) -> None:
|
||||
"""Set the speak mode value."""
|
||||
self.speak_mode = value
|
||||
|
||||
def set_last_plugin_return(self, value: bool) -> None:
|
||||
"""Set the speak mode value."""
|
||||
self.last_plugin_return = value
|
@@ -1,176 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
from functools import cache
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||
LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs"))
|
||||
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
|
||||
_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel")
|
||||
|
||||
current_directory = os.getcwd()
|
||||
|
||||
new_directory = PILOT_PATH
|
||||
os.chdir(new_directory)
|
||||
|
||||
|
||||
@cache
|
||||
def get_device() -> str:
|
||||
try:
|
||||
import torch
|
||||
|
||||
return (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return "cpu"
|
||||
|
||||
|
||||
LLM_MODEL_CONFIG = {
|
||||
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
|
||||
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
||||
"vicuna-7b": os.path.join(MODEL_PATH, "vicuna-7b"),
|
||||
# (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
"vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"),
|
||||
"vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"),
|
||||
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
|
||||
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
|
||||
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
|
||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||
"chatglm2-6b": os.path.join(MODEL_PATH, "chatglm2-6b"),
|
||||
"chatglm2-6b-int4": os.path.join(MODEL_PATH, "chatglm2-6b-int4"),
|
||||
# https://huggingface.co/THUDM/chatglm3-6b
|
||||
"chatglm3-6b": os.path.join(MODEL_PATH, "chatglm3-6b"),
|
||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
|
||||
"proxyllm": "chatgpt_proxyllm",
|
||||
"chatgpt_proxyllm": "chatgpt_proxyllm",
|
||||
"bard_proxyllm": "bard_proxyllm",
|
||||
"claude_proxyllm": "claude_proxyllm",
|
||||
"wenxin_proxyllm": "wenxin_proxyllm",
|
||||
"tongyi_proxyllm": "tongyi_proxyllm",
|
||||
"zhipu_proxyllm": "zhipu_proxyllm",
|
||||
"bc_proxyllm": "bc_proxyllm",
|
||||
"spark_proxyllm": "spark_proxyllm",
|
||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
|
||||
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
|
||||
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
|
||||
# https://huggingface.co/Qwen/Qwen-7B-Chat
|
||||
"qwen-7b-chat": os.path.join(MODEL_PATH, "Qwen-7B-Chat"),
|
||||
# https://huggingface.co/Qwen/Qwen-7B-Chat-Int8
|
||||
"qwen-7b-chat-int8": os.path.join(MODEL_PATH, "Qwen-7B-Chat-Int8"),
|
||||
# https://huggingface.co/Qwen/Qwen-7B-Chat-Int4
|
||||
"qwen-7b-chat-int4": os.path.join(MODEL_PATH, "Qwen-7B-Chat-Int4"),
|
||||
# https://huggingface.co/Qwen/Qwen-14B-Chat
|
||||
"qwen-14b-chat": os.path.join(MODEL_PATH, "Qwen-14B-Chat"),
|
||||
# https://huggingface.co/Qwen/Qwen-14B-Chat-Int8
|
||||
"qwen-14b-chat-int8": os.path.join(MODEL_PATH, "Qwen-14B-Chat-Int8"),
|
||||
# https://huggingface.co/Qwen/Qwen-14B-Chat-Int4
|
||||
"qwen-14b-chat-int4": os.path.join(MODEL_PATH, "Qwen-14B-Chat-Int4"),
|
||||
# https://huggingface.co/Qwen/Qwen-72B-Chat
|
||||
"qwen-72b-chat": os.path.join(MODEL_PATH, "Qwen-72B-Chat"),
|
||||
# https://huggingface.co/Qwen/Qwen-72B-Chat-Int8
|
||||
"qwen-72b-chat-int8": os.path.join(MODEL_PATH, "Qwen-72B-Chat-Int8"),
|
||||
# https://huggingface.co/Qwen/Qwen-72B-Chat-Int4
|
||||
"qwen-72b-chat-int4": os.path.join(MODEL_PATH, "Qwen-72B-Chat-Int4"),
|
||||
# https://huggingface.co/Qwen/Qwen-1_8B-Chat
|
||||
"qwen-1.8b-chat": os.path.join(MODEL_PATH, "Qwen-1_8B-Chat"),
|
||||
# https://huggingface.co/Qwen/Qwen-1_8B-Chat-Int8
|
||||
"qwen-1.8b-chat-int8": os.path.join(MODEL_PATH, "wen-1_8B-Chat-Int8"),
|
||||
# https://huggingface.co/Qwen/Qwen-1_8B-Chat-Int4
|
||||
"qwen-1.8b-chat-int4": os.path.join(MODEL_PATH, "Qwen-1_8B-Chat-Int4"),
|
||||
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||
# wget https://huggingface.co/TheBloke/vicuna-13B-v1.5-GGUF/resolve/main/vicuna-13b-v1.5.Q4_K_M.gguf -O models/ggml-model-q4_0.gguf
|
||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.gguf"),
|
||||
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
|
||||
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
|
||||
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
|
||||
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
|
||||
"codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
|
||||
"codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
|
||||
"codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
|
||||
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
|
||||
# For test now
|
||||
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
|
||||
# https://huggingface.co/microsoft/Orca-2-7b
|
||||
"orca-2-7b": os.path.join(MODEL_PATH, "Orca-2-7b"),
|
||||
# https://huggingface.co/microsoft/Orca-2-13b
|
||||
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
|
||||
# https://huggingface.co/openchat/openchat_3.5
|
||||
"openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
|
||||
# https://huggingface.co/hfl/chinese-alpaca-2-7b
|
||||
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
|
||||
# https://huggingface.co/hfl/chinese-alpaca-2-13b
|
||||
"chinese-alpaca-2-13b": os.path.join(MODEL_PATH, "chinese-alpaca-2-13b"),
|
||||
# https://huggingface.co/THUDM/codegeex2-6b
|
||||
"codegeex2-6b": os.path.join(MODEL_PATH, "codegeex2-6b"),
|
||||
# https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
|
||||
"zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"),
|
||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
"mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"),
|
||||
# https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
|
||||
"mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"),
|
||||
# https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1
|
||||
"xwin-lm-7b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-7B-V0.1"),
|
||||
# https://huggingface.co/Xwin-LM/Xwin-LM-13B-V0.1
|
||||
"xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"),
|
||||
# https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1
|
||||
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
|
||||
# https://huggingface.co/01-ai/Yi-34B-Chat
|
||||
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
|
||||
# https://huggingface.co/01-ai/Yi-34B-Chat-8bits
|
||||
"yi-34b-chat-8bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-8bits"),
|
||||
# https://huggingface.co/01-ai/Yi-34B-Chat-4bits
|
||||
"yi-34b-chat-4bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-4bits"),
|
||||
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_CONFIG = {
|
||||
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||
# https://huggingface.co/moka-ai/m3e-large
|
||||
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
|
||||
# https://huggingface.co/moka-ai/m3e-base
|
||||
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
|
||||
# https://huggingface.co/BAAI/bge-large-en
|
||||
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
|
||||
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
|
||||
# https://huggingface.co/BAAI/bge-large-zh
|
||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"proxy_openai": "proxy_openai",
|
||||
"proxy_azure": "proxy_azure",
|
||||
}
|
||||
|
||||
# Load model config
|
||||
ISDEBUG = False
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 10
|
||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "data"
|
||||
)
|
||||
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100
|
@@ -1 +0,0 @@
|
||||
from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao
|
@@ -1,59 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
"""We need to design a base class. That other connector can Write with this"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
|
||||
class BaseConnect(ABC):
|
||||
def get_connect(self, db_name: str):
|
||||
pass
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
pass
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def get_example_data(self, table: str, count: int = 3):
|
||||
pass
|
||||
|
||||
def get_database_list(self):
|
||||
pass
|
||||
|
||||
def get_database_names(self):
|
||||
pass
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
pass
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
pass
|
||||
|
||||
def run_to_df(self, command: str, fetch: str = "all"):
|
||||
pass
|
||||
|
||||
def get_users(self):
|
||||
pass
|
||||
|
||||
def get_grants(self):
|
||||
pass
|
||||
|
||||
def get_collation(self):
|
||||
pass
|
||||
|
||||
def get_charset(self):
|
||||
pass
|
||||
|
||||
def get_fields(self, table_name):
|
||||
pass
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
pass
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
pass
|
@@ -1,119 +0,0 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from pilot.connections.base import BaseConnect
|
||||
|
||||
|
||||
class SparkConnect(BaseConnect):
|
||||
"""
|
||||
Spark Connect supports operating on a variety of data sources through the DataFrame interface.
|
||||
A DataFrame can be operated on using relational transformations and can also be used to create a temporary view.
|
||||
Registering a DataFrame as a temporary view allows you to run SQL queries over its data.
|
||||
Datasource now support parquet, jdbc, orc, libsvm, csv, text, json.
|
||||
"""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "spark"
|
||||
"""db driver"""
|
||||
driver: str = "spark"
|
||||
"""db dialect"""
|
||||
dialect: str = "sparksql"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
spark_session: Optional = None,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Spark DataFrame from Datasource path
|
||||
return: Spark DataFrame
|
||||
"""
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
self.spark_session = (
|
||||
spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate()
|
||||
)
|
||||
self.path = file_path
|
||||
self.table_name = "temp"
|
||||
self.df = self.create_df(self.path)
|
||||
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
):
|
||||
try:
|
||||
return cls(file_path=file_path, engine_args=engine_args)
|
||||
|
||||
except Exception as e:
|
||||
print("load spark datasource error" + str(e))
|
||||
|
||||
def create_df(self, path):
|
||||
"""Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.).
|
||||
return: Spark DataFrame
|
||||
reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html
|
||||
"""
|
||||
extension = (
|
||||
"text" if path.rsplit(".", 1)[-1] == "txt" else path.rsplit(".", 1)[-1]
|
||||
)
|
||||
return self.spark_session.read.load(
|
||||
path, format=extension, inferSchema="true", header="true"
|
||||
)
|
||||
|
||||
def run(self, sql):
|
||||
print(f"spark sql to run is {sql}")
|
||||
self.df.createOrReplaceTempView(self.table_name)
|
||||
df = self.spark_session.sql(sql)
|
||||
first_row = df.first()
|
||||
rows = [first_row.asDict().keys()]
|
||||
for row in df.collect():
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def query_ex(self, sql):
|
||||
rows = self.run(sql)
|
||||
field_names = rows[0]
|
||||
return field_names, rows
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
|
||||
return "ans"
|
||||
|
||||
def get_fields(self):
|
||||
"""Get column meta about dataframe."""
|
||||
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_db_list(self):
|
||||
return ["default"]
|
||||
|
||||
def get_db_names(self):
|
||||
return ["default"]
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def table_simple_info(self):
|
||||
return f"{self.table_name}{self.get_fields()}"
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
return ""
|
@@ -1,17 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DBConfig(BaseModel):
|
||||
db_type: str
|
||||
db_name: str
|
||||
file_path: str = ""
|
||||
db_host: str = ""
|
||||
db_port: int = 0
|
||||
db_user: str = ""
|
||||
db_pwd: str = ""
|
||||
comment: str = ""
|
||||
|
||||
|
||||
class DbTypeInfo(BaseModel):
|
||||
db_type: str
|
||||
is_file_db: bool = False
|
@@ -1,244 +0,0 @@
|
||||
from sqlalchemy import Column, Integer, String, Index, Text, text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
|
||||
|
||||
class ConnectConfigEntity(Base):
|
||||
"""db connect config entity"""
|
||||
|
||||
__tablename__ = "connect_config"
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
|
||||
db_type = Column(String(255), nullable=False, comment="db type")
|
||||
db_name = Column(String(255), nullable=False, comment="db name")
|
||||
db_path = Column(String(255), nullable=True, comment="file db path")
|
||||
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
|
||||
db_port = Column(String(255), nullable=True, comment="db connect port(not file db)")
|
||||
db_user = Column(String(255), nullable=True, comment="db user")
|
||||
db_pwd = Column(String(255), nullable=True, comment="db password")
|
||||
comment = Column(Text, nullable=True, comment="db comment")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("db_name", name="uk_db"),
|
||||
Index("idx_q_db_type", "db_type"),
|
||||
{"mysql_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
|
||||
)
|
||||
|
||||
|
||||
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
"""db connect config dao"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def update(self, entity: ConnectConfigEntity):
|
||||
"""update db connect info"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def delete(self, db_name: int):
|
||||
""" "delete db connect info"""
|
||||
session = self.get_session()
|
||||
if db_name is None:
|
||||
raise Exception("db_name is None")
|
||||
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
db_connect.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
|
||||
"""get db connect info by name"""
|
||||
session = self.get_session()
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
result = db_connect.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def add_url_db(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_host: str,
|
||||
db_port: int,
|
||||
db_user: str,
|
||||
db_pwd: str,
|
||||
comment: str = "",
|
||||
):
|
||||
"""
|
||||
add db connect info
|
||||
Args:
|
||||
db_name: db name
|
||||
db_type: db type
|
||||
db_host: db host
|
||||
db_port: db port
|
||||
db_user: db user
|
||||
db_pwd: db password
|
||||
comment: comment
|
||||
"""
|
||||
try:
|
||||
session = self.get_session()
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
insert_statement = text(
|
||||
"""
|
||||
INSERT INTO connect_config (
|
||||
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
|
||||
) VALUES (
|
||||
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
params = {
|
||||
"db_name": db_name,
|
||||
"db_type": db_type,
|
||||
"db_path": "",
|
||||
"db_host": db_host,
|
||||
"db_port": db_port,
|
||||
"db_user": db_user,
|
||||
"db_pwd": db_pwd,
|
||||
"comment": comment,
|
||||
}
|
||||
session.execute(insert_statement, params)
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("add db connect info error!" + str(e))
|
||||
|
||||
def update_db_info(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_path: str = "",
|
||||
db_host: str = "",
|
||||
db_port: int = 0,
|
||||
db_user: str = "",
|
||||
db_pwd: str = "",
|
||||
comment: str = "",
|
||||
):
|
||||
"""update db connect info"""
|
||||
old_db_conf = self.get_db_config(db_name)
|
||||
if old_db_conf:
|
||||
try:
|
||||
session = self.get_session()
|
||||
if not db_path:
|
||||
update_statement = text(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
||||
)
|
||||
else:
|
||||
update_statement = text(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
|
||||
)
|
||||
session.execute(update_statement)
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("edit db connect info error!" + str(e))
|
||||
return True
|
||||
raise ValueError(f"{db_name} not have config info!")
|
||||
|
||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||
"""add file db connect info"""
|
||||
try:
|
||||
session = self.get_session()
|
||||
insert_statement = text(
|
||||
"""
|
||||
INSERT INTO connect_config(
|
||||
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
|
||||
) VALUES (
|
||||
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
|
||||
)
|
||||
"""
|
||||
)
|
||||
params = {
|
||||
"db_name": db_name,
|
||||
"db_type": db_type,
|
||||
"db_path": db_path,
|
||||
"db_host": "",
|
||||
"db_port": 0,
|
||||
"db_user": "",
|
||||
"db_pwd": "",
|
||||
"comment": comment,
|
||||
}
|
||||
|
||||
session.execute(insert_statement, params)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("add db connect info error!" + str(e))
|
||||
|
||||
def get_db_config(self, db_name):
|
||||
"""get db config by name"""
|
||||
session = self.get_session()
|
||||
if db_name:
|
||||
select_statement = text(
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
connect_config
|
||||
WHERE
|
||||
db_name = :db_name
|
||||
"""
|
||||
)
|
||||
params = {"db_name": db_name}
|
||||
result = session.execute(select_statement, params)
|
||||
|
||||
else:
|
||||
raise ValueError("Cannot get database by name" + db_name)
|
||||
|
||||
fields = [field[0] for field in result.cursor.description]
|
||||
row_dict = {}
|
||||
row_1 = list(result.cursor.fetchall()[0])
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row_1[i]
|
||||
return row_dict
|
||||
|
||||
def get_db_list(self):
|
||||
"""get db list"""
|
||||
session = self.get_session()
|
||||
result = session.execute(text("SELECT * FROM connect_config"))
|
||||
|
||||
fields = [field[0] for field in result.cursor.description]
|
||||
data = []
|
||||
for row in result.cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
return data
|
||||
|
||||
def delete_db(self, db_name):
|
||||
"""delete db connect info"""
|
||||
session = self.get_session()
|
||||
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
|
||||
params = {"db_name": db_name}
|
||||
session.execute(delete_statement, params)
|
||||
session.commit()
|
||||
session.close()
|
||||
return True
|
@@ -1,147 +0,0 @@
|
||||
import os
|
||||
import duckdb
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
|
||||
table_name = "connect_config"
|
||||
|
||||
|
||||
class DuckdbConnectConfig:
|
||||
def __init__(self):
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
self.connect = duckdb.connect(duckdb_path)
|
||||
self.__init_config_tables()
|
||||
|
||||
def __init_config_tables(self):
|
||||
# check config table
|
||||
result = self.connect.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||
).fetchall()
|
||||
|
||||
if not result:
|
||||
# create config table
|
||||
self.connect.execute(
|
||||
"CREATE TABLE connect_config (id integer primary key, db_name VARCHAR(100) UNIQUE, db_type VARCHAR(50), db_path VARCHAR(255) NULL, db_host VARCHAR(255) NULL, db_port INTEGER NULL, db_user VARCHAR(255) NULL, db_pwd VARCHAR(255) NULL, comment TEXT NULL)"
|
||||
)
|
||||
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
|
||||
|
||||
def add_url_db(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_host: str,
|
||||
db_port: int,
|
||||
db_user: str,
|
||||
db_pwd: str,
|
||||
comment: str = "",
|
||||
):
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||
[db_name, db_type, "", db_host, db_port, db_user, db_pwd, comment],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("add db connect info error1!" + str(e))
|
||||
|
||||
def update_db_info(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_path: str = "",
|
||||
db_host: str = "",
|
||||
db_port: int = 0,
|
||||
db_user: str = "",
|
||||
db_pwd: str = "",
|
||||
comment: str = "",
|
||||
):
|
||||
old_db_conf = self.get_db_config(db_name)
|
||||
if old_db_conf:
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
if not db_path:
|
||||
cursor.execute(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("edit db connect info error2!" + str(e))
|
||||
return True
|
||||
raise ValueError(f"{db_name} not have config info!")
|
||||
|
||||
def get_file_db_name(self, path):
|
||||
try:
|
||||
conn = duckdb.connect(path)
|
||||
result = conn.execute("SELECT current_database()").fetchone()[0]
|
||||
return result
|
||||
except Exception as e:
|
||||
raise "Unusable duckdb database path:" + path
|
||||
|
||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||
[db_name, db_type, db_path, "", 0, "", "", comment],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("add db connect info error2!" + str(e))
|
||||
|
||||
def delete_db(self, db_name):
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("DELETE FROM connect_config where db_name=?", [db_name])
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
def get_db_config(self, db_name):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
if db_name:
|
||||
cursor.execute(
|
||||
"SELECT * FROM connect_config where db_name=? ", [db_name]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Cannot get database by name" + db_name)
|
||||
|
||||
fields = [field[0] for field in cursor.description]
|
||||
row_dict = {}
|
||||
row_1 = list(cursor.fetchall()[0])
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row_1[i]
|
||||
return row_dict
|
||||
return None
|
||||
|
||||
def get_db_list(self):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute("SELECT * FROM connect_config ")
|
||||
|
||||
fields = [field[0] for field in cursor.description]
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
return data
|
||||
|
||||
return []
|
||||
|
||||
def get_db_names(self):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute("SELECT db_name FROM connect_config ")
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
data.append(row[0])
|
||||
return data
|
||||
return []
|
@@ -1,160 +0,0 @@
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections import ConnectConfigDao
|
||||
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.component import SystemApp, ComponentType
|
||||
from pilot.utils.executor_utils import ExecutorFactory
|
||||
|
||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||
from pilot.connections.base import BaseConnect
|
||||
|
||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
|
||||
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
|
||||
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect
|
||||
from pilot.connections.rdbms.conn_postgresql import PostgreSQLDatabase
|
||||
from pilot.connections.rdbms.conn_starrocks import StarRocksConnect
|
||||
from pilot.connections.rdbms.conn_doris import DorisConnect
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.connections.db_conn_info import DBConfig
|
||||
from pilot.connections.conn_spark import SparkConnect
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ConnectManager:
|
||||
"""db connect manager"""
|
||||
|
||||
def get_all_subclasses(self, cls):
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += self.get_all_subclasses(subclass)
|
||||
return subclasses
|
||||
|
||||
def get_all_completed_types(self):
|
||||
chat_classes = self.get_all_subclasses(BaseConnect)
|
||||
support_types = []
|
||||
for cls in chat_classes:
|
||||
if cls.db_type:
|
||||
support_types.append(DBType.of_db_type(cls.db_type))
|
||||
return support_types
|
||||
|
||||
def get_cls_by_dbtype(self, db_type):
|
||||
chat_classes = self.get_all_subclasses(BaseConnect)
|
||||
result = None
|
||||
for cls in chat_classes:
|
||||
if cls.db_type == db_type:
|
||||
result = cls
|
||||
if not result:
|
||||
raise ValueError("Unsupported Db Type!" + db_type)
|
||||
return result
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
"""metadata database management initialization"""
|
||||
# self.storage = DuckdbConnectConfig()
|
||||
self.storage = ConnectConfigDao()
|
||||
self.db_summary_client = DBSummaryClient(system_app)
|
||||
|
||||
def get_connect(self, db_name):
|
||||
db_config = self.storage.get_db_config(db_name)
|
||||
db_type = DBType.of_db_type(db_config.get("db_type"))
|
||||
connect_instance = self.get_cls_by_dbtype(db_type.value())
|
||||
if db_type.is_file_db():
|
||||
db_path = db_config.get("db_path")
|
||||
return connect_instance.from_file_path(db_path)
|
||||
else:
|
||||
db_host = db_config.get("db_host")
|
||||
db_port = db_config.get("db_port")
|
||||
db_user = db_config.get("db_user")
|
||||
db_pwd = db_config.get("db_pwd")
|
||||
return connect_instance.from_uri_db(
|
||||
host=db_host, port=db_port, user=db_user, pwd=db_pwd, db_name=db_name
|
||||
)
|
||||
|
||||
def test_connect(self, db_info: DBConfig):
|
||||
try:
|
||||
db_type = DBType.of_db_type(db_info.db_type)
|
||||
connect_instance = self.get_cls_by_dbtype(db_type.value())
|
||||
if db_type.is_file_db():
|
||||
db_path = db_info.file_path
|
||||
return connect_instance.from_file_path(db_path)
|
||||
else:
|
||||
db_name = db_info.db_name
|
||||
db_host = db_info.db_host
|
||||
db_port = db_info.db_port
|
||||
db_user = db_info.db_user
|
||||
db_pwd = db_info.db_pwd
|
||||
return connect_instance.from_uri_db(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=db_user,
|
||||
pwd=db_pwd,
|
||||
db_name=db_name,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{db_info.db_name} Test connect Failure!{str(e)}")
|
||||
raise ValueError(f"{db_info.db_name} Test connect Failure!{str(e)}")
|
||||
|
||||
def get_db_list(self):
|
||||
return self.storage.get_db_list()
|
||||
|
||||
def get_db_names(self):
|
||||
return self.storage.get_by_name()
|
||||
|
||||
def delete_db(self, db_name: str):
|
||||
return self.storage.delete_db(db_name)
|
||||
|
||||
def edit_db(self, db_info: DBConfig):
|
||||
return self.storage.update_db_info(
|
||||
db_info.db_name,
|
||||
db_info.db_type,
|
||||
db_info.file_path,
|
||||
db_info.db_host,
|
||||
db_info.db_port,
|
||||
db_info.db_user,
|
||||
db_info.db_pwd,
|
||||
db_info.comment,
|
||||
)
|
||||
|
||||
async def async_db_summary_embedding(self, db_name, db_type):
|
||||
# 在这里执行需要异步运行的代码
|
||||
self.db_summary_client.db_summary_embedding(db_name, db_type)
|
||||
|
||||
def add_db(self, db_info: DBConfig):
|
||||
print(f"add_db:{db_info.__dict__}")
|
||||
try:
|
||||
db_type = DBType.of_db_type(db_info.db_type)
|
||||
if db_type.is_file_db():
|
||||
self.storage.add_file_db(
|
||||
db_info.db_name, db_info.db_type, db_info.file_path
|
||||
)
|
||||
else:
|
||||
self.storage.add_url_db(
|
||||
db_info.db_name,
|
||||
db_info.db_type,
|
||||
db_info.db_host,
|
||||
db_info.db_port,
|
||||
db_info.db_user,
|
||||
db_info.db_pwd,
|
||||
db_info.comment,
|
||||
)
|
||||
# async embedding
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(
|
||||
self.db_summary_client.db_summary_embedding,
|
||||
db_info.db_name,
|
||||
db_info.db_type,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Add db connect info error!" + str(e))
|
||||
|
||||
return True
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user