chore: Add pylint for DB-GPT core lib (#1076)

This commit is contained in:
Fangyin Cheng 2024-01-16 17:36:26 +08:00 committed by GitHub
parent 3a54d1ef9a
commit 40c853575a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 2213 additions and 839 deletions

40
.flake8 Normal file
View File

@ -0,0 +1,40 @@
[flake8]
exclude =
.eggs/
build/
*/tests/*
*_private
max-line-length = 88
inline-quotes = "
ignore =
C408
C417
E121
E123
E126
E203
E226
E24
E704
W503
W504
W605
I
N
B001
B002
B003
B004
B005
B007
B008
B009
B010
B011
B012
B013
B014
B015
B016
B017
avoid-escape = no

13
.isort.cfg Normal file
View File

@ -0,0 +1,13 @@
[settings]
# This is to make isort compatible with Black. See
# https://black.readthedocs.io/en/stable/the_black_code_style.html#how-black-wraps-lines.
line_length=88
profile=black
multi_line_output=3
include_trailing_comma=True
use_parentheses=True
float_to_top=True
filter_files=True
skip_glob=examples/notebook/*
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER,AFTERRAY

20
.mypy.ini Normal file
View File

@ -0,0 +1,20 @@
[mypy]
exclude = /tests/
# plugins = pydantic.mypy
[mypy-graphviz.*]
ignore_missing_imports = True
[mypy-cachetools.*]
ignore_missing_imports = True
[mypy-coloredlogs.*]
ignore_missing_imports = True
[mypy-termcolor.*]
ignore_missing_imports = True
[mypy-pydantic.*]
strict_optional = False
ignore_missing_imports = True
follow_imports = skip

View File

@ -20,4 +20,22 @@ repos:
stages: [commit]
pass_filenames: false
args: []
- id: python-test-doc
name: Python Doc Test
entry: make test-doc
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-lint-mypy
name: Python Lint mypy
entry: make mypy
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []

View File

@ -14,7 +14,7 @@ setup: $(VENV)/bin/activate
$(VENV)/bin/activate: $(VENV)/.venv-timestamp
$(VENV)/.venv-timestamp: setup.py
$(VENV)/.venv-timestamp: setup.py requirements
# Create new virtual environment if setup.py has changed
python3 -m venv $(VENV)
$(VENV_BIN)/pip install --upgrade pip
@ -46,15 +46,14 @@ fmt: setup ## Format Python code
# $(VENV_BIN)/blackdoc .
$(VENV_BIN)/blackdoc dbgpt
$(VENV_BIN)/blackdoc examples
# TODO: Type checking of Python code.
# https://github.com/python/mypy
# $(VENV_BIN)/mypy dbgpt
# TODO: uUse flake8 to enforce Python style guide.
# TODO: Use flake8 to enforce Python style guide.
# https://flake8.pycqa.org/en/latest/
# $(VENV_BIN)/flake8 dbgpt
$(VENV_BIN)/flake8 dbgpt/core/
# TODO: More package checks with flake8.
.PHONY: pre-commit
pre-commit: fmt test ## Run formatting and unit tests before committing
pre-commit: fmt test test-doc mypy ## Run formatting and unit tests before committing
test: $(VENV)/.testenv ## Run unit tests
$(VENV_BIN)/pytest dbgpt
@ -64,6 +63,12 @@ test-doc: $(VENV)/.testenv ## Run doctests
# -k "not test_" skips tests that are not doctests.
$(VENV_BIN)/pytest --doctest-modules -k "not test_" dbgpt/core
.PHONY: mypy
mypy: $(VENV)/.testenv ## Run mypy checks
# https://github.com/python/mypy
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
# TODO: More package checks with mypy.
.PHONY: coverage
coverage: setup ## Run tests and report coverage
$(VENV_BIN)/pytest dbgpt --cov=dbgpt

View File

@ -1,22 +1,22 @@
import logging
import os
import uuid
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import pandas as pd
import seaborn as sns
from matplotlib.font_manager import FontManager
from pandas import DataFrame
from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.util.string_utils import is_scientific_notation
from ...command_mange import command
matplotlib.use("Agg")
import logging
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager
from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.util.string_utils import is_scientific_notation
logger = logging.getLogger(__name__)

View File

@ -192,7 +192,8 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F
with engine_no_db.connect() as conn:
conn.execute(
DDL(
f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE "
f"utf8mb4_unicode_ci"
)
)
logger.info(f"Database {db_name} successfully created")
@ -218,26 +219,31 @@ class WebServerParameters(BaseParameters):
controller_addr: Optional[str] = field(
default=None,
metadata={
"help": "The Model controller address to connect. If None, read model controller address from environment key `MODEL_SERVER`."
"help": "The Model controller address to connect. If None, read model "
"controller address from environment key `MODEL_SERVER`."
},
)
model_name: str = field(
default=None,
metadata={
"help": "The default model name to use. If None, read model name from environment key `LLM_MODEL`.",
"help": "The default model name to use. If None, read model name from "
"environment key `LLM_MODEL`.",
"tags": "fixed",
},
)
share: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
"help": "Whether to create a publicly shareable link for the interface. "
"Creates an SSH tunnel to make your UI accessible from anywhere. "
},
)
remote_embedding: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`"
"help": "Whether to enable remote embedding models. If it is True, you need"
" to start a embedding model through `dbgpt start worker --worker_type "
"text2vec --model_name xxx --model_path xxx`"
},
)
log_level: Optional[str] = field(
@ -286,3 +292,10 @@ class WebServerParameters(BaseParameters):
"help": "The directories to search awel files, split by `,`",
},
)
default_thread_pool_size: Optional[int] = field(
default=None,
metadata={
"help": "The default thread pool size, If None, "
"use default config of python thread pool",
},
)

View File

@ -25,7 +25,9 @@ def initialize_components(
from dbgpt.model.cluster.controller.controller import controller
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
system_app.register(
DefaultExecutorFactory, max_workers=param.default_thread_pool_size
)
system_app.register_instance(controller)
from dbgpt.serve.agent.hub.controller import module_agent

View File

@ -3,8 +3,6 @@ import os
import sys
from typing import List
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
@ -41,6 +39,10 @@ from dbgpt.util.utils import (
setup_logging,
)
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
static_file_path = os.path.join(ROOT_PATH, "dbgpt", "app/static")
CFG = Config()

View File

@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from urllib.parse import urljoin
import requests
from prettytable import PrettyTable
from dbgpt.app.knowledge.request.request import (
ChunkQueryRequest,
@ -193,9 +194,6 @@ def knowledge_init(
return
from prettytable import PrettyTable
class _KnowledgeVisualizer:
def __init__(self, api_address: str, out_format: str):
self.client = KnowledgeApiClient(api_address)

View File

@ -4,13 +4,14 @@
import os
import sys
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG
from dbgpt.model.cluster import run_worker_manager
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
CFG = Config()
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)

View File

@ -313,8 +313,9 @@ class BaseChat(ABC):
)
### store current conversation
span.end(metadata={"error": str(e)})
# self.memory.append(self.current_message)
self.current_message.end_current_round()
await blocking_func_to_async(
self._executor, self.current_message.end_current_round
)
async def nostream_call(self):
payload = await self._build_model_request()
@ -381,8 +382,9 @@ class BaseChat(ABC):
)
span.end(metadata={"error": str(e)})
### store dialogue
# self.memory.append(self.current_message)
self.current_message.end_current_round()
await blocking_func_to_async(
self._executor, self.current_message.end_current_round
)
return self.current_ai_response()
async def get_llm_response(self):

View File

@ -104,28 +104,37 @@ class BaseComponent(LifeCycle, ABC):
@classmethod
def get_instance(
cls,
cls: Type[T],
system_app: SystemApp,
default_component=_EMPTY_DEFAULT_COMPONENT,
or_register_component: Type[BaseComponent] = None,
or_register_component: Optional[Type[T]] = None,
*args,
**kwargs,
) -> BaseComponent:
) -> T:
"""Get the current component instance.
Args:
system_app (SystemApp): The system app
default_component : The default component instance if not retrieve by name
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
or_register_component (Type[T]): The new component to register if not retrieve by name
Returns:
BaseComponent: The component instance
T: The component instance
"""
# Check for keyword argument conflicts
if "default_component" in kwargs:
raise ValueError(
"default_component argument given in both fixed and **kwargs"
)
if "or_register_component" in kwargs:
raise ValueError(
"or_register_component argument given in both fixed and **kwargs"
)
kwargs["default_component"] = default_component
kwargs["or_register_component"] = or_register_component
return system_app.get_component(
cls.name,
cls,
default_component=default_component,
or_register_component=or_register_component,
*args,
**kwargs,
)
@ -159,11 +168,11 @@ class SystemApp(LifeCycle):
"""Returns the internal AppConfig."""
return self._app_config
def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
def register(self, component: Type[T], *args, **kwargs) -> T:
"""Register a new component by its type.
Args:
component (Type[BaseComponent]): The component class to register
component (Type[T]): The component class to register
Returns:
T: The instance of registered component
@ -198,7 +207,7 @@ class SystemApp(LifeCycle):
name: Union[str, ComponentType],
component_type: Type[T],
default_component=_EMPTY_DEFAULT_COMPONENT,
or_register_component: Type[BaseComponent] = None,
or_register_component: Optional[Type[T]] = None,
*args,
**kwargs,
) -> T:
@ -208,7 +217,7 @@ class SystemApp(LifeCycle):
name (Union[str, ComponentType]): Component name
component_type (Type[T]): The type of current retrieve component
default_component : The default component instance if not retrieve by name
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
or_register_component (Type[T]): The new component to register if not retrieve by name
Returns:
T: The instance retrieved by component name

View File

@ -1,11 +1,13 @@
from dbgpt.core.interface.cache import (
"""The core module contains the core interfaces and classes for dbgpt."""
from dbgpt.core.interface.cache import ( # noqa: F401
CacheClient,
CacheConfig,
CacheKey,
CachePolicy,
CacheValue,
)
from dbgpt.core.interface.llm import (
from dbgpt.core.interface.llm import ( # noqa: F401
DefaultMessageConverter,
LLMClient,
MessageConverter,
@ -16,7 +18,7 @@ from dbgpt.core.interface.llm import (
ModelRequest,
ModelRequestContext,
)
from dbgpt.core.interface.message import (
from dbgpt.core.interface.message import ( # noqa: F401
AIMessage,
BaseMessage,
ConversationIdentifier,
@ -29,8 +31,11 @@ from dbgpt.core.interface.message import (
StorageConversation,
SystemMessage,
)
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
from dbgpt.core.interface.prompt import (
from dbgpt.core.interface.output_parser import ( # noqa: F401
BaseOutputParser,
SQLOutputParser,
)
from dbgpt.core.interface.prompt import ( # noqa: F401
BasePromptTemplate,
ChatPromptTemplate,
HumanPromptTemplate,
@ -40,8 +45,8 @@ from dbgpt.core.interface.prompt import (
StoragePromptTemplate,
SystemPromptTemplate,
)
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.core.interface.storage import (
from dbgpt.core.interface.serialization import Serializable, Serializer # noqa: F401
from dbgpt.core.interface.storage import ( # noqa: F401
DefaultStorageItemAdapter,
InMemoryStorage,
QuerySpec,

View File

@ -1,27 +1,34 @@
"""Example selector base class"""
from abc import ABC
from enum import Enum
from typing import List
from typing import List, Optional
from dbgpt._private.pydantic import BaseModel
class ExampleType(Enum):
"""Example type"""
ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot"
class ExampleSelector(BaseModel, ABC):
"""Example selector base class"""
examples_record: List[dict]
use_example: bool = False
type: str = ExampleType.ONE_SHOT.value
def examples(self, count: int = 2):
"""Return examples"""
if ExampleType.ONE_SHOT.value == self.type:
return self.__one_show_context()
return self.__one_shot_context()
else:
return self.__few_shot_context(count)
def __few_shot_context(self, count: int = 2) -> List[dict]:
def __few_shot_context(self, count: int = 2) -> Optional[List[dict]]:
"""
Use 2 or more examples, default 2
Returns: example text
@ -31,14 +38,14 @@ class ExampleSelector(BaseModel, ABC):
return need_use
return None
def __one_show_context(self) -> dict:
def __one_shot_context(self) -> Optional[dict]:
"""
Use one examples
Returns:
"""
if self.use_example:
need_use = self.examples_record[:1]
need_use = self.examples_record[-1]
return need_use
return None

View File

@ -1,8 +1,12 @@
"""Prompt template registry.
This module is deprecated. we will remove it in the future.
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from collections import defaultdict
from typing import Dict, List
from typing import Dict, List, Optional
_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__"
_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__"
@ -14,15 +18,15 @@ class PromptTemplateRegistry:
"""
def __init__(self) -> None:
self.registry = defaultdict(dict)
self.registry = defaultdict(dict) # type: ignore
def register(
self,
prompt_template,
language: str = "en",
is_default: bool = False,
model_names: List[str] = None,
scene_name: str = None,
model_names: Optional[List[str]] = None,
scene_name: Optional[str] = None,
) -> None:
"""Register prompt template with scene name, language
registry dict format:
@ -43,7 +47,7 @@ class PromptTemplateRegistry:
if not scene_name:
raise ValueError("Prompt template scene name cannot be empty")
if not model_names:
model_names: List[str] = [_DEFAULT_MODEL_KEY]
model_names = [_DEFAULT_MODEL_KEY]
scene_registry = self.registry[scene_name]
_register_scene_prompt_template(
scene_registry, prompt_template, language, model_names
@ -64,7 +68,7 @@ class PromptTemplateRegistry:
scene_name: str,
language: str,
model_name: str,
proxyllm_backend: str = None,
proxyllm_backend: Optional[str] = None,
):
"""Get prompt template with scene name, language and model name
proxyllm_backend: see CFG.PROXYLLM_BACKEND

View File

@ -1,9 +1,10 @@
"""Agentic Workflow Expression Language (AWEL)
"""Agentic Workflow Expression Language (AWEL).
Note:
AWEL is still an experimental feature and only opens the lowest level API.
The stability of this API cannot be guaranteed at present.
Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow
expression language specially designed for large model application development. It
provides great functionality and flexibility. Through the AWEL API, you can focus on
the development of business logic for LLMs applications without paying attention to
cumbersome model and environment details.
"""
@ -71,10 +72,12 @@ __all__ = [
"TransformStreamAbsOperator",
"HttpTrigger",
"setup_dev_environment",
"_is_async_iterator",
]
def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
"""Initialize AWEL."""
from .dag.base import DAGVar
from .dag.dag_manager import DAGManager
from .operator.base import initialize_runner
@ -92,13 +95,13 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
def setup_dev_environment(
dags: List[DAG],
host: Optional[str] = "127.0.0.1",
port: Optional[int] = 5555,
host: str = "127.0.0.1",
port: int = 5555,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
show_dag_graph: Optional[bool] = True,
) -> None:
"""Setup a development environment for AWEL.
"""Run AWEL in development environment.
Just using in development environment, not production environment.
@ -107,9 +110,11 @@ def setup_dev_environment(
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
port (Optional[int], optional): The port. Defaults to 5555.
logging_level (Optional[str], optional): The logging level. Defaults to None.
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
If True, the DAG graph will be saved to a file and open it automatically.
logger_filename (Optional[str], optional): The logger filename.
Defaults to None.
show_dag_graph (Optional[bool], optional): Whether show the DAG graph.
Defaults to True. If True, the DAG graph will be saved to a file and open
it automatically.
"""
import uvicorn
from fastapi import FastAPI
@ -138,7 +143,9 @@ def setup_dev_environment(
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
except Exception as e:
logger.warning(
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
f"Visualize DAG {str(dag)} failed: {e}, if your system has no "
f"graphviz, you can install it by `pip install graphviz` or "
f"`sudo apt install graphviz`"
)
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger)

View File

@ -1,7 +1,10 @@
"""Base classes for AWEL."""
from abc import ABC, abstractmethod
class Trigger(ABC):
"""Base class for trigger."""
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -0,0 +1 @@
"""The module of DAGs."""

View File

@ -1,3 +1,7 @@
"""The base module of DAG.
DAG is the core component of AWEL, it is used to define the relationship between tasks.
"""
import asyncio
import contextvars
import logging
@ -6,7 +10,7 @@ import uuid
from abc import ABC, abstractmethod
from collections import deque
from concurrent.futures import Executor
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast
from dbgpt.component import SystemApp
@ -27,86 +31,108 @@ def _is_async_context():
class DependencyMixin(ABC):
"""The mixin class for DAGNode.
This class defines the interface for setting upstream and downstream nodes.
And it also implements the operator << and >> for setting upstream
and downstream nodes.
"""
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
def set_upstream(self, nodes: DependencyType) -> None:
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
ValueError: If no upstream nodes are provided or if an argument is
not a DependencyMixin.
"""
@abstractmethod
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
def set_downstream(self, nodes: DependencyType) -> None:
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
ValueError: If no downstream nodes are provided or if an argument is
not a DependencyMixin.
"""
def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self << nodes
"""Set upstream nodes for current node.
Implements: self << nodes.
Example:
.. code-block:: python
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes
def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self >> nodes
"""Set downstream nodes for current node.
Example:
Implements: self >> nodes.
.. code-block:: python
Examples:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] >> self"""
"""Set upstream nodes for current node.
Implements: [node] >> self
"""
self.__lshift__(nodes)
return self
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] << self"""
"""Set downstream nodes for current node.
Implements: [node] << self
"""
self.__rshift__(nodes)
return self
class DAGVar:
"""The DAGVar is used to store the current DAG context."""
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None
_async_local: contextvars.ContextVar = contextvars.ContextVar(
"current_dag_stack", default=deque()
)
_system_app: Optional[SystemApp] = None
# The executor for current DAG, this is used run some sync tasks in async DAG
_executor: Optional[Executor] = None
@classmethod
def enter_dag(cls, dag) -> None:
"""Enter a DAG context.
Args:
dag (DAG): The DAG to enter
"""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@ -119,6 +145,7 @@ class DAGVar:
@classmethod
def exit_dag(cls) -> None:
"""Exit a DAG context."""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@ -134,6 +161,11 @@ class DAGVar:
@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
"""Get the current DAG.
Returns:
Optional[DAG]: The current DAG
"""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@ -147,36 +179,56 @@ class DAGVar:
return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
def get_current_system_app(cls) -> Optional[SystemApp]:
"""Get the current system app.
Returns:
Optional[SystemApp]: The current system app
"""
# if not cls._system_app:
# raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
"""Set the current system app.
Args:
system_app (SystemApp): The system app to set
"""
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
logger.warning("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
def get_executor(cls) -> Optional[Executor]:
"""Get the current executor.
Returns:
Optional[Executor]: The current executor
"""
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
"""Set the current executor.
Args:
executor (Executor): The executor to set
"""
cls._executor = executor
class DAGLifecycle:
"""The lifecycle of DAG"""
"""The lifecycle of DAG."""
async def before_dag_run(self):
"""The callback before DAG run"""
"""Execute before DAG run."""
pass
async def after_dag_end(self):
"""The callback after DAG end,
"""Execute after DAG end.
This method may be called multiple times, please make sure it is idempotent.
"""
@ -184,6 +236,8 @@ class DAGLifecycle:
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
"""The base class of DAGNode."""
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
@ -196,6 +250,17 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
executor: Optional[Executor] = None,
**kwargs,
) -> None:
"""Initialize a DAGNode.
Args:
dag (Optional["DAG"], optional): The DAG to add this node to.
Defaults to None.
node_id (Optional[str], optional): The node id. Defaults to None.
node_name (Optional[str], optional): The node name. Defaults to None.
system_app (Optional[SystemApp], optional): The system app.
Defaults to None.
executor (Optional[Executor], optional): The executor. Defaults to None.
"""
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
@ -206,24 +271,28 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
self._node_name: str = node_name
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
@property
def node_id(self) -> str:
"""Return the node id of current DAGNode."""
if not self._node_id:
raise ValueError("Node id not set for current DAGNode")
return self._node_id
@property
@abstractmethod
def dev_mode(self) -> bool:
"""Whether current DAGNode is in dev mode"""
"""Whether current DAGNode is in dev mode."""
@property
def system_app(self) -> SystemApp:
def system_app(self) -> Optional[SystemApp]:
"""Return the system app of current DAGNode."""
return self._system_app
def set_system_app(self, system_app: SystemApp) -> None:
"""Set system app for current DAGNode
"""Set system app for current DAGNode.
Args:
system_app (SystemApp): The system app
@ -231,50 +300,97 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._system_app = system_app
def set_node_id(self, node_id: str) -> None:
"""Set node id for current DAGNode.
Args:
node_id (str): The node id
"""
self._node_id = node_id
def __hash__(self) -> int:
"""Return the hash value of current DAGNode.
If the node_id is not None, return the hash value of node_id.
"""
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
"""Return whether the current DAGNode is equal to other DAGNode."""
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> str:
def node_name(self) -> Optional[str]:
"""Return the node name of current DAGNode.
Returns:
Optional[str]: The node name of current DAGNode
"""
return self._node_name
@property
def dag(self) -> "DAG":
def dag(self) -> Optional["DAG"]:
"""Return the DAG of current DAGNode.
Returns:
Optional["DAG"]: The DAG of current DAGNode
"""
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
def set_upstream(self, nodes: DependencyType) -> None:
"""Set upstream nodes for current node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
"""
self.set_dependency(nodes)
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
def set_downstream(self, nodes: DependencyType) -> None:
"""Set downstream nodes for current node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
"""
self.set_dependency(nodes, is_upstream=False)
@property
def upstream(self) -> List["DAGNode"]:
"""Return the upstream nodes of current DAGNode.
Returns:
List["DAGNode"]: The upstream nodes of current DAGNode
"""
return self._upstream
@property
def downstream(self) -> List["DAGNode"]:
"""Return the downstream nodes of current DAGNode.
Returns:
List["DAGNode"]: The downstream nodes of current DAGNode
"""
return self._downstream
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
"""Set dependency for current node.
Args:
nodes (DependencyType): The nodes to set dependency to current node.
is_upstream (bool, optional): Whether set upstream nodes. Defaults to True.
"""
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance of 'DAGNode'"
"all nodes to set dependency to current node must be instance "
"of 'DAGNode'"
)
nodes: Sequence[DAGNode] = nodes
dags = set([node.dag for node in nodes if node.dag])
nodes = cast(Sequence[DAGNode], nodes)
dags = set([node.dag for node in nodes if node.dag]) # noqa: C403
if self.dag:
dags.add(self.dag)
if not dags:
@ -302,6 +418,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
node._upstream.append(self)
def __repr__(self):
"""Return the representation of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
@ -313,6 +430,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
return f"{cls_name}"
def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()
@ -321,7 +439,7 @@ def _build_task_key(task_name: str, key: str) -> str:
class DAGContext:
"""The context of current DAG, created when the DAG is running
"""The context of current DAG, created when the DAG is running.
Every DAG has been triggered will create a new DAGContext.
"""
@ -329,22 +447,32 @@ class DAGContext:
def __init__(
self,
streaming_call: bool = False,
node_to_outputs: Dict[str, TaskContext] = None,
node_name_to_ids: Dict[str, str] = None,
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
node_name_to_ids: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize a DAGContext.
Args:
streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False.
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
The task outputs of current DAG. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional):
The task name to task id mapping. Defaults to None.
"""
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._curr_task_ctx: Optional[TaskContext] = None
self._share_data: Dict[str, Any] = {}
self._node_to_outputs = node_to_outputs
self._node_name_to_ids = node_name_to_ids
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
"""The task outputs of current DAG
"""Return the task outputs of current DAG.
Just use for internal for now.
Returns:
@ -354,18 +482,28 @@ class DAGContext:
@property
def current_task_context(self) -> TaskContext:
"""Return the current task context."""
if not self._curr_task_ctx:
raise RuntimeError("Current task context not set")
return self._curr_task_ctx
@property
def streaming_call(self) -> bool:
"""Whether the current DAG is streaming call"""
"""Whether the current DAG is streaming call."""
return self._streaming_call
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
"""Set the current task context.
When the task is running, the current task context
will be set to the task context.
TODO: We should support parallel task running in the future.
"""
self._curr_task_ctx = _curr_task_ctx
def get_task_output(self, task_name: str) -> TaskOutput:
"""Get the task output by task name
"""Get the task output by task name.
Args:
task_name (str): The task name
@ -376,22 +514,41 @@ class DAGContext:
if task_name is None:
raise ValueError("task_name can't be None")
node_id = self._node_name_to_ids.get(task_name)
if node_id:
if not node_id:
raise ValueError(f"Task name {task_name} not exists in DAG")
return self._task_outputs.get(node_id).task_output
task_output = self._task_outputs.get(node_id)
if not task_output:
raise ValueError(f"Task output for task {task_name} not exists")
return task_output.task_output
async def get_from_share_data(self, key: str) -> Any:
"""Get share data by key.
Args:
key (str): The share data key
Returns:
Any: The share data, you can cast it to the real type
"""
return self._share_data.get(key)
async def save_to_share_data(
self, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by key.
Args:
key (str): The share data key
data (Any): The share data
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
"""
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key
"""Get share data by task name and key.
Args:
task_name (str): The task name
@ -409,14 +566,14 @@ class DAGContext:
async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by task name and key
"""Save share data by task name and key.
Args:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (bool): Whether overwrite the share data if the key already exists.
Defaults to None.
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
Raises:
ValueError: If the share data key already exists and overwrite is not True
@ -429,15 +586,22 @@ class DAGContext:
class DAG:
"""The DAG class.
Manage the DAG nodes and the relationship between them.
"""
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
"""Initialize a DAG."""
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = None
self._leaf_nodes: List[DAGNode] = None
self._trigger_nodes: List[DAGNode] = None
self._root_nodes: List[DAGNode] = []
self._leaf_nodes: List[DAGNode] = []
self._trigger_nodes: List[DAGNode] = []
self._resource_group: Optional[ResourceGroup] = resource_group
def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:
@ -448,22 +612,26 @@ class DAG:
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
)
self.node_name_to_node[node.node_name] = node
self.node_map[node.node_id] = node
node_id = node.node_id
if not node_id:
raise ValueError("Node id can't be None")
self.node_map[node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
self._root_nodes = []
self._leaf_nodes = []
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
"""Return the dag id of current DAG."""
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
nodes: Set[DAGNode] = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
@ -474,7 +642,7 @@ class DAG:
@property
def root_nodes(self) -> List[DAGNode]:
"""The root nodes of current DAG
"""Return the root nodes of current DAG.
Returns:
List[DAGNode]: The root nodes of current DAG, no repeat
@ -485,7 +653,7 @@ class DAG:
@property
def leaf_nodes(self) -> List[DAGNode]:
"""The leaf nodes of current DAG
"""Return the leaf nodes of current DAG.
Returns:
List[DAGNode]: The leaf nodes of current DAG, no repeat
@ -496,7 +664,7 @@ class DAG:
@property
def trigger_nodes(self) -> List[DAGNode]:
"""The trigger nodes of current DAG
"""Return the trigger nodes of current DAG.
Returns:
List[DAGNode]: The trigger nodes of current DAG, no repeat
@ -506,34 +674,42 @@ class DAG:
return self._trigger_nodes
async def _after_dag_end(self) -> None:
"""The callback after DAG end"""
"""Execute after DAG end."""
tasks = []
for node in self.node_map.values():
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def print_tree(self) -> None:
"""Print the DAG tree"""
"""Print the DAG tree""" # noqa: D400
_print_format_dag_tree(self)
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
"""Create the DAG graph"""
"""Visualize the DAG.
Args:
view (bool, optional): Whether view the DAG graph. Defaults to True,
if True, it will open the graph file with your default viewer.
"""
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def __enter__(self):
"""Enter a DAG context."""
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit a DAG context."""
DAGVar.exit_dag()
def __repr__(self):
"""Return the representation of current DAG."""
return f"DAG(dag_id={self.dag_id})"
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> Set[DAGNode]:
nodes: Set[DAGNode] = set()
if not node:
return nodes
nodes.add(node)
@ -553,7 +729,7 @@ def _print_dag(
level: int = 0,
prefix: str = "",
last: bool = True,
level_dict: Dict[str, Any] = None,
level_dict: Optional[Dict[int, Any]] = None,
):
if level_dict is None:
level_dict = {}
@ -606,7 +782,7 @@ def _handle_dag_nodes(
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
"""Visualize the DAG
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
@ -641,7 +817,7 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
filename = kwargs["filename"]
del kwargs["filename"]
if not "directory" in kwargs:
if "directory" not in kwargs:
from dbgpt.configs.model_config import LOGDIR
kwargs["directory"] = LOGDIR

View File

@ -1,3 +1,8 @@
"""DAGManager is a component of AWEL, it is used to manage DAGs.
DAGManager will load DAGs from dag_dirs, and register the trigger nodes
to TriggerManager.
"""
import logging
from typing import Dict, List
@ -10,24 +15,35 @@ logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
"""The component of DAGManager."""
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_dirs: List[str]):
"""Initialize a DAGManager.
Args:
system_app (SystemApp): The system app.
dag_dirs (List[str]): The directories to load DAGs.
"""
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_dirs)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
def init_app(self, system_app: SystemApp):
"""Initialize the DAGManager."""
self.system_app = system_app
def load_dags(self):
"""Load DAGs from dag_dirs."""
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
self.dag_map[dag_id] = dag
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager

View File

@ -1,3 +1,8 @@
"""DAG loader.
DAGLoader will load DAGs from dag_dirs or other sources.
Now only support load DAGs from local files.
"""
import hashlib
import logging
import os
@ -12,16 +17,26 @@ logger = logging.getLogger(__name__)
class DAGLoader(ABC):
"""Abstract base class representing a loader for loading DAGs."""
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags"""
"""Load dags."""
class LocalFileDAGLoader(DAGLoader):
"""DAG loader for loading DAGs from local files."""
def __init__(self, dag_dirs: List[str]) -> None:
"""Initialize a LocalFileDAGLoader.
Args:
dag_dirs (List[str]): The directories to load DAGs.
"""
self._dag_dirs = dag_dirs
def load_dags(self) -> List[DAG]:
"""Load dags from local files."""
dags = []
for filepath in self._dag_dirs:
if not os.path.exists(filepath):
@ -70,7 +85,7 @@ def _load_modules_from_file(filepath: str):
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception as e:
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message

View File

@ -0,0 +1 @@
"""The module of operator."""

View File

@ -1,7 +1,7 @@
"""Base classes for operators that can be executed within a workflow."""
import asyncio
import functools
from abc import ABC, ABCMeta, abstractmethod
from inspect import signature
from types import FunctionType
from typing import (
Any,
@ -9,7 +9,6 @@ from typing import (
Dict,
Generic,
Iterator,
List,
Optional,
TypeVar,
Union,
@ -21,7 +20,6 @@ from dbgpt.util.executor_utils import (
AsyncToSyncIterator,
BlockingFunction,
DefaultExecutorFactory,
ExecutorFactory,
blocking_func_to_async,
)
@ -54,13 +52,15 @@ class WorkflowRunner(ABC, Generic[T]):
node (RunnableDAGNode): The starting node of the workflow to be executed.
call_data (CALL_DATA): The data pass to root operator node.
streaming_call (bool): Whether the call is a streaming call.
exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
exist_dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
Returns:
DAGContext: The context after executing the workflow, containing the final state and data.
DAGContext: The context after executing the workflow, containing the final
state and data.
"""
default_runner: WorkflowRunner = None
default_runner: Optional[WorkflowRunner] = None
class BaseOperatorMeta(ABCMeta):
@ -68,8 +68,7 @@ class BaseOperatorMeta(ABCMeta):
@classmethod
def _apply_defaults(cls, func: F) -> F:
sig_cache = signature(func)
# sig_cache = signature(func)
@functools.wraps(func)
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
@ -81,7 +80,7 @@ class BaseOperatorMeta(ABCMeta):
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory
).create()
else:
executor = DefaultExecutorFactory().create()
@ -107,9 +106,10 @@ class BaseOperatorMeta(ABCMeta):
real_obj = func(self, *args, **kwargs)
return real_obj
return cast(T, apply_defaults)
return cast(F, apply_defaults)
def __new__(cls, name, bases, namespace, **kwargs):
"""Create a new BaseOperator class with default arguments."""
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
return new_cls
@ -126,13 +126,14 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
task_id: Optional[str] = None,
task_name: Optional[str] = None,
dag: Optional[DAG] = None,
runner: WorkflowRunner = None,
runner: Optional[WorkflowRunner] = None,
**kwargs,
) -> None:
"""Initializes a BaseOperator with an optional workflow runner.
"""Create a BaseOperator with an optional workflow runner.
Args:
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
runner (WorkflowRunner, optional): The runner used to execute the workflow.
Defaults to None.
"""
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
if not runner:
@ -141,19 +142,24 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
runner = DefaultWorkflowRunner()
self._runner: WorkflowRunner = runner
self._dag_ctx: DAGContext = None
self._dag_ctx: Optional[DAGContext] = None
@property
def current_dag_context(self) -> DAGContext:
"""Return the current DAG context."""
if not self._dag_ctx:
raise ValueError("DAGContext is not set")
return self._dag_ctx
@property
def dev_mode(self) -> bool:
"""Whether the operator is in dev mode.
In production mode, the default runner is not None.
Returns:
bool: Whether the operator is in dev mode. True if the default runner is None.
bool: Whether the operator is in dev mode. True if the
default runner is None.
"""
return default_runner is None
@ -186,7 +192,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
Returns:
OUT: The output of the node after execution.
"""
@ -196,7 +203,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
return out_ctx.current_task_context.task_output.output
def _blocking_call(
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
self,
call_data: Optional[CALL_DATA] = None,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> OUT:
"""Execute the node and return the output.
@ -213,6 +222,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if not loop:
loop = get_or_create_event_loop()
loop = cast(asyncio.BaseEventLoop, loop)
return loop.run_until_complete(self.call(call_data))
async def call_stream(
@ -226,7 +236,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
@ -237,7 +248,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
return out_ctx.current_task_context.task_output.output_stream
def _blocking_call_stream(
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
self,
call_data: Optional[CALL_DATA] = None,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> Iterator[OUT]:
"""Execute the node and return the output as a stream.
@ -259,9 +272,22 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
) -> Any:
"""Execute a blocking function asynchronously.
In AWEL, the operators are executed asynchronously. However,
some functions are blocking, we run them in a separate thread.
Args:
func (BlockingFunction): The blocking function to be executed.
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
"""
if not self._executor:
raise ValueError("Executor is not set")
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
def initialize_runner(runner: WorkflowRunner):
"""Initialize the default runner."""
global default_runner
default_runner = runner

View File

@ -1,7 +1,7 @@
"""Common operators of AWEL."""
import asyncio
import logging
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
@ -13,7 +13,17 @@ from typing import (
)
from ..dag.base import DAGContext
from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput
from ..task.base import (
IN,
OUT,
InputContext,
InputSource,
JoinFunc,
MapFunc,
ReduceFunc,
TaskContext,
TaskOutput,
)
from .base import BaseOperator
logger = logging.getLogger(__name__)
@ -25,7 +35,12 @@ class JoinOperator(BaseOperator, Generic[OUT]):
This node type is useful for combining the outputs of upstream nodes.
"""
def __init__(self, combine_function, **kwargs):
def __init__(self, combine_function: JoinFunc, **kwargs):
"""Create a JoinDAGNode with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
"""
super().__init__(**kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")
@ -33,6 +48,7 @@ class JoinOperator(BaseOperator, Generic[OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the join operation on the DAG context's inputs.
Args:
dag_ctx (DAGContext): The current context of the DAG.
@ -50,8 +66,10 @@ class JoinOperator(BaseOperator, Generic[OUT]):
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
def __init__(self, reduce_function=None, **kwargs):
"""Initializes a ReduceStreamOperator with a combine function.
"""Operator that reduces inputs using a custom reduce function."""
def __init__(self, reduce_function: Optional[ReduceFunc] = None, **kwargs):
"""Create a ReduceStreamOperator with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
@ -89,6 +107,7 @@ class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
return reduce_output
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
"""Reduce the input stream to a single value."""
raise NotImplementedError
@ -99,8 +118,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
passes the transformed data downstream.
"""
def __init__(self, map_function=None, **kwargs):
"""Initializes a MapDAGNode with a mapping function.
def __init__(self, map_function: Optional[MapFunc] = None, **kwargs):
"""Create a MapDAGNode with a mapping function.
Args:
map_function: A function that defines how to map the input data.
@ -133,13 +152,18 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
if not call_data and not curr_task_ctx.task_input.check_single_parent():
num_parents = len(curr_task_ctx.task_input.parent_outputs)
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent,"
f"now number of parents: {num_parents}"
)
map_function = self.map_function or self.map
if call_data:
call_data = await curr_task_ctx._call_data_to_output()
output = await call_data.map(map_function)
wrapped_call_data = await curr_task_ctx._call_data_to_output()
if not wrapped_call_data:
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data"
)
output: TaskOutput[OUT] = await wrapped_call_data.map(map_function)
curr_task_ctx.set_task_output(output)
return output
@ -150,6 +174,7 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
return output
async def map(self, input_value: IN) -> OUT:
"""Map the input data to a new value."""
raise NotImplementedError
@ -161,6 +186,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
This node filters its input data using a branching function and
allows for conditional paths in the workflow.
If a branch function returns True, the corresponding task will be executed.
otherwise, the corresponding task will be skipped, and the output of
this skip node will be set to `SKIP_DATA`
"""
def __init__(
@ -168,11 +198,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
**kwargs,
):
"""
Initializes a BranchDAGNode with a branching function.
"""Create a BranchDAGNode with a branching function.
Args:
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]):
Dict of function that defines the branching condition.
Raises:
ValueError: If the branch_function is not callable.
@ -183,7 +213,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
if not callable(branch_function):
raise ValueError("branch_function must be callable")
if isinstance(value, BaseOperator):
branches[branch_function] = value.node_name or value.node_name
if not value.node_name:
raise ValueError("branch node name must be set")
branches[branch_function] = value.node_name
self._branches = branches
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
@ -210,7 +242,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
branches = await self.branches()
branch_func_tasks = []
branch_nodes: List[str] = []
branch_nodes: List[Union[BaseOperator, str]] = []
for func, node_name in branches.items():
branch_nodes.append(node_name)
branch_func_tasks.append(
@ -225,20 +257,25 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
node_name = branch_nodes[i]
branch_out = ctx.parent_outputs[0].task_output
logger.info(
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
f"branch_input_ctxs {i} result {branch_out.output}, "
f"is_empty: {branch_out.is_empty}"
)
if ctx.parent_outputs[0].task_output.is_empty:
if ctx.parent_outputs[0].task_output.is_none:
logger.info(f"Skip node name {node_name}")
skip_node_names.append(node_name)
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
"""Return branch logic based on input data."""
raise NotImplementedError
class InputOperator(BaseOperator, Generic[OUT]):
"""Operator node that reads data from an input source."""
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
"""Create an InputDAGNode with an input source."""
super().__init__(**kwargs)
self._input_source = input_source
@ -250,7 +287,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
class TriggerOperator(InputOperator, Generic[OUT]):
"""Operator node that triggers the DAG to run."""
def __init__(self, **kwargs) -> None:
"""Create a TriggerDAGNode."""
from ..task.task_impl import SimpleCallDataInputSource
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)

View File

@ -1,3 +1,4 @@
"""The module of stream operator."""
from abc import ABC, abstractmethod
from typing import AsyncIterator, Generic
@ -7,12 +8,18 @@ from .base import BaseOperator
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
call_data = curr_task_ctx.call_data
if call_data:
call_data = await curr_task_ctx._call_data_to_output()
output = await call_data.streamify(self.streamify)
wrapped_call_data = await curr_task_ctx._call_data_to_output()
if not wrapped_call_data:
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data"
)
output = await wrapped_call_data.streamify(self.streamify)
curr_task_ctx.set_task_output(output)
return output
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
@ -23,26 +30,28 @@ class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
@abstractmethod
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
"""Convert a value of IN to an AsyncIterator[OUT]
"""Convert a value of IN to an AsyncIterator[OUT].
Args:
input_value (IN): The data of parent operator's output
Example:
Examples:
.. code-block:: python
.. code-block:: python
class MyStreamOperator(StreamifyAbsOperator[int, int]):
async def streamify(self, input_value: int) -> AsyncIterator[int]:
for i in range(input_value):
yield i
class MyStreamOperator(StreamifyAbsOperator[int, int]):
async def streamify(self, input_value: int) -> AsyncIterator[int]:
for i in range(input_value):
yield i
"""
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
"""An abstract operator that converts a value of AsyncIterator[IN] to an OUT."""
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.unstreamify(self.unstreamify)
curr_task_ctx.set_task_output(output)
@ -56,24 +65,30 @@ class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
.. code-block:: python
.. code-block:: python
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
value_cnt = 0
async for v in input_value:
value_cnt += 1
return value_cnt
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
value_cnt = 0
async for v in input_value:
value_cnt += 1
return value_cnt
"""
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
"""Streaming to other streaming data.
An abstract operator that transforms a value of
AsyncIterator[IN] to another AsyncIterator[OUT].
"""
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.transform_stream(self.transform_stream)
curr_task_ctx.set_task_output(output)
return output
@ -81,19 +96,18 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
async def transform_stream(
self, input_value: AsyncIterator[IN]
) -> AsyncIterator[OUT]:
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT].
Args:
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
Examples:
.. code-block:: python
.. code-block:: python
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
async def unstreamify(
self, input_value: AsyncIterator[int]
) -> AsyncIterator[int]:
async for v in input_value:
yield v + 1
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
async def unstreamify(
self, input_value: AsyncIterator[int]
) -> AsyncIterator[int]:
async for v in input_value:
yield v + 1
"""

View File

@ -0,0 +1,4 @@
"""The module of AWEL resource.
Not implemented yet.
"""

View File

@ -1,8 +1,15 @@
"""Base class for resource group."""
from abc import ABC, abstractmethod
class ResourceGroup(ABC):
"""Base class for resource group.
A resource group is a group of resources that are related to each other.
It contains the all resources that are needed to run a workflow.
"""
@property
@abstractmethod
def name(self) -> str:
"""The name of current resource group"""
"""Return the name of current resource group."""

View File

@ -0,0 +1,4 @@
"""The module to run AWEL operators.
You can implement your own runner by inheriting the `WorkflowRunner` class.
"""

View File

@ -1,33 +1,38 @@
"""Job manager for DAG."""
import asyncio
import logging
import uuid
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, cast
from ..dag.base import DAG, DAGLifecycle
from ..dag.base import DAGLifecycle
from ..operator.base import CALL_DATA, BaseOperator
logger = logging.getLogger(__name__)
class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None:
pass
class DAGInstance:
def __init__(self, dag: DAG) -> None:
self._dag = dag
class JobManager(DAGLifecycle):
"""Job manager for DAG.
This class is used to manage the DAG lifecycle.
"""
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
id2call_data: Dict[str, Optional[Dict]],
node_name_to_ids: Dict[str, str],
) -> None:
"""Create a job manager.
Args:
root_nodes (List[BaseOperator]): The root nodes of the DAG.
all_nodes (List[BaseOperator]): All nodes of the DAG.
end_node (BaseOperator): The end node of the DAG.
id2call_data (Dict[str, Optional[Dict]]): The call data of each node.
node_name_to_ids (Dict[str, str]): The node name to node id mapping.
"""
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
@ -38,6 +43,15 @@ class JobManager(DAGLifecycle):
def build_from_end_node(
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> "JobManager":
"""Build a job manager from the end node.
This will get all upstream nodes from the end node, and build a job manager.
Args:
end_node (BaseOperator): The end node of the DAG.
call_data (Optional[CALL_DATA], optional): The call data of the end node.
Defaults to None.
"""
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
@ -50,17 +64,22 @@ class JobManager(DAGLifecycle):
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
"""Get the call data by node id.
Args:
node_id (str): The node id.
"""
return self._id2node_data.get(node_id)
async def before_dag_run(self):
"""The callback before DAG run"""
"""Execute the callback before DAG run."""
tasks = []
for node in self._all_nodes:
tasks.append(node.before_dag_run())
await asyncio.gather(*tasks)
async def after_dag_end(self):
"""The callback after DAG end"""
"""Execute the callback after DAG end."""
tasks = []
for node in self._all_nodes:
tasks.append(node.after_dag_end())
@ -68,9 +87,9 @@ class JobManager(DAGLifecycle):
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]:
id2call_data = {}
root_nodes: List[BaseOperator], call_data: Optional[CALL_DATA]
) -> Dict[str, Optional[Dict]]:
id2call_data: Dict[str, Optional[Dict]] = {}
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
if not call_data:
return id2call_data
@ -82,7 +101,8 @@ def _save_call_data(
for node in root_nodes:
node_id = node.node_id
logger.debug(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
f"Save call data to node {node.node_id}, call_data: "
f"{call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id)
return id2call_data
@ -91,13 +111,11 @@ def _save_call_data(
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
"""Build all nodes from the end node."""
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id
if not task_id:
task_id = str(uuid.uuid4())
end_node.set_node_id(task_id)
if isinstance(end_node, BaseOperator) and not end_node._node_id:
end_node.set_node_id(str(uuid.uuid4()))
nodes.append(end_node)
for node in end_node.upstream:
node = cast(BaseOperator, node)
nodes += _build_from_end_node(node)
return nodes

View File

@ -1,12 +1,16 @@
"""Local runner for workflow.
This runner will run the workflow in the current process.
"""
import logging
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..operator.common_operator import BranchOperator, JoinOperator
from ..task.base import SKIP_DATA, TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
@ -14,6 +18,8 @@ logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
"""The default workflow runner."""
async def execute_workflow(
self,
node: BaseOperator,
@ -21,6 +27,17 @@ class DefaultWorkflowRunner(WorkflowRunner):
streaming_call: bool = False,
exist_dag_ctx: Optional[DAGContext] = None,
) -> DAGContext:
"""Execute the workflow.
Args:
node (BaseOperator): The end node of the workflow.
call_data (Optional[CALL_DATA], optional): The call data of the end node.
Defaults to None.
streaming_call (bool, optional): Whether the call is streaming call.
Defaults to False.
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
Defaults to None.
"""
# Save node output
# dag = node.dag
job_manager = JobManager.build_from_end_node(node, call_data)
@ -37,8 +54,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
)
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
skip_node_ids = set()
system_app: SystemApp = DAGVar.get_current_system_app()
skip_node_ids: Set[str] = set()
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
await job_manager.before_dag_run()
await self._execute_node(
@ -57,7 +74,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
system_app: SystemApp,
system_app: Optional[SystemApp],
):
# Skip run node
if node.node_id in node_outputs:
@ -79,8 +96,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
input_ctx = DefaultInputContext(inputs)
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
task_ctx: DefaultTaskContext = DefaultTaskContext(
node.node_id, TaskState.INIT, task_output=None
)
current_call_data = job_manager.get_call_data_by_id(node.node_id)
if current_call_data:
task_ctx.set_call_data(current_call_data)
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
@ -88,12 +109,13 @@ class DefaultWorkflowRunner(WorkflowRunner):
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(None))
task_ctx.set_task_output(SimpleTaskOutput(SKIP_DATA))
node_outputs[node.node_id] = task_ctx
return
try:
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
f"Begin run operator, node id: {node.node_id}, node name: "
f"{node.node_name}, cls: {node}"
)
if system_app is not None and node.system_app is None:
node.set_system_app(system_app)
@ -120,6 +142,7 @@ def _skip_current_downstream_by_node_name(
if not skip_nodes:
return
for child in branch_node.downstream:
child = cast(BaseOperator, child)
if child.node_name in skip_nodes:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
@ -131,4 +154,5 @@ def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
child = cast(BaseOperator, child)
_skip_downstream_by_id(child, skip_node_ids)

View File

@ -0,0 +1 @@
"""The module of Task."""

View File

@ -1,8 +1,10 @@
"""Base classes for task-related objects."""
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
@ -17,6 +19,24 @@ OUT = TypeVar("OUT")
T = TypeVar("T")
class _EMPTY_DATA_TYPE:
def __bool__(self):
return False
EMPTY_DATA = _EMPTY_DATA_TYPE()
SKIP_DATA = _EMPTY_DATA_TYPE()
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
PredicateFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
JoinFunc = Union[Callable[..., OUT], Callable[..., Awaitable[OUT]]]
class TaskState(str, Enum):
"""Enumeration representing the state of a task in the workflow.
@ -33,8 +53,8 @@ class TaskState(str, Enum):
class TaskOutput(ABC, Generic[T]):
"""Abstract base class representing the output of a task.
This class encapsulates the output of a task and provides methods to access the output data.
It can be subclassed to implement specific output behaviors.
This class encapsulates the output of a task and provides methods to access the
output data.It can be subclassed to implement specific output behaviors.
"""
@property
@ -56,20 +76,30 @@ class TaskOutput(ABC, Generic[T]):
return False
@property
def output(self) -> Optional[T]:
def is_none(self) -> bool:
"""Check if the output is None.
Returns:
bool: True if the output is None, False otherwise.
"""
return False
@property
def output(self) -> T:
"""Return the output of the task.
Returns:
T: The output of the task. None if the output is empty.
T: The output of the task.
"""
raise NotImplementedError
@property
def output_stream(self) -> Optional[AsyncIterator[T]]:
def output_stream(self) -> AsyncIterator[T]:
"""Return the output of the task as an asynchronous stream.
Returns:
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
AsyncIterator[T]: An asynchronous iterator over the output. None if the
output is empty.
"""
raise NotImplementedError
@ -83,39 +113,38 @@ class TaskOutput(ABC, Generic[T]):
@abstractmethod
def new_output(self) -> "TaskOutput[T]":
"""Create new output object"""
"""Create new output object."""
async def map(self, map_func) -> "TaskOutput[T]":
async def map(self, map_func: MapFunc) -> "TaskOutput[OUT]":
"""Apply a mapping function to the task's output.
Args:
map_func: A function to apply to the task's output.
map_func (MapFunc): A function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the mapping function.
TaskOutput[OUT]: The result of applying the mapping function.
"""
raise NotImplementedError
async def reduce(self, reduce_func) -> "TaskOutput[T]":
async def reduce(self, reduce_func: ReduceFunc) -> "TaskOutput[OUT]":
"""Apply a reducing function to the task's output.
Stream TaskOutput to Nonstream TaskOutput.
Stream TaskOutput to no stream TaskOutput.
Args:
reduce_func: A reducing function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the reducing function.
TaskOutput[OUT]: The result of applying the reducing function.
"""
raise NotImplementedError
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> "TaskOutput[T]":
async def streamify(self, transform_func: StreamFunc) -> "TaskOutput[T]":
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
Args:
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
transform_func (StreamFunc): Function to transform a T value into an
AsyncIterator[OUT].
Returns:
TaskOutput[T]: The result of applying the reducing function.
@ -123,38 +152,39 @@ class TaskOutput(ABC, Generic[T]):
raise NotImplementedError
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
self, transform_func: TransformFunc
) -> "TaskOutput[OUT]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> "TaskOutput[T]":
async def unstreamify(self, transform_func: UnStreamFunc) -> "TaskOutput[OUT]":
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
Args:
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
transform_func (UnStreamFunc): Function to transform an AsyncIterator[T]
into a T value.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def check_condition(self, condition_func) -> bool:
async def check_condition(self, condition_func) -> "TaskOutput[OUT]":
"""Check if current output meets a given condition.
Args:
condition_func: A function to determine if the condition is met.
Returns:
bool: True if current output meet the condition, False otherwise.
TaskOutput[T]: The result of applying the reducing function.
If the condition is not met, return empty output.
"""
raise NotImplementedError
@ -182,6 +212,9 @@ class TaskContext(ABC, Generic[T]):
Returns:
InputContext: The InputContext of current task.
Raises:
Exception: If the InputContext is not set.
"""
@abstractmethod
@ -216,7 +249,7 @@ class TaskContext(ABC, Generic[T]):
@abstractmethod
def set_current_state(self, task_state: TaskState) -> None:
"""Set current task state
"""Set current task state.
Args:
task_state (TaskState): The task state to be set.
@ -224,7 +257,7 @@ class TaskContext(ABC, Generic[T]):
@abstractmethod
def new_ctx(self) -> "TaskContext":
"""Create new task context
"""Create new task context.
Returns:
TaskContext: A new instance of a TaskContext.
@ -233,14 +266,14 @@ class TaskContext(ABC, Generic[T]):
@property
@abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Get the metadata of current task
"""Return the metadata of current task.
Returns:
Dict[str, Any]: The metadata
"""
def update_metadata(self, key: str, value: Any) -> None:
"""Update metadata with key and value
"""Update metadata with key and value.
Args:
key (str): The key of metadata
@ -250,15 +283,15 @@ class TaskContext(ABC, Generic[T]):
@property
def call_data(self) -> Optional[Dict]:
"""Get the call data for current data"""
"""Return the call data for current data."""
return self.metadata.get("call_data")
@abstractmethod
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
"""Get the call data for current data."""
def set_call_data(self, call_data: Dict) -> None:
"""Set call data for current task"""
"""Save the call data for current task."""
self.update_metadata("call_data", call_data)
@ -315,7 +348,8 @@ class InputContext(ABC):
"""Filter the inputs based on a provided function.
Args:
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
filter_func (Callable[[Any], bool]): A function that returns True for
inputs to keep.
Returns:
InputContext: A new InputContext instance with the filtered inputs.
@ -323,13 +357,15 @@ class InputContext(ABC):
@abstractmethod
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
self, predicate_func: PredicateFunc, failed_value: Any = None
) -> "InputContext":
"""Predicate the inputs based on a provided function.
Args:
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate function is False
predicate_func (Callable[[Any], bool]): A function that returns True for
inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate
function is False
Returns:
InputContext: A new InputContext instance with the predicate inputs.
"""

View File

@ -1,3 +1,7 @@
"""The default implementation of Task.
This implementation can run workflow in local machine.
"""
import asyncio
import logging
from abc import ABC, abstractmethod
@ -8,15 +12,32 @@ from typing import (
Coroutine,
Dict,
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
from .base import (
_EMPTY_DATA_TYPE,
EMPTY_DATA,
OUT,
PLACEHOLDER_DATA,
SKIP_DATA,
InputContext,
InputSource,
MapFunc,
PredicateFunc,
ReduceFunc,
StreamFunc,
T,
TaskContext,
TaskOutput,
TaskState,
TransformFunc,
UnStreamFunc,
)
logger = logging.getLogger(__name__)
@ -37,101 +58,197 @@ async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: T) -> None:
"""The default implementation of TaskOutput.
It wraps the no stream data and provide some basic data operations.
"""
def __init__(self, data: Union[T, _EMPTY_DATA_TYPE] = EMPTY_DATA) -> None:
"""Create a SimpleTaskOutput.
Args:
data (Union[T, _EMPTY_DATA_TYPE], optional): The output data. Defaults to
EMPTY_DATA.
"""
super().__init__()
self._data = data
@property
def output(self) -> T:
return self._data
"""Return the output data."""
if self._data == EMPTY_DATA:
raise ValueError("No output data for current task output")
return cast(T, self._data)
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
"""Save the output data to current object.
Args:
output_data (T | AsyncIterator[T]): The output data.
"""
if _is_async_iterator(output_data):
raise ValueError(
f"Can not set stream data {output_data} to SimpleTaskOutput"
)
self._data = cast(T, output_data)
def new_output(self) -> TaskOutput[T]:
return SimpleTaskOutput(None)
"""Create new output object with empty data."""
return SimpleTaskOutput()
@property
def is_empty(self) -> bool:
"""Return True if the output data is empty."""
return self._data == EMPTY_DATA or self._data == SKIP_DATA
@property
def is_none(self) -> bool:
"""Return True if the output data is None."""
return self._data is None
async def _apply_func(self, func) -> Any:
"""Apply the function to current output data."""
if asyncio.iscoroutinefunction(func):
out = await func(self._data)
else:
out = func(self._data)
return out
async def map(self, map_func) -> TaskOutput[T]:
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
"""Apply a mapping function to the task's output.
Args:
map_func (MapFunc): A function to apply to the task's output.
Returns:
TaskOutput[OUT]: The result of applying the mapping function.
"""
out = await self._apply_func(map_func)
return SimpleTaskOutput(out)
async def check_condition(self, condition_func) -> bool:
return await self._apply_func(condition_func)
async def check_condition(self, condition_func) -> TaskOutput[OUT]:
"""Check the condition function."""
out = await self._apply_func(condition_func)
if out:
return SimpleTaskOutput(PLACEHOLDER_DATA)
return SimpleTaskOutput(EMPTY_DATA)
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> TaskOutput[T]:
async def streamify(self, transform_func: StreamFunc) -> TaskOutput[OUT]:
"""Transform the task's output to a stream output.
Args:
transform_func (StreamFunc): A function to transform the task's output to a
stream output.
Returns:
TaskOutput[OUT]: The result of transforming the task's output to a stream
output.
"""
out = await self._apply_func(transform_func)
return SimpleStreamTaskOutput(out)
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: AsyncIterator[T]) -> None:
"""The default stream implementation of TaskOutput."""
def __init__(
self, data: Union[AsyncIterator[T], _EMPTY_DATA_TYPE] = EMPTY_DATA
) -> None:
"""Create a SimpleStreamTaskOutput.
Args:
data (Union[AsyncIterator[T], _EMPTY_DATA_TYPE], optional): The output data.
Defaults to EMPTY_DATA.
"""
super().__init__()
self._data = data
@property
def is_stream(self) -> bool:
"""Return True if the output data is a stream."""
return True
@property
def is_empty(self) -> bool:
return not self._data
"""Return True if the output data is empty."""
return self._data == EMPTY_DATA or self._data == SKIP_DATA
@property
def is_none(self) -> bool:
"""Return True if the output data is None."""
return self._data is None
@property
def output_stream(self) -> AsyncIterator[T]:
return self._data
"""Return the output data.
Returns:
AsyncIterator[T]: The output data.
Raises:
ValueError: If the output data is empty.
"""
if self._data == EMPTY_DATA:
raise ValueError("No output data for current task output")
return cast(AsyncIterator[T], self._data)
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
"""Save the output data to current object.
Raises:
ValueError: If the output data is not a stream.
"""
if not _is_async_iterator(output_data):
raise ValueError(
f"Can not set non-stream data {output_data} to SimpleStreamTaskOutput"
)
self._data = cast(AsyncIterator[T], output_data)
def new_output(self) -> TaskOutput[T]:
return SimpleStreamTaskOutput(None)
"""Create new output object with empty data."""
return SimpleStreamTaskOutput()
async def map(self, map_func) -> TaskOutput[T]:
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
"""Apply a mapping function to the task's output."""
is_async = asyncio.iscoroutinefunction(map_func)
async def new_iter() -> AsyncIterator[T]:
async for out in self._data:
async def new_iter() -> AsyncIterator[OUT]:
async for out in self.output_stream:
if is_async:
out = await map_func(out)
new_out: OUT = await map_func(out)
else:
out = map_func(out)
yield out
new_out = cast(OUT, map_func(out))
yield new_out
return SimpleStreamTaskOutput(new_iter())
async def reduce(self, reduce_func) -> TaskOutput[T]:
out = await _reduce_stream(self._data, reduce_func)
async def reduce(self, reduce_func: ReduceFunc) -> TaskOutput[OUT]:
"""Apply a reduce function to the task's output."""
out = await _reduce_stream(self.output_stream, reduce_func)
return SimpleTaskOutput(out)
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> TaskOutput[T]:
async def unstreamify(self, transform_func: UnStreamFunc) -> TaskOutput[OUT]:
"""Transform the task's output to a non-stream output."""
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
out = await transform_func(self.output_stream)
else:
out = transform_func(self._data)
out = transform_func(self.output_stream)
return SimpleTaskOutput(out)
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> TaskOutput[T]:
async def transform_stream(self, transform_func: TransformFunc) -> TaskOutput[OUT]:
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
out: AsyncIterator[OUT] = await transform_func(self.output_stream)
else:
out = transform_func(self._data)
out = cast(AsyncIterator[OUT], transform_func(self.output_stream))
return SimpleStreamTaskOutput(out)
@ -145,20 +262,34 @@ def _is_async_iterator(obj):
class BaseInputSource(InputSource, ABC):
"""The base class of InputSource."""
def __init__(self) -> None:
"""Create a BaseInputSource."""
super().__init__()
self._is_read = False
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data with task context"""
"""Return data with task context."""
async def read(self, task_ctx: TaskContext) -> TaskOutput:
"""Read data with task context.
Args:
task_ctx (TaskContext): The task context.
Returns:
TaskOutput: The task output.
Raises:
ValueError: If the input source is a stream and has been read.
"""
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output = SimpleStreamTaskOutput(data)
output: TaskOutput = SimpleStreamTaskOutput(data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
@ -166,7 +297,14 @@ class BaseInputSource(InputSource, ABC):
class SimpleInputSource(BaseInputSource):
"""The default implementation of InputSource."""
def __init__(self, data: Any) -> None:
"""Create a SimpleInputSource.
Args:
data (Any): The input data.
"""
super().__init__()
self._data = data
@ -175,63 +313,121 @@ class SimpleInputSource(BaseInputSource):
class SimpleCallDataInputSource(BaseInputSource):
"""The implementation of InputSource for call data."""
def __init__(self) -> None:
"""Create a SimpleCallDataInputSource."""
super().__init__()
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data from task context.
Returns:
Any: The data.
Raises:
ValueError: If the call data is empty.
"""
call_data = task_ctx.call_data
data = call_data.get("data") if call_data else None
if not (call_data and data):
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
if data == EMPTY_DATA:
raise ValueError("No call data for current SimpleCallDataInputSource")
return data
class DefaultTaskContext(TaskContext, Generic[T]):
"""The default implementation of TaskContext."""
def __init__(
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
self,
task_id: str,
task_state: TaskState,
task_output: Optional[TaskOutput[T]] = None,
) -> None:
"""Create a DefaultTaskContext.
Args:
task_id (str): The task id.
task_state (TaskState): The task state.
task_output (Optional[TaskOutput[T]], optional): The task output. Defaults
to None.
"""
super().__init__()
self._task_id = task_id
self._task_state = task_state
self._output = task_output
self._task_input = None
self._metadata = {}
self._output: Optional[TaskOutput[T]] = task_output
self._task_input: Optional[InputContext] = None
self._metadata: Dict[str, Any] = {}
@property
def task_id(self) -> str:
"""Return the task id."""
return self._task_id
@property
def task_input(self) -> InputContext:
"""Return the task input."""
if not self._task_input:
raise ValueError("No input for current task context")
return self._task_input
def set_task_input(self, input_ctx: "InputContext") -> None:
def set_task_input(self, input_ctx: InputContext) -> None:
"""Save the task input to current task."""
self._task_input = input_ctx
@property
def task_output(self) -> TaskOutput:
"""Return the task output.
Returns:
TaskOutput: The task output.
Raises:
ValueError: If the task output is empty.
"""
if not self._output:
raise ValueError("No output for current task context")
return self._output
def set_task_output(self, task_output: TaskOutput) -> None:
"""Save the task output to current task.
Args:
task_output (TaskOutput): The task output.
"""
self._output = task_output
@property
def current_state(self) -> TaskState:
"""Return the current task state."""
return self._task_state
def set_current_state(self, task_state: TaskState) -> None:
"""Save the current task state to current task."""
self._task_state = task_state
def new_ctx(self) -> TaskContext:
"""Create new task context with empty output."""
if not self._output:
raise ValueError("No output for current task context")
new_output = self._output.new_output()
return DefaultTaskContext(self._task_id, self._task_state, new_output)
@property
def metadata(self) -> Dict[str, Any]:
"""Return the metadata of current task.
Returns:
Dict[str, Any]: The metadata.
"""
return self._metadata
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
"""Return the call data of current task.
Returns:
Optional[TaskOutput[T]]: The call data.
"""
call_data = self.call_data
if not call_data:
return None
@ -240,24 +436,48 @@ class DefaultTaskContext(TaskContext, Generic[T]):
class DefaultInputContext(InputContext):
"""The default implementation of InputContext.
It wraps the all inputs from parent tasks and provide some basic data operations.
"""
def __init__(self, outputs: List[TaskContext]) -> None:
"""Create a DefaultInputContext.
Args:
outputs (List[TaskContext]): The outputs from parent tasks.
"""
super().__init__()
self._outputs = outputs
@property
def parent_outputs(self) -> List[TaskContext]:
"""Return the outputs from parent tasks.
Returns:
List[TaskContext]: The outputs from parent tasks.
"""
return self._outputs
async def _apply_func(
self, func: Callable[[Any], Any], apply_type: str = "map"
) -> Tuple[List[TaskContext], List[TaskOutput]]:
"""Apply the function to all parent outputs.
Args:
func (Callable[[Any], Any]): The function to apply.
apply_type (str, optional): The apply type. Defaults to "map".
Returns:
Tuple[List[TaskContext], List[TaskOutput]]: The new parent outputs and the
results of applying the function.
"""
new_outputs: List[TaskContext] = []
map_tasks = []
for out in self._outputs:
new_outputs.append(out.new_ctx())
result = None
if apply_type == "map":
result = out.task_output.map(func)
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
elif apply_type == "check_condition":
@ -269,29 +489,40 @@ class DefaultInputContext(InputContext):
return new_outputs, results
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
"""Apply a mapping function to all parent outputs."""
new_outputs, results = await self._apply_func(map_func)
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
"""Apply a mapping function to all parent outputs.
The parent outputs will be unpacked and passed to the mapping function.
Args:
map_func (Callable[..., Any]): The mapping function.
Returns:
InputContext: The new input context.
"""
if not self._outputs:
return DefaultInputContext([])
# Some parent may be empty
not_empty_idx = 0
for i, p in enumerate(self._outputs):
if p.task_output.is_empty:
# Skip empty parent
continue
not_empty_idx = i
break
# All output is empty?
is_steam = self._outputs[not_empty_idx].task_output.is_stream
if is_steam:
if not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
if is_steam and not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
outputs = []
for out in self._outputs:
if out.task_output.is_stream:
@ -305,22 +536,26 @@ class DefaultInputContext(InputContext):
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
single_output.task_output.set_output(map_res)
logger.debug(
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
f"Current map_all map_res: {map_res}, is steam: "
f"{single_output.task_output.is_stream}"
)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
"""Apply a reduce function to all parent outputs."""
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply reduce function"
"The output in all tasks must has same output format of stream to apply"
" reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
"""Filter all parent outputs."""
new_outputs, results = await self._apply_func(
filter_func, apply_type="check_condition"
)
@ -331,15 +566,16 @@ class DefaultInputContext(InputContext):
return DefaultInputContext(result_outputs)
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
self, predicate_func: PredicateFunc, failed_value: Any = None
) -> "InputContext":
"""Apply a predicate function to all parent outputs."""
new_outputs, results = await self._apply_func(
predicate_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
if results[i]:
task_ctx = cast(TaskContext, task_ctx)
if not results[i].is_empty:
task_ctx.task_output.set_output(True)
result_outputs.append(task_ctx)
else:

View File

@ -66,10 +66,10 @@ async def _create_input_node(**kwargs):
else:
outputs = kwargs.get("outputs", ["Hello."])
nodes = []
for output in outputs:
for i, output in enumerate(outputs):
print(f"output: {output}")
input_source = SimpleInputSource(output)
input_node = InputOperator(input_source)
input_node = InputOperator(input_source, task_id="input_node_" + str(i))
nodes.append(input_node)
yield nodes

View File

@ -26,7 +26,7 @@ from .conftest import (
@pytest.mark.asyncio
async def test_input_node(runner: WorkflowRunner):
input_node = InputOperator(SimpleInputSource("hello"))
input_node = InputOperator(SimpleInputSource("hello"), task_id="112232")
res: DAGContext[str] = await runner.execute_workflow(input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert res.current_task_context.task_output.output == "hello"
@ -36,7 +36,9 @@ async def test_input_node(runner: WorkflowRunner):
yield i
num_iter = 10
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
steam_input_node = InputOperator(
SimpleInputSource(new_steam_iter(num_iter)), task_id="112232"
)
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
output_steam = res.current_task_context.task_output.output_stream

View File

@ -0,0 +1 @@
"""The trigger module of AWEL."""

View File

@ -1,3 +1,4 @@
"""Base class for all trigger classes."""
from __future__ import annotations
from abc import ABC, abstractmethod
@ -6,6 +7,11 @@ from ..operator.common_operator import TriggerOperator
class Trigger(TriggerOperator, ABC):
"""Base class for all trigger classes.
Now only support http trigger.
"""
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -1,10 +1,11 @@
"""Http trigger for AWEL."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
from starlette.requests import Request
from starlette.responses import Response
from dbgpt._private.pydantic import BaseModel
@ -13,29 +14,35 @@ from ..operator.base import BaseOperator
from .base import Trigger
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter
RequestBody = Union[Type[Request], Type[BaseModel], str]
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
logger = logging.getLogger(__name__)
class HttpTrigger(Trigger):
"""Http trigger for AWEL.
Http trigger is used to trigger a DAG by http request.
"""
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
streaming_response: bool = False,
streaming_predict_func: Optional[StreamingPredictFunc] = None,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str]] = None,
router_tags: Optional[List[str | Enum]] = None,
**kwargs,
) -> None:
"""Initialize a HttpTrigger."""
super().__init__(**kwargs)
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
@ -49,15 +56,21 @@ class HttpTrigger(Trigger):
self._router_tags = router_tags
self._response_headers = response_headers
self._response_media_type = response_media_type
self._end_node: BaseOperator = None
self._end_node: Optional[BaseOperator] = None
async def trigger(self) -> None:
"""Trigger the DAG. Not used in HttpTrigger."""
pass
def mount_to_router(self, router: "APIRouter") -> None:
"""Mount the trigger to a router.
Args:
router (APIRouter): The router to mount the trigger.
"""
from fastapi import Depends
methods = self._methods if isinstance(self._methods, list) else [self._methods]
methods = [self._methods] if isinstance(self._methods, str) else self._methods
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
@ -87,7 +100,8 @@ class HttpTrigger(Trigger):
)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
f"mount router function {dynamic_route_function}({function_name}), "
f"endpoint: {self._endpoint}, methods: {methods}"
)
router.api_route(
@ -100,17 +114,27 @@ class HttpTrigger(Trigger):
async def _parse_request_body(
request: Request, request_body_cls: Optional[Type[BaseModel]]
request: Request, request_body_cls: Optional[RequestBody]
):
if not request_body_cls:
return None
if request.method == "POST":
json_data = await request.json()
return request_body_cls(**json_data)
elif request.method == "GET":
return request_body_cls(**request.query_params)
else:
if request_body_cls == Request:
return request
if request.method == "POST":
if request_body_cls == str:
bytes_body = await request.body()
str_body = bytes_body.decode("utf-8")
return str_body
elif issubclass(request_body_cls, BaseModel):
json_data = await request.json()
return request_body_cls(**json_data)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
elif request.method == "GET":
if issubclass(request_body_cls, BaseModel):
return request_body_cls(**request.query_params)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
async def _trigger_dag(
@ -123,10 +147,10 @@ async def _trigger_dag(
from fastapi import BackgroundTasks
from fastapi.responses import StreamingResponse
end_node = dag.leaf_nodes
if len(end_node) != 1:
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
end_node = cast(BaseOperator, leaf_nodes[0])
if not streaming_response:
return await end_node.call(call_data={"data": body})
else:
@ -141,7 +165,7 @@ async def _trigger_dag(
}
generator = await end_node.call_stream(call_data={"data": body})
background_tasks = BackgroundTasks()
background_tasks.add_task(end_node.dag._after_dag_end)
background_tasks.add_task(dag._after_dag_end)
return StreamingResponse(
generator,
headers=headers,

View File

@ -1,41 +1,63 @@
"""Trigger manager for AWEL.
After DB-GPT started, the trigger manager will be initialized and register all triggers
"""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .base import Trigger
if TYPE_CHECKING:
from fastapi import APIRouter
from dbgpt.component import BaseComponent, ComponentType, SystemApp
logger = logging.getLogger(__name__)
class TriggerManager(ABC):
"""Base class for trigger manager."""
@abstractmethod
def register_trigger(self, trigger: Any) -> None:
""" "Register a trigger to current manager"""
"""Register a trigger to current manager."""
class HttpTriggerManager(TriggerManager):
"""Http trigger manager.
Register all http triggers to a router.
"""
def __init__(
self,
router: Optional["APIRouter"] = None,
router_prefix: Optional[str] = "/api/v1/awel/trigger",
router_prefix: str = "/api/v1/awel/trigger",
) -> None:
"""Initialize a HttpTriggerManager.
Args:
router (Optional["APIRouter"], optional): The router. Defaults to None.
If None, will create a new FastAPI router.
router_prefix (str, optional): The router prefix. Defaults
to "/api/v1/awel/trigger".
"""
if not router:
from fastapi import APIRouter
router = APIRouter()
self._router_prefix = router_prefix
self._router = router
self._trigger_map = {}
self._trigger_map: Dict[str, Trigger] = {}
def register_trigger(self, trigger: Any) -> None:
"""Register a trigger to current manager."""
from .http_trigger import HttpTrigger
if not isinstance(trigger, HttpTrigger):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger: HttpTrigger = trigger
trigger_id = trigger.node_id
if trigger_id not in self._trigger_map:
trigger.mount_to_router(self._router)
@ -45,23 +67,32 @@ class HttpTriggerManager(TriggerManager):
logger.info(
f"Include router {self._router} to prefix path {self._router_prefix}"
)
system_app.app.include_router(
self._router, prefix=self._router_prefix, tags=["AWEL"]
)
app = system_app.app
if not app:
raise RuntimeError("System app not initialized")
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
class DefaultTriggerManager(TriggerManager, BaseComponent):
"""Default trigger manager for AWEL.
Manage all trigger managers. Just support http trigger now.
"""
name = ComponentType.AWEL_TRIGGER_MANAGER
def __init__(self, system_app: SystemApp | None = None):
"""Initialize a DefaultTriggerManager."""
self.system_app = system_app
self.http_trigger = HttpTriggerManager()
super().__init__(None)
def init_app(self, system_app: SystemApp):
"""Initialize the trigger manager."""
self.system_app = system_app
def register_trigger(self, trigger: Any) -> None:
"""Register a trigger to current manager."""
from .http_trigger import HttpTrigger
if isinstance(trigger, HttpTrigger):
@ -71,4 +102,6 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
raise ValueError(f"Unsupport trigger: {trigger}")
def after_register(self) -> None:
self.http_trigger._init_app(self.system_app)
"""After register, init the trigger manager."""
if self.system_app:
self.http_trigger._init_app(self.system_app)

View File

@ -0,0 +1,4 @@
"""The core interface of DB-GPT.
Just include the core interface to keep our dependencies clean.
"""

View File

@ -1,3 +1,10 @@
"""The cache interface.
The cache interface is used to cache LLM results and embedding results.
Maybe we can cache more server results in the future.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
@ -10,17 +17,23 @@ V = TypeVar("V")
class RetrievalPolicy(str, Enum):
"""The retrieval policy of the cache."""
EXACT_MATCH = "exact_match"
SIMILARITY_MATCH = "similarity_match"
class CachePolicy(str, Enum):
"""The cache policy of the cache."""
LRU = "lru"
FIFO = "fifo"
@dataclass
class CacheConfig:
"""The cache config."""
retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH
cache_policy: Optional[CachePolicy] = CachePolicy.LRU
@ -30,7 +43,8 @@ class CacheKey(Serializable, ABC, Generic[K]):
Supported cache keys:
- The LLM cache key: Include user prompt and the parameters to LLM.
- The embedding model cache key: Include the texts to embedding and the parameters to embedding model.
- The embedding model cache key: Include the texts to embedding and the parameters
to embedding model.
"""
@abstractmethod
@ -76,7 +90,8 @@ class CacheClient(ABC, Generic[K, V]):
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None.
Optional[CacheValue[V]]: The value retrieved according to key. If cache key
not exist, return None.
"""
@abstractmethod
@ -110,8 +125,8 @@ class CacheClient(ABC, Generic[K, V]):
@abstractmethod
def new_key(self, **kwargs) -> CacheKey[K]:
"""Create a cache key with params"""
"""Create a cache key with params."""
@abstractmethod
def new_value(self, **kwargs) -> CacheValue[K]:
"""Create a cache key with params"""
"""Create a cache key with params."""

View File

@ -1,3 +1,5 @@
"""The interface for LLM."""
import collections
import copy
import logging
@ -31,7 +33,8 @@ class ModelInferenceMetrics:
"""The timestamp (in milliseconds) when the model inference ends."""
current_time_ms: Optional[int] = None
"""The current timestamp (in milliseconds) when the model inference return partially output(stream)."""
"""The current timestamp (in milliseconds) when the model inference return
partially output(stream)."""
first_token_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the first token is generated."""
@ -64,6 +67,14 @@ class ModelInferenceMetrics:
def create_metrics(
last_metrics: Optional["ModelInferenceMetrics"] = None,
) -> "ModelInferenceMetrics":
"""Create metrics for model inference.
Args:
last_metrics(ModelInferenceMetrics): The last metrics.
Returns:
ModelInferenceMetrics: The metrics for model inference.
"""
start_time_ms = last_metrics.start_time_ms if last_metrics else None
first_token_time_ms = last_metrics.first_token_time_ms if last_metrics else None
first_completion_time_ms = (
@ -100,15 +111,21 @@ class ModelInferenceMetrics:
)
def to_dict(self) -> Dict:
"""Convert the model inference metrics to dict."""
return asdict(self)
@dataclass
@PublicAPI(stability="beta")
class ModelRequestContext:
stream: Optional[bool] = False
"""A class to represent the context of a LLM model request."""
stream: bool = False
"""Whether to return a stream of responses."""
cache_enable: bool = False
"""Whether to enable the cache for the model inference"""
user_name: Optional[str] = None
"""The user name of the model request."""
@ -129,8 +146,6 @@ class ModelRequestContext:
request_id: Optional[str] = None
"""The request id of the model inference."""
cache_enable: Optional[bool] = False
"""Whether to enable the cache for the model inference"""
@dataclass
@ -141,27 +156,31 @@ class ModelOutput:
text: str
"""The generated text."""
error_code: int
"""The error code of the model inference. If the model inference is successful, the error code is 0."""
model_context: Dict = None
finish_reason: str = None
usage: Dict[str, Any] = None
"""The error code of the model inference. If the model inference is successful,
the error code is 0."""
model_context: Optional[Dict] = None
finish_reason: Optional[str] = None
usage: Optional[Dict[str, Any]] = None
metrics: Optional[ModelInferenceMetrics] = None
"""Some metrics for model inference"""
def to_dict(self) -> Dict:
"""Convert the model output to dict."""
return asdict(self)
_ModelMessageType = Union[ModelMessage, Dict[str, Any]]
_ModelMessageType = Union[List[ModelMessage], List[Dict[str, Any]]]
@dataclass
@PublicAPI(stability="beta")
class ModelRequest:
"""The model request."""
model: str
"""The name of the model."""
messages: List[_ModelMessageType]
messages: _ModelMessageType
"""The input messages."""
temperature: Optional[float] = None
@ -189,28 +208,42 @@ class ModelRequest:
@property
def stream(self) -> bool:
"""Whether to return a stream of responses."""
return self.context and self.context.stream
return bool(self.context and self.context.stream)
def copy(self):
def copy(self) -> "ModelRequest":
"""Copy the model request.
Returns:
ModelRequest: The copied model request.
"""
new_request = copy.deepcopy(self)
# Transform messages to List[ModelMessage]
new_request.messages = list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
new_request.messages,
)
)
new_request.messages = new_request.get_messages()
return new_request
def to_dict(self) -> Dict[str, Any]:
"""Convert the model request to dict.
Returns:
Dict[str, Any]: The model request in dict.
"""
new_reqeust = copy.deepcopy(self)
new_reqeust.messages = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
)
new_messages = []
for message in new_reqeust.messages:
if isinstance(message, dict):
new_messages.append(message)
else:
new_messages.append(message.dict())
new_reqeust.messages = new_messages
# Skip None fields
return {k: v for k, v in asdict(new_reqeust).items() if v is not None}
def to_trace_metadata(self):
def to_trace_metadata(self) -> Dict[str, Any]:
"""Convert the model request to trace metadata.
Returns:
Dict[str, Any]: The trace metadata.
"""
metadata = self.to_dict()
metadata["prompt"] = self.messages_to_string()
return metadata
@ -218,16 +251,19 @@ class ModelRequest:
def get_messages(self) -> List[ModelMessage]:
"""Get the messages.
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
If the messages is not a list of ModelMessage, it will be converted to a list
of ModelMessage.
Returns:
List[ModelMessage]: The messages.
"""
return list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
self.messages,
)
)
messages = []
for message in self.messages:
if isinstance(message, dict):
messages.append(ModelMessage(**message))
else:
messages.append(message)
return messages
def get_single_user_message(self) -> Optional[ModelMessage]:
"""Get the single user message.
@ -245,20 +281,35 @@ class ModelRequest:
model: str,
messages: List[ModelMessage],
context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None,
stream: Optional[bool] = False,
echo: Optional[bool] = False,
stream: bool = False,
echo: bool = False,
**kwargs,
):
"""Build a model request.
Args:
model(str): The model name.
messages(List[ModelMessage]): The messages.
context(Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]]):
The context.
stream(bool): Whether to return a stream of responses. Defaults to False.
echo(bool): Whether to echo the input messages. Defaults to False.
**kwargs: Other arguments.
"""
if not context:
context = ModelRequestContext(stream=stream)
context_dict = None
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
context = ModelRequestContext(**context_dict)
elif not isinstance(context, ModelRequestContext):
context_dict = None
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
if context_dict:
context = ModelRequestContext(**context_dict)
else:
context = ModelRequestContext(stream=stream)
return ModelRequest(
model=model,
messages=messages,
@ -292,7 +343,6 @@ class ModelRequest:
ValueError: If the message role is not supported
Examples:
.. code-block:: python
from dbgpt.core.interface.message import (
@ -337,7 +387,7 @@ class ModelRequest:
class ModelExtraMedata(BaseParameters):
"""A class to represent the extra metadata of a LLM."""
prompt_roles: Optional[List[str]] = field(
prompt_roles: List[str] = field(
default_factory=lambda: [
ModelMessageRoleType.SYSTEM,
ModelMessageRoleType.HUMAN,
@ -356,7 +406,8 @@ class ModelExtraMedata(BaseParameters):
prompt_chat_template: Optional[str] = field(
default=None,
metadata={
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
"help": "The chat template, see: "
"https://huggingface.co/docs/transformers/main/en/chat_templating"
},
)
@ -403,19 +454,19 @@ class ModelMetadata(BaseParameters):
def from_dict(
cls, data: dict, ignore_extra_fields: bool = False
) -> "ModelMetadata":
"""Create a new model metadata from a dict."""
if "ext_metadata" in data:
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
return cls(**data)
class MessageConverter(ABC):
"""An abstract class for message converter.
r"""An abstract class for message converter.
Different LLMs may have different message formats, this class is used to convert the messages
to the format of the LLM.
Different LLMs may have different message formats, this class is used to convert
the messages to the format of the LLM.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
@ -425,7 +476,8 @@ class MessageConverter(ABC):
... messages: List[ModelMessage],
... model_metadata: Optional[ModelMetadata] = None,
... ) -> List[ModelMessage]:
... # Convert the messages, merge system messages to the last user message.
... # Convert the messages, merge system messages to the last user
... # message.
... system_message = None
... other_messages = []
... sep = "\\n"
@ -478,6 +530,7 @@ class DefaultMessageConverter(MessageConverter):
"""The default message converter."""
def __init__(self, prompt_sep: Optional[str] = None):
"""Create a new default message converter."""
self._prompt_sep = prompt_sep
def convert(
@ -493,7 +546,8 @@ class DefaultMessageConverter(MessageConverter):
2. Move the last user's message to the end of the list
3. Convert the messages to no system message if the model does not support system message
3. Convert the messages to no system message if the model does not support
system message
Args:
messages(List[ModelMessage]): The messages.
@ -520,10 +574,11 @@ class DefaultMessageConverter(MessageConverter):
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages to no system message.
r"""Convert the messages to no system message.
Examples:
>>> # Convert the messages to no system message, just merge system messages to the last user message
>>> # Convert the messages to no system message, just merge system messages
>>> # to the last user message
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
@ -550,7 +605,7 @@ class DefaultMessageConverter(MessageConverter):
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... content="You are a helpful assistant\nWho are you",
... ),
... ]
"""
@ -562,7 +617,8 @@ class DefaultMessageConverter(MessageConverter):
result_messages = []
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
# Not support system message, append system message to the last user message
# Not support system message, append system message to the last user
# message
system_messages.append(message)
elif message.role in [
ModelMessageRoleType.HUMAN,
@ -578,7 +634,8 @@ class DefaultMessageConverter(MessageConverter):
system_message_str = system_messages[0].content
if system_message_str and result_messages:
# Not support system messages, merge system messages to the last user message
# Not support system messages, merge system messages to the last user
# message
result_messages[-1].content = (
system_message_str + prompt_sep + result_messages[-1].content
)
@ -587,10 +644,9 @@ class DefaultMessageConverter(MessageConverter):
def move_last_user_message_to_end(
self, messages: List[ModelMessage]
) -> List[ModelMessage]:
"""Move the last user message to the end of the list.
"""Try to move the last user message to the end of the list.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
@ -660,7 +716,7 @@ class LLMClient(ABC):
@property
def cache(self) -> collections.abc.MutableMapping:
"""The cache object to cache the model metadata.
"""Return the cache object to cache the model metadata.
You can override this property to use your own cache object.
Returns:
@ -677,7 +733,8 @@ class LLMClient(ABC):
"""Generate a response for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
you can use the message converter to convert the messages to the format of the
LLM.
Args:
request(ModelRequest): The model request.
@ -697,7 +754,8 @@ class LLMClient(ABC):
"""Generate a stream of responses for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
you can use the message converter to convert the messages to the format of the
LLM.
Args:
request(ModelRequest): The model request.
@ -733,6 +791,7 @@ class LLMClient(ABC):
message_converter: Optional[MessageConverter] = None,
) -> ModelRequest:
"""Covert the message.
If no message converter is provided, the original request will be returned.
Args:
@ -746,14 +805,15 @@ class LLMClient(ABC):
return request
new_request = request.copy()
model_metadata = await self.get_model_metadata(request.model)
new_messages = message_converter.convert(request.messages, model_metadata)
new_messages = message_converter.convert(request.get_messages(), model_metadata)
new_request.messages = new_messages
return new_request
async def cached_models(self) -> List[ModelMetadata]:
"""Get all the models from the cache or the llm server.
If the model metadata is not in the cache, it will be fetched from the llm server.
If the model metadata is not in the cache, it will be fetched from the
llm server.
Returns:
List[ModelMetadata]: A list of model metadata.

View File

@ -1,8 +1,10 @@
"""The conversation and message module."""
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.interface.storage import (
@ -29,11 +31,11 @@ class BaseMessage(BaseModel, ABC):
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model"""
"""Whether the message will be passed to the model."""
return True
def to_dict(self) -> Dict:
"""Convert to dict
"""Convert to dict.
Returns:
Dict: The dict object
@ -47,7 +49,7 @@ class BaseMessage(BaseModel, ABC):
@staticmethod
def messages_to_string(messages: List["BaseMessage"]) -> str:
"""Convert messages to str
"""Convert messages to str.
Args:
messages (List[BaseMessage]): The messages
@ -92,7 +94,7 @@ class ViewMessage(BaseMessage):
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
"""Whether the message will be passed to the model.
The view message will not be passed to the model
"""
@ -109,7 +111,7 @@ class SystemMessage(BaseMessage):
class ModelMessageRoleType:
""" "Type of ModelMessage role"""
"""Type of ModelMessage role."""
SYSTEM = "system"
HUMAN = "human"
@ -118,7 +120,7 @@ class ModelMessageRoleType:
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Type of message that interaction between dbgpt-server and llm-server."""
"""Similar to openai's message format"""
role: str
@ -127,7 +129,7 @@ class ModelMessage(BaseModel):
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
"""Whether the message will be passed to the model.
The view message will not be passed to the model
@ -142,6 +144,14 @@ class ModelMessage(BaseModel):
@staticmethod
def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]:
"""Covert BaseMessage format to current ModelMessage format.
Args:
messages (List[BaseMessage]): The base messages
Returns:
List[ModelMessage]: The model messages
"""
result = []
for message in messages:
content, round_index = message.content, message.round_index
@ -173,7 +183,7 @@ class ModelMessage(BaseModel):
def from_openai_messages(
messages: Union[str, List[Dict[str, str]]]
) -> List["ModelMessage"]:
"""Openai message format to current ModelMessage format"""
"""Openai message format to current ModelMessage format."""
if isinstance(messages, str):
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
result = []
@ -202,8 +212,11 @@ class ModelMessage(BaseModel):
convert_to_compatible_format: bool = False,
support_system_role: bool = True,
) -> List[Dict[str, str]]:
"""Convert to common message format(e.g. OpenAI message format) and
huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
"""Cover to common message format.
Convert to common message format(e.g. OpenAI message format) and
huggingface [Templates of Chat Models]
(https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
Args:
messages (List["ModelMessage"]): The model messages
@ -243,15 +256,38 @@ class ModelMessage(BaseModel):
@staticmethod
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
"""Convert to dict list.
Args:
messages (List["ModelMessage"]): The model messages
Returns:
List[Dict[str, str]]: The dict list
"""
return list(map(lambda m: m.dict(), messages))
@staticmethod
def build_human_message(content: str) -> "ModelMessage":
"""Build human message.
Args:
content (str): The content
Returns:
ModelMessage: The model message
"""
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
@staticmethod
def get_printable_message(messages: List["ModelMessage"]) -> str:
"""Get the printable message"""
"""Get the printable message.
Args:
messages (List["ModelMessage"]): The model messages
Returns:
str: The printable message
"""
str_msg = ""
for message in messages:
curr_message = (
@ -263,7 +299,7 @@ class ModelMessage(BaseModel):
@staticmethod
def messages_to_string(messages: List["ModelMessage"]) -> str:
"""Convert messages to str
"""Convert messages to str.
Args:
messages (List[ModelMessage]): The messages
@ -287,12 +323,12 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
def _messages_to_str(
messages: List[Union[BaseMessage, ModelMessage]],
messages: Union[List[BaseMessage], List[ModelMessage]],
human_prefix: str = "Human",
ai_prefix: str = "AI",
system_prefix: str = "System",
) -> str:
"""Convert messages to str
"""Convert messages to str.
Args:
messages (List[Union[BaseMessage, ModelMessage]]): The messages
@ -343,21 +379,27 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
def parse_model_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]:
"""
Parse model messages to extract the user prompt, system messages, and a history of conversation.
) -> Tuple[str, List[str], List[List[str]]]:
"""Parse model messages.
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
pairs of human and AI messages.
Parse model messages to extract the user prompt, system messages, and a history of
conversation.
This function analyzes a list of ModelMessage objects, identifying the role of each
message (e.g., human, system, ai)
and categorizes them accordingly. The last message is expected to be from the user
(human), and it's treated as
the current user prompt. System messages are extracted separately, and the
conversation history is compiled into pairs of human and AI messages.
Args:
messages (List[ModelMessage]): List of messages from a chat conversation.
Returns:
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
tuple: A tuple containing the user prompt, list of system messages, and the
conversation history.
The conversation history is a list of message pairs, each containing a
user message and the corresponding AI response.
Examples:
.. code-block:: python
@ -399,7 +441,6 @@ def parse_model_messages(
# system_messages: ["Error 404"]
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
"""
system_messages: List[str] = []
history_messages: List[List[str]] = [[]]
@ -420,27 +461,30 @@ def parse_model_messages(
class OnceConversation:
"""All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""Once conversation.
All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
def __init__(
self,
chat_mode: str,
user_name: str = None,
sys_code: str = None,
summary: str = None,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
summary: Optional[str] = None,
**kwargs,
):
"""Create a new conversation."""
self.chat_mode: str = chat_mode
self.user_name: str = user_name
self.sys_code: str = sys_code
self.summary: str = summary
self.user_name: Optional[str] = user_name
self.sys_code: Optional[str] = sys_code
self.summary: Optional[str] = summary
self.messages: List[BaseMessage] = kwargs.get("messages", [])
self.start_date: str = kwargs.get("start_date", "")
# After each complete round of dialogue, the current value will be increased by 1
# After each complete round of dialogue, the current value will be
# increased by 1
self.chat_order: int = int(kwargs.get("chat_order", 0))
self.model_name: str = kwargs.get("model_name", "")
self.param_type: str = kwargs.get("param_type", "")
@ -460,10 +504,9 @@ class OnceConversation:
self.messages.append(message)
def start_new_round(self) -> None:
"""Start a new round of conversation
"""Start a new round of conversation.
Example:
>>> conversation = OnceConversation("chat_normal")
>>> # The chat order will be 0, then we start a new round of conversation
>>> assert conversation.chat_order == 0
@ -473,7 +516,8 @@ class OnceConversation:
>>> conversation.add_user_message("hello")
>>> conversation.add_ai_message("hi")
>>> conversation.end_current_round()
>>> # Now the chat order will be 1, then we start a new round of conversation
>>> # Now the chat order will be 1, then we start a new round of
>>> # conversation
>>> conversation.start_new_round()
>>> # Now the chat order will be 2
>>> assert conversation.chat_order == 2
@ -485,7 +529,7 @@ class OnceConversation:
self.chat_order += 1
def end_current_round(self) -> None:
"""End the current round of conversation
"""Execute the end of the current round of conversation.
We do noting here, just for the interface
"""
@ -494,7 +538,7 @@ class OnceConversation:
def add_user_message(
self, message: str, check_duplicate_type: Optional[bool] = False
) -> None:
"""Add a user message to the conversation
"""Save a user message to the conversation.
Args:
message (str): The message content
@ -514,11 +558,12 @@ class OnceConversation:
def add_ai_message(
self, message: str, update_if_exist: Optional[bool] = False
) -> None:
"""Add an AI message to the conversation
"""Save an AI message to current conversation.
Args:
message (str): The message content
update_if_exist (bool): Whether to update the message if the message type is duplicate
update_if_exist (bool): Whether to update the message if the message type
is duplicate
"""
if not update_if_exist:
self._append_message(AIMessage(content=message))
@ -530,51 +575,57 @@ class OnceConversation:
self._append_message(AIMessage(content=message))
def _update_ai_message(self, new_message: str) -> None:
"""
"""Update the all AI message to new message.
stream out message update
Args:
new_message:
Returns:
new_message (str): The new message
"""
for item in self.messages:
if item.type == "ai":
item.content = new_message
def add_view_message(self, message: str) -> None:
"""Add an AI message to the store"""
"""Save a view message to current conversation."""
self._append_message(ViewMessage(content=message))
def add_system_message(self, message: str) -> None:
"""Add a system message to the store"""
"""Save a system message to current conversation."""
self._append_message(SystemMessage(content=message))
def set_start_time(self, datatime: datetime):
"""Set the start time of the conversation."""
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
self.start_date = dt_str
def clear(self) -> None:
"""Remove all messages from the store"""
"""Remove all messages from the store."""
self.messages.clear()
def get_latest_user_message(self) -> Optional[HumanMessage]:
"""Get the latest user message"""
"""Get the latest user message."""
for message in self.messages[::-1]:
if isinstance(message, HumanMessage):
return message
return None
def get_system_messages(self) -> List[SystemMessage]:
"""Get the latest user message"""
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
"""Get the latest user message.
Returns:
List[SystemMessage]: The system messages
"""
return cast(
List[SystemMessage],
list(filter(lambda x: isinstance(x, SystemMessage), self.messages)),
)
def _to_dict(self) -> Dict:
return _conversation_to_dict(self)
def from_conversation(self, conversation: OnceConversation) -> None:
"""Load the conversation from the storage"""
"""Load the conversation from the storage."""
self.chat_mode = conversation.chat_mode
self.messages = conversation.messages
self.start_date = conversation.start_date
@ -592,7 +643,7 @@ class OnceConversation:
self._message_index = conversation._message_index
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
"""Get the messages by round index
"""Get the messages by round index.
Args:
round_index (int): The round index
@ -603,7 +654,7 @@ class OnceConversation:
return list(filter(lambda x: x.round_index == round_index, self.messages))
def get_latest_round(self) -> List[BaseMessage]:
"""Get the latest round messages
"""Get the latest round messages.
Returns:
List[BaseMessage]: The messages
@ -611,7 +662,7 @@ class OnceConversation:
return self.get_messages_by_round(self.chat_order)
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
"""Get the messages with round count
"""Get the messages with round count.
If the round count is 1, the history messages will not be included.
@ -660,16 +711,19 @@ class OnceConversation:
return messages
def get_model_messages(self) -> List[ModelMessage]:
"""Get the model messages
"""Get the model messages.
Model messages just include human, ai and system messages.
Model messages maybe include the history messages, The order of the messages is the same as the order of
Model messages maybe include the history messages, The order of the messages is
the same as the order of
the messages in the conversation, the last message is the latest message.
If you want to hand the message with your own logic, you can override this method.
If you want to hand the message with your own logic, you can override this
method.
Examples:
If you not need the history messages, you can override this method like this:
If you not need the history messages, you can override this method
like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
@ -681,7 +735,8 @@ class OnceConversation:
)
return messages
If you want to add the one round history messages, you can override this method like this:
If you want to add the one round history messages, you can override this
method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
@ -717,7 +772,7 @@ class OnceConversation:
def get_history_message(
self, include_system_message: bool = False
) -> List[BaseMessage]:
"""Get the history message
"""Get the history message.
Not include the system messages.
@ -729,46 +784,60 @@ class OnceConversation:
"""
messages = []
for message in self.messages:
if message.pass_to_model:
if include_system_message:
messages.append(message)
elif message.type != "system":
messages.append(message)
if (
message.pass_to_model
and include_system_message
or message.type != "system"
):
messages.append(message)
return messages
class ConversationIdentifier(ResourceIdentifier):
"""Conversation identifier"""
"""Conversation identifier."""
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
"""Create a conversation identifier.
Args:
conv_uid (str): The conversation uid
identifier_type (str): The identifier type
"""
self.conv_uid = conv_uid
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
"""Return the str identifier."""
return f"{self.identifier_type}:{self.conv_uid}"
def to_dict(self) -> Dict:
"""Convert to dict."""
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
class MessageIdentifier(ResourceIdentifier):
"""Message identifier"""
"""Message identifier."""
identifier_split = "___"
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
"""Create a message identifier."""
self.conv_uid = conv_uid
self.index = index
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
"""Return the str identifier."""
return (
f"{self.identifier_type}{self.identifier_split}{self.conv_uid}"
f"{self.identifier_split}{self.index}"
)
@staticmethod
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
"""Convert from str identifier
"""Convert from str identifier.
Args:
str_identifier (str): The str identifier
@ -782,6 +851,7 @@ class MessageIdentifier(ResourceIdentifier):
return MessageIdentifier(parts[1], int(parts[2]))
def to_dict(self) -> Dict:
"""Convert to dict."""
return {
"conv_uid": self.conv_uid,
"index": self.index,
@ -790,17 +860,31 @@ class MessageIdentifier(ResourceIdentifier):
class MessageStorageItem(StorageItem):
"""The message storage item.
Keep the message detail and the message index.
"""
@property
def identifier(self) -> MessageIdentifier:
"""Return the identifier."""
return self._id
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
"""Create a message storage item.
Args:
conv_uid (str): The conversation uid
index (int): The message index
message_detail (Dict): The message detail
"""
self.conv_uid = conv_uid
self.index = index
self.message_detail = message_detail
self._id = MessageIdentifier(conv_uid, index)
def to_dict(self) -> Dict:
"""Convert to dict."""
return {
"conv_uid": self.conv_uid,
"index": self.index,
@ -808,7 +892,8 @@ class MessageStorageItem(StorageItem):
}
def to_message(self) -> BaseMessage:
"""Convert to message object
"""Convert to message object.
Returns:
BaseMessage: The message object
@ -818,7 +903,7 @@ class MessageStorageItem(StorageItem):
return _message_from_dict(self.message_detail)
def merge(self, other: "StorageItem") -> None:
"""Merge the other message to self
"""Merge the other message to self.
Args:
other (StorageItem): The other message
@ -829,16 +914,20 @@ class MessageStorageItem(StorageItem):
class StorageConversation(OnceConversation, StorageItem):
"""All the information of a conversation, the current single service in memory,
"""The storage conversation.
All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
@property
def identifier(self) -> ConversationIdentifier:
"""Return the identifier."""
return self._id
def to_dict(self) -> Dict:
"""Convert to dict."""
dict_data = self._to_dict()
messages: Dict = dict_data.pop("messages")
message_ids = []
@ -859,7 +948,7 @@ class StorageConversation(OnceConversation, StorageItem):
return dict_data
def merge(self, other: "StorageItem") -> None:
"""Merge the other conversation to self
"""Merge the other conversation to self.
Args:
other (StorageItem): The other conversation
@ -871,17 +960,18 @@ class StorageConversation(OnceConversation, StorageItem):
def __init__(
self,
conv_uid: str,
chat_mode: str = None,
user_name: str = None,
sys_code: str = None,
message_ids: List[str] = None,
summary: str = None,
save_message_independent: Optional[bool] = True,
conv_storage: StorageInterface = None,
message_storage: StorageInterface = None,
chat_mode: str = "chat_normal",
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
message_ids: Optional[List[str]] = None,
summary: Optional[str] = None,
save_message_independent: bool = True,
conv_storage: Optional[StorageInterface] = None,
message_storage: Optional[StorageInterface] = None,
load_message: bool = True,
**kwargs,
):
"""Create a conversation."""
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
self.conv_uid = conv_uid
self._message_ids = message_ids
@ -905,7 +995,7 @@ class StorageConversation(OnceConversation, StorageItem):
@property
def message_ids(self) -> List[str]:
"""Get the message ids
"""Return the message ids.
Returns:
List[str]: The message ids
@ -913,7 +1003,7 @@ class StorageConversation(OnceConversation, StorageItem):
return self._message_ids if self._message_ids else []
def end_current_round(self) -> None:
"""End the current round of conversation
"""End the current round of conversation.
Save the conversation to the storage after a round of conversation
"""
@ -926,7 +1016,7 @@ class StorageConversation(OnceConversation, StorageItem):
]
def save_to_storage(self) -> None:
"""Save the conversation to the storage"""
"""Save the conversation to the storage."""
# Save messages first
message_list = self._get_message_items()
self._message_ids = [
@ -943,7 +1033,7 @@ class StorageConversation(OnceConversation, StorageItem):
def load_from_storage(
self, conv_storage: StorageInterface, message_storage: StorageInterface
) -> None:
"""Load the conversation from the storage
"""Load the conversation from the storage.
Warning: This will overwrite the current conversation.
@ -952,7 +1042,7 @@ class StorageConversation(OnceConversation, StorageItem):
message_storage (StorageInterface): The storage interface
"""
# Load conversation first
conversation: StorageConversation = conv_storage.load(
conversation: Optional[StorageConversation] = conv_storage.load(
self._id, StorageConversation
)
if conversation is None:
@ -988,18 +1078,18 @@ class StorageConversation(OnceConversation, StorageItem):
def _append_additional_kwargs(
self, conversation: StorageConversation, messages: List[BaseMessage]
) -> None:
"""Parse the additional kwargs and append to the conversation
"""Parse the additional kwargs and append to the conversation.
Args:
conversation (StorageConversation): The conversation
messages (List[BaseMessage]): The messages
"""
param_type = None
param_value = None
param_type = ""
param_value = ""
for message in messages[::-1]:
if message.additional_kwargs:
param_type = message.additional_kwargs.get("param_type")
param_value = message.additional_kwargs.get("param_value")
param_type = message.additional_kwargs.get("param_type", "")
param_value = message.additional_kwargs.get("param_value", "")
break
if not conversation.param_type:
conversation.param_type = param_type
@ -1007,7 +1097,7 @@ class StorageConversation(OnceConversation, StorageItem):
conversation.param_value = param_value
def delete(self) -> None:
"""Delete all the messages and conversation from the storage"""
"""Delete all the messages and conversation."""
# Delete messages first
message_list = self._get_message_items()
message_ids = [message.identifier for message in message_list]
@ -1055,13 +1145,13 @@ def _conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
def _conversation_from_dict(once: dict) -> OnceConversation:
conversation = OnceConversation(
once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
once.get("chat_mode", ""), once.get("user_name"), once.get("sys_code")
)
conversation.cost = once.get("cost", 0)
conversation.chat_mode = once.get("chat_mode", "chat_normal")
conversation.tokens = once.get("tokens", 0)
conversation.start_date = once.get("start_date", "")
conversation.chat_order = int(once.get("chat_order"))
conversation.chat_order = int(once.get("chat_order", 0))
conversation.param_type = once.get("param_type", "")
conversation.param_value = once.get("param_value", "")
conversation.model_name = once.get("model_name", "proxyllm")
@ -1093,7 +1183,8 @@ def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessa
def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
"""Append the view message to the messages
"""Append the view message to the messages.
Just for show in DB-GPT-Web.
If already have view message, do nothing.

View File

@ -0,0 +1 @@
"""The module include all core operators of DB-GPT."""

View File

@ -1,5 +1,9 @@
"""The chat history prompt composer operator.
We can wrap some atomic operators to a complex operator.
"""
import dataclasses
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from dbgpt.core import (
ChatPromptTemplate,
@ -51,6 +55,7 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
"""Create a new chat history prompt composer operator."""
super().__init__(**kwargs)
self._prompt_template = prompt_template
self._history_key = history_key
@ -61,7 +66,8 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
self._sub_compose_dag = self._build_composer_dag()
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
"""Compose the chat history prompt."""
end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0])
# Sub dag, use the same dag context in the parent dag
return await end_node.call(
call_data={"data": input_value}, dag_ctx=self.current_dag_context
@ -82,7 +88,9 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template, history_key=self._history_key
)
model_request_build_task = JoinOperator(self._build_model_request)
model_request_build_task: JoinOperator[ModelRequest] = JoinOperator(
combine_function=self._build_model_request
)
# Build composer dag
(
@ -113,5 +121,6 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
return ModelRequest.build_request(messages=messages, **model_dict)
async def after_dag_end(self):
"""Execute after dag end."""
# Should call after_dag_end() of sub dag
await self._sub_compose_dag._after_dag_end()

View File

@ -1,9 +1,12 @@
"""The LLM operator."""
import dataclasses
from abc import ABC
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.awel import (
BaseOperator,
BranchFunc,
BranchOperator,
DAGContext,
@ -32,11 +35,13 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
"""Build the model request from the input value."""
def __init__(self, model: Optional[str] = None, **kwargs):
"""Create a new request builder operator."""
self._model = model
super().__init__(**kwargs)
async def map(self, input_value: RequestInput) -> ModelRequest:
req_dict = {}
"""Transform the input value to a model request."""
req_dict: Dict[str, Any] = {}
if not input_value:
raise ValueError("input_value is not set")
if isinstance(input_value, str):
@ -47,7 +52,9 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
req_dict = {"messages": [input_value]}
elif isinstance(input_value, list) and isinstance(input_value[0], ModelMessage):
req_dict = {"messages": input_value}
elif dataclasses.is_dataclass(input_value):
elif dataclasses.is_dataclass(input_value) and not isinstance(
input_value, type
):
req_dict = dataclasses.asdict(input_value)
elif isinstance(input_value, BaseModel):
req_dict = input_value.dict()
@ -90,6 +97,7 @@ class BaseLLM:
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
def __init__(self, llm_client: Optional[LLMClient] = None):
"""Create a new LLM operator."""
self._llm_client = llm_client
@property
@ -118,10 +126,19 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
"""Create a new LLM operator."""
super().__init__(llm_client=llm_client)
MapOperator.__init__(self, **kwargs)
async def map(self, request: ModelRequest) -> ModelOutput:
"""Generate the model output.
Args:
request (ModelRequest): The model request.
Returns:
ModelOutput: The model output.
"""
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_NAME, request.model
)
@ -142,15 +159,23 @@ class BaseStreamingLLMOperator(
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
StreamifyAbsOperator.__init__(self, **kwargs)
"""Create a streaming operator for a LLM.
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
"""
super().__init__(llm_client=llm_client)
BaseOperator.__init__(self, **kwargs)
async def streamify( # type: ignore
self, request: ModelRequest # type: ignore
) -> AsyncIterator[ModelOutput]: # type: ignore
"""Streamify the request."""
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_NAME, request.model
)
model_output = None
async for output in self.llm_client.generate_stream(request):
async for output in self.llm_client.generate_stream(request): # type: ignore
model_output = output
yield output
if model_output:
@ -160,10 +185,17 @@ class BaseStreamingLLMOperator(
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
"""Branch operator for LLM.
This operator will branch the workflow based on the stream flag of the request.
This operator will branch the workflow based on
the stream flag of the request.
"""
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
"""Create a new LLM branch operator.
Args:
stream_task_name (str): The name of the streaming task.
no_stream_task_name (str): The name of the non-streaming task.
"""
super().__init__(**kwargs)
if not stream_task_name:
raise ValueError("stream_task_name is not set")
@ -172,18 +204,22 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
self._stream_task_name = stream_task_name
self._no_stream_task_name = no_stream_task_name
async def branches(self) -> Dict[BranchFunc[ModelRequest], str]:
async def branches(
self,
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
"""
Return a dict of branch function and task name.
Returns:
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task name.
the key is a predicate function, the value is the task name. If the predicate function returns True,
we will run the corresponding task.
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task
name. the key is a predicate function, the value is the task name.
If the predicate function returns True, we will run the corresponding
task.
"""
async def check_stream_true(r: ModelRequest) -> bool:
# If stream is true, we will run the streaming task. otherwise, we will run the non-streaming task.
# If stream is true, we will run the streaming task. otherwise, we will run
# the non-streaming task.
return r.stream
return {

View File

@ -1,3 +1,4 @@
"""The message operator."""
import uuid
from abc import ABC
from typing import Any, Callable, Dict, List, Optional, Union
@ -36,6 +37,7 @@ class BaseConversationOperator(BaseOperator, ABC):
check_storage: bool = True,
**kwargs,
):
"""Create a new BaseConversationOperator."""
self._check_storage = check_storage
self._storage = storage
self._message_storage = message_storage
@ -102,7 +104,8 @@ class PreChatHistoryLoadOperator(
):
"""The operator to prepare the storage conversation.
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
In DB-GPT, conversation record and the messages in the conversation are stored in
the storage,
and they can store in different storage(for high performance).
This operator just load the conversation and messages from storage.
@ -115,6 +118,7 @@ class PreChatHistoryLoadOperator(
include_system_message: bool = False,
**kwargs,
):
"""Create a new PreChatHistoryLoadOperator."""
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs)
self._include_system_message = include_system_message
@ -139,7 +143,8 @@ class PreChatHistoryLoadOperator(
chat_mode = input_value.chat_mode
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
# Create a new storage conversation, this will load the conversation from
# storage, so we must do this async
storage_conv: StorageConversation = await self.blocking_func_to_async(
StorageConversation,
conv_uid=input_value.conv_uid,
@ -167,14 +172,21 @@ class PreChatHistoryLoadOperator(
class ConversationMapperOperator(
BaseConversationOperator, MapOperator[List[BaseMessage], List[BaseMessage]]
):
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
"""The base conversation mapper operator."""
def __init__(
self, message_mapper: Optional[_MultiRoundMessageMapper] = None, **kwargs
):
"""Create a new ConversationMapperOperator."""
MapOperator.__init__(self, **kwargs)
self._message_mapper = message_mapper
async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
"""Map the input value to a ModelRequest."""
return await self.map_messages(input_value)
async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
"""Map multi round messages to a list of BaseMessage."""
messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round)
@ -184,11 +196,11 @@ class ConversationMapperOperator(
) -> List[BaseMessage]:
"""Map multi round messages to a list of BaseMessage.
By default, just merge all multi round messages to a list of BaseMessage according origin order.
By default, just merge all multi round messages to a list of BaseMessage
according origin order.
And you can overwrite this method to implement your own logic.
Examples:
Merge multi round messages to a list of BaseMessage according origin order.
>>> from dbgpt.core.interface.message import (
@ -215,7 +227,8 @@ class ConversationMapperOperator(
... AIMessage(content="Just a joke.", round_index=2),
... ]
Map multi round messages to a list of BaseMessage just keep the last one round.
Map multi round messages to a list of BaseMessage just keep the last one
round.
>>> class MyMapper(ConversationMapperOperator):
... def __init__(self, **kwargs):
@ -234,13 +247,16 @@ class ConversationMapperOperator(
... ]
Args:
messages_by_round (List[List[BaseMessage]]):
The messages grouped by round.
"""
# Just merge and return
return _merge_multi_round_messages(messages_by_round)
class BufferedConversationMapperOperator(ConversationMapperOperator):
"""
"""Buffered conversation mapper operator.
The buffered conversation mapper operator which can be configured to keep
a certain number of starting and/or ending rounds of a conversation.
@ -249,39 +265,44 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
keep_end_rounds (Optional[int]): Number of final rounds to keep.
Examples:
# Keeping the first 2 and the last 1 rounds of a conversation
import asyncio
from dbgpt.core.interface.message import AIMessage, HumanMessage
from dbgpt.core.operator import BufferedConversationMapperOperator
.. code-block:: python
operator = BufferedConversationMapperOperator(keep_start_rounds=2, keep_end_rounds=1)
messages = [
# Assume each HumanMessage and AIMessage belongs to separate rounds
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# This will keep rounds 1, 2, and 3
assert asyncio.run(operator.map_messages(messages)) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Keeping the first 2 and the last 1 rounds of a conversation
import asyncio
from dbgpt.core.interface.message import AIMessage, HumanMessage
from dbgpt.core.operator import BufferedConversationMapperOperator
operator = BufferedConversationMapperOperator(
keep_start_rounds=2, keep_end_rounds=1
)
messages = [
# Assume each HumanMessage and AIMessage belongs to separate rounds
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# This will keep rounds 1, 2, and 3
assert asyncio.run(operator.map_messages(messages)) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
"""
def __init__(
self,
keep_start_rounds: Optional[int] = None,
keep_end_rounds: Optional[int] = None,
message_mapper: _MultiRoundMessageMapper = None,
message_mapper: Optional[_MultiRoundMessageMapper] = None,
**kwargs,
):
"""Create a new BufferedConversationMapperOperator."""
# Validate the input parameters
if keep_start_rounds is not None and keep_start_rounds < 0:
raise ValueError("keep_start_rounds must be non-negative")
@ -311,10 +332,11 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
def _filter_round_messages(
self, messages_by_round: List[List[BaseMessage]]
) -> List[List[BaseMessage]]:
"""Filters the messages to keep only the specified starting and/or ending rounds.
"""Return a filtered list of messages.
Filters the messages to keep only the specified starting and/or ending rounds.
Examples:
>>> from dbgpt.core import AIMessage, HumanMessage
>>> from dbgpt.core.operator import BufferedConversationMapperOperator
>>> messages = [
@ -395,15 +417,18 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ]
Args:
messages_by_round (List[List[BaseMessage]]): The messages grouped by round.
messages_by_round (List[List[BaseMessage]]):
The messages grouped by round.
Returns:
List[List[BaseMessage]]: Filtered list of messages.
Returns:
List[List[BaseMessage]]: Filtered list of messages.
"""
total_rounds = len(messages_by_round)
if self._keep_start_rounds is not None and self._keep_end_rounds is not None:
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
# Avoid overlapping when the sum of start and end rounds exceeds total rounds
# Avoid overlapping when the sum of start and end rounds exceeds total
# rounds
return messages_by_round
return (
messages_by_round[: self._keep_start_rounds]
@ -423,14 +448,16 @@ EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]
class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
"""The token buffered conversation mapper operator.
If the token count of the messages is greater than the max token limit, we will evict the messages by round.
If the token count of the messages is greater than the max token limit, we will
evict the messages by round.
Args:
model (str): The model name.
llm_client (LLMClient): The LLM client.
max_token_limit (int): The max token limit.
eviction_policy (EvictionPolicyType): The eviction policy.
message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after all messages are handled.
message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after
all messages are handled.
"""
def __init__(
@ -438,10 +465,11 @@ class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
model: str,
llm_client: LLMClient,
max_token_limit: int = 2000,
eviction_policy: EvictionPolicyType = None,
message_mapper: _MultiRoundMessageMapper = None,
eviction_policy: Optional[EvictionPolicyType] = None,
message_mapper: Optional[_MultiRoundMessageMapper] = None,
**kwargs,
):
"""Create a new TokenBufferedConversationMapperOperator."""
if max_token_limit < 0:
raise ValueError("Max token limit can't be negative")
self._model = model
@ -452,6 +480,7 @@ class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
super().__init__(**kwargs)
async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
"""Map multi round messages to a list of BaseMessage."""
eviction_policy = self._eviction_policy or self.eviction_policy
messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
messages_str = _messages_to_str(_merge_multi_round_messages(messages_by_round))
@ -459,7 +488,8 @@ class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
current_tokens = await self._llm_client.count_token(self._model, messages_str)
while current_tokens > self._max_token_limit:
# Evict the messages by round after all tokens are not greater than the max token limit
# Evict the messages by round after all tokens are not greater than the max
# token limit
# TODO: We should find a high performance way to do this
messages_by_round = eviction_policy(messages_by_round)
messages_str = _messages_to_str(

View File

@ -1,8 +1,8 @@
"""The prompt operator."""
from abc import ABC
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import (
BasePromptTemplate,
ChatPromptTemplate,
ModelMessage,
ModelMessageRoleType,
@ -13,7 +13,13 @@ from dbgpt.core.awel import JoinOperator, MapOperator
from dbgpt.core.interface.message import BaseMessage
from dbgpt.core.interface.operator.llm_operator import BaseLLM
from dbgpt.core.interface.operator.message_operator import BaseConversationOperator
from dbgpt.core.interface.prompt import HumanPromptTemplate, MessageType
from dbgpt.core.interface.prompt import (
BaseChatPromptTemplate,
HumanPromptTemplate,
MessagesPlaceholder,
MessageType,
PromptTemplate,
)
from dbgpt.util.function_utils import rearrange_args_by_type
@ -21,6 +27,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
"""The base prompt builder operator."""
def __init__(self, check_storage: bool, **kwargs):
"""Create a new prompt builder operator."""
super().__init__(check_storage=check_storage, **kwargs)
async def format_prompt(
@ -39,10 +46,10 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
kwargs.update(prompt_dict)
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
messages = prompt.format_messages(**pass_kwargs)
messages = ModelMessage.from_base_messages(messages)
model_messages = ModelMessage.from_base_messages(messages)
# Start new round conversation, and save user message to storage
await self.start_new_round_conv(messages)
return messages
await self.start_new_round_conv(model_messages)
return model_messages
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
"""Start a new round conversation.
@ -50,7 +57,6 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
Args:
messages (List[ModelMessage]): The messages.
"""
lass_user_message = None
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
@ -58,7 +64,9 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
break
if not lass_user_message:
raise ValueError("No user message")
storage_conv: StorageConversation = await self.get_storage_conversation()
storage_conv: Optional[
StorageConversation
] = await self.get_storage_conversation()
if not storage_conv:
return
# Start new round
@ -66,13 +74,17 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
storage_conv.add_user_message(lass_user_message)
async def after_dag_end(self):
"""The callback after DAG end"""
# TODO remove this to start_new_round()
"""Execute after the DAG finished."""
# Save the storage conversation to storage after the whole DAG finished
storage_conv: StorageConversation = await self.get_storage_conversation()
storage_conv: Optional[
StorageConversation
] = await self.get_storage_conversation()
if not storage_conv:
return
model_output: ModelOutput = await self.current_dag_context.get_from_share_data(
model_output: Optional[
ModelOutput
] = await self.current_dag_context.get_from_share_data(
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT
)
if model_output:
@ -82,7 +94,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
storage_conv.end_current_round()
PromptTemplateType = Union[ChatPromptTemplate, BasePromptTemplate, MessageType, str]
PromptTemplateType = Union[ChatPromptTemplate, PromptTemplate, MessageType, str]
class PromptBuilderOperator(
@ -91,7 +103,6 @@ class PromptBuilderOperator(
"""The operator to build the prompt with static prompt.
Examples:
.. code-block:: python
import asyncio
@ -119,7 +130,8 @@ class PromptBuilderOperator(
ChatPromptTemplate(
messages=[
HumanPromptTemplate.from_template(
"Please write a {dialect} SQL count the length of a field"
"Please write a {dialect} SQL count the length of a"
" field"
)
]
)
@ -131,7 +143,8 @@ class PromptBuilderOperator(
"You are a {dialect} SQL expert"
),
HumanPromptTemplate.from_template(
"Please write a {dialect} SQL count the length of a field"
"Please write a {dialect} SQL count the length of a"
" field"
),
],
)
@ -171,17 +184,18 @@ class PromptBuilderOperator(
"""
def __init__(self, prompt: PromptTemplateType, **kwargs):
"""Create a new prompt builder operator."""
if isinstance(prompt, str):
prompt = ChatPromptTemplate(
messages=[HumanPromptTemplate.from_template(prompt)]
)
elif isinstance(prompt, BasePromptTemplate) and not isinstance(
prompt, ChatPromptTemplate
):
elif isinstance(prompt, PromptTemplate):
prompt = ChatPromptTemplate(
messages=[HumanPromptTemplate.from_template(prompt.template)]
)
elif isinstance(prompt, MessageType):
elif isinstance(
prompt, (BaseChatPromptTemplate, MessagesPlaceholder, BaseMessage)
):
prompt = ChatPromptTemplate(messages=[prompt])
self._prompt = prompt
@ -190,6 +204,7 @@ class PromptBuilderOperator(
@rearrange_args_by_type
async def merge_prompt(self, prompt_dict: Dict[str, Any]) -> List[ModelMessage]:
"""Format the prompt."""
return await self.format_prompt(self._prompt, prompt_dict)
@ -202,6 +217,7 @@ class DynamicPromptBuilderOperator(
"""
def __init__(self, **kwargs):
"""Create a new dynamic prompt builder operator."""
super().__init__(check_storage=False, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_prompt, **kwargs)
@ -209,20 +225,37 @@ class DynamicPromptBuilderOperator(
async def merge_prompt(
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
"""Merge the prompt and history."""
return await self.format_prompt(prompt, prompt_dict)
class HistoryPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
"""The operator to build the prompt with static prompt.
The prompt will pass to this operator.
"""
def __init__(
self,
prompt: ChatPromptTemplate,
history_key: Optional[str] = None,
history_key: str = "chat_history",
check_storage: bool = True,
str_history: bool = False,
**kwargs,
):
"""Create a new history prompt builder operator.
Args:
prompt (ChatPromptTemplate): The prompt.
history_key (str, optional): The key of history in prompt dict. Defaults
to "chat_history".
check_storage (bool, optional): Whether to check the storage.
Defaults to True.
str_history (bool, optional): Whether to convert the history to string.
Defaults to False.
"""
self._prompt = prompt
self._history_key = history_key
self._str_history = str_history
@ -233,6 +266,7 @@ class HistoryPromptBuilderOperator(
async def merge_history(
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
"""Merge the prompt and history."""
if self._str_history:
prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
else:
@ -250,11 +284,12 @@ class HistoryDynamicPromptBuilderOperator(
def __init__(
self,
history_key: Optional[str] = None,
history_key: str = "chat_history",
check_storage: bool = True,
str_history: bool = False,
**kwargs,
):
"""Create a new history dynamic prompt builder operator."""
self._history_key = history_key
self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
@ -267,6 +302,7 @@ class HistoryDynamicPromptBuilderOperator(
history: List[BaseMessage],
prompt_dict: Dict[str, Any],
) -> List[ModelMessage]:
"""Merge the prompt and history."""
if self._str_history:
prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
else:

View File

@ -1,3 +1,4 @@
"""The Abstract Retriever Operator."""
from abc import abstractmethod
from dbgpt.core.awel import MapOperator
@ -16,7 +17,8 @@ class RetrieverOperator(MapOperator[IN, OUT]):
Returns:
OUT: The output value.
"""
# The retrieve function is blocking, so we need to wrap it in a blocking_func_to_async.
# The retrieve function is blocking, so we need to wrap it in a
# blocking_func_to_async.
return await self.blocking_func_to_async(self.retrieve, input_value)
@abstractmethod

View File

@ -1,10 +1,15 @@
"""The output parser is used to parse the output of an LLM call.
TODO: Make this more general and clear.
"""
from __future__ import annotations
import json
import logging
from abc import ABC
from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union
from typing import Any, TypeVar, Union
from dbgpt.core import ModelOutput
from dbgpt.core.awel import MapOperator
@ -22,11 +27,16 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
"""
def __init__(self, is_stream_out: bool = True, **kwargs):
"""Create a new output parser."""
super().__init__(**kwargs)
self.is_stream_out = is_stream_out
self.data_schema = None
def update(self, data_schema):
"""Update the data schema.
TODO: Remove this method.
"""
self.data_schema = data_schema
def __post_process_code(self, code):
@ -40,9 +50,16 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
return code
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
data = _parse_model_response(chunk)
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""Parse the output of an LLM call.
Args:
chunk (ResponseTye): The output of an LLM call.
skip_echo_len (int): The length of the prompt to skip.
"""
data = _parse_model_response(chunk)
# TODO: Multi mode output handler, rewrite this for multi model, use adapter
# mode.
model_context = data.get("model_context")
has_echo = False
if model_context and "prompt_echo_len_char" in model_context:
@ -65,6 +82,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
return output
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
"""Parse the output of an LLM call."""
resp_obj_ex = _parse_model_response(response)
if isinstance(resp_obj_ex, str):
resp_obj_ex = json.loads(resp_obj_ex)
@ -89,7 +107,8 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
return ai_response
else:
raise ValueError(
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
f"Model server error!code={resp_obj_ex['error_code']}, error msg is "
f"{resp_obj_ex['text']}"
)
def _illegal_json_ends(self, s):
@ -117,7 +136,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
temp_json = self._illegal_json_ends(temp_json)
return temp_json
except Exception as e:
except Exception:
raise ValueError("Failed to find a valid json in LLM response" + temp_json)
def _json_interception(self, s, is_json_array: bool = False):
@ -150,17 +169,17 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
break
assert count == 0
return s[i : j + 1]
except Exception as e:
except Exception:
return ""
def parse_prompt_response(self, model_out_text) -> T:
"""
parse model out text to prompt define response
def parse_prompt_response(self, model_out_text) -> Any:
"""Parse model out text to prompt define response.
Args:
model_out_text:
model_out_text: The output of an LLM call.
Returns:
Any: The parsed output of an LLM call.
"""
cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output:
@ -194,12 +213,15 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
def parse_view_response(
self, ai_text, data, parse_prompt_response: Any = None
) -> str:
"""
parse the ai response info to user view
"""Parse the AI response info to user view.
Args:
text:
ai_text (str): The output of an LLM call.
data (dict): The data has been handled by some scene.
parse_prompt_response (Any): The prompt response has been parsed.
Returns:
str: The parsed output of an LLM call.
"""
return ai_text
@ -240,10 +262,14 @@ def _parse_model_response(response: ResponseTye):
class SQLOutputParser(BaseOutputParser):
"""Parse the SQL output of an LLM call."""
def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
"""Parse the output of an LLM call."""
model_out_text = super().parse_model_nostream_resp(response, sep)
clean_str = super().parse_prompt_response(model_out_text)
return json.loads(clean_str, strict=True)

View File

@ -1,3 +1,5 @@
"""The prompt template interface."""
from __future__ import annotations
import dataclasses
@ -7,9 +9,7 @@ from string import Formatter
from typing import Any, Callable, Dict, List, Optional, Set, Union
from dbgpt._private.pydantic import BaseModel, root_validator
from dbgpt.core._private.example_base import ExampleSelector
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.core.interface.storage import (
InMemoryStorage,
QuerySpec,
@ -42,64 +42,32 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
class BasePromptTemplate(BaseModel):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template: Optional[str]
class PromptTemplate(BasePromptTemplate):
"""Prompt template."""
template: str
"""The prompt template."""
template_format: Optional[str] = "f-string"
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_key: str = "response"
template_is_strict: bool = True
"""strict template will check template args"""
response_format: Optional[str] = None
response_key: Optional[str] = "response"
template_scene: Optional[str] = None
template_is_strict: Optional[bool] = True
"""strict template will check template args"""
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.template:
if self.response_format:
kwargs[self.response_key] = json.dumps(
self.response_format, ensure_ascii=False, indent=4
)
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template_is_strict
)(self.template, **kwargs)
@classmethod
def from_template(
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
) -> BasePromptTemplate:
"""Create a prompt template from a template string."""
input_variables = get_template_vars(template, template_format)
return cls(
template=template,
input_variables=input_variables,
template_format=template_format,
**kwargs,
)
class PromptTemplate(BasePromptTemplate):
template_scene: Optional[str]
template_define: Optional[str]
template_define: Optional[str] = None
"""this template define"""
"""default use stream out"""
stream_out: bool = True
""""""
output_parser: BaseOutputParser = None
""""""
sep: str = "###"
example_selector: ExampleSelector = None
need_historical_messages: bool = False
temperature: float = 0.6
max_new_tokens: int = 1024
class Config:
"""Configuration for this pydantic object."""
@ -111,13 +79,38 @@ class PromptTemplate(BasePromptTemplate):
"""Return the prompt type key."""
return "prompt"
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.response_format:
kwargs[self.response_key] = json.dumps(
self.response_format, ensure_ascii=False, indent=4
)
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template_is_strict
)(self.template, **kwargs)
@classmethod
def from_template(
cls, template: str, template_format: str = "f-string", **kwargs: Any
) -> BasePromptTemplate:
"""Create a prompt template from a template string."""
input_variables = get_template_vars(template, template_format)
return cls(
template=template,
input_variables=input_variables,
template_format=template_format,
**kwargs,
)
class BaseChatPromptTemplate(BaseModel, ABC):
"""The base chat prompt template."""
prompt: BasePromptTemplate
@property
def input_variables(self) -> List[str]:
"""A list of the names of the variables the prompt template expects."""
"""Return a list of the names of the variables the prompt template expects."""
return self.prompt.input_variables
@abstractmethod
@ -128,14 +121,14 @@ class BaseChatPromptTemplate(BaseModel, ABC):
def from_template(
cls,
template: str,
template_format: Optional[str] = "f-string",
template_format: str = "f-string",
response_format: Optional[str] = None,
response_key: Optional[str] = "response",
response_key: str = "response",
template_is_strict: bool = True,
**kwargs: Any,
) -> BaseChatPromptTemplate:
"""Create a prompt template from a template string."""
prompt = BasePromptTemplate.from_template(
prompt = PromptTemplate.from_template(
template,
template_format,
response_format=response_format,
@ -149,6 +142,11 @@ class SystemPromptTemplate(BaseChatPromptTemplate):
"""The system prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs.
Returns:
List[BaseMessage]: The formatted messages.
"""
content = self.prompt.format(**kwargs)
return [SystemMessage(content=content)]
@ -157,20 +155,31 @@ class HumanPromptTemplate(BaseChatPromptTemplate):
"""The human prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs.
Returns:
List[BaseMessage]: The formatted messages.
"""
content = self.prompt.format(**kwargs)
return [HumanMessage(content=content)]
class MessagesPlaceholder(BaseChatPromptTemplate):
class MessagesPlaceholder(BaseModel):
"""The messages placeholder template.
Mostly used for the chat history.
"""
variable_name: str
prompt: BasePromptTemplate = None
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs.
Just return the messages from the kwargs with the variable name.
Returns:
List[BaseMessage]: The messages.
"""
messages = kwargs.get(self.variable_name, [])
if not isinstance(messages, list):
raise ValueError(
@ -185,7 +194,7 @@ class MessagesPlaceholder(BaseChatPromptTemplate):
@property
def input_variables(self) -> List[str]:
"""A list of the names of the variables the prompt template expects.
"""Return a list of the names of the variables the prompt template expects.
Returns:
List[str]: The input variables.
@ -193,10 +202,26 @@ class MessagesPlaceholder(BaseChatPromptTemplate):
return [self.variable_name]
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
MessageType = Union[BaseChatPromptTemplate, MessagesPlaceholder, BaseMessage]
class ChatPromptTemplate(BasePromptTemplate):
"""The chat prompt template.
Examples:
.. code-block:: python
prompt_template = ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(
"You are a helpful AI assistant."
),
MessagesPlaceholder(variable_name="chat_history"),
HumanPromptTemplate.from_template("{question}"),
]
)
"""
messages: List[MessageType]
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
@ -205,12 +230,7 @@ class ChatPromptTemplate(BasePromptTemplate):
for message in self.messages:
if isinstance(message, BaseMessage):
result_messages.append(message)
elif isinstance(message, BaseChatPromptTemplate):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
result_messages.extend(message.format_messages(**pass_kwargs))
elif isinstance(message, MessagesPlaceholder):
elif isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
@ -227,7 +247,7 @@ class ChatPromptTemplate(BasePromptTemplate):
if not input_variables:
input_variables = set()
for message in messages:
if isinstance(message, BaseChatPromptTemplate):
if isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)):
input_variables.update(message.input_variables)
values["input_variables"] = sorted(input_variables)
return values
@ -235,6 +255,8 @@ class ChatPromptTemplate(BasePromptTemplate):
@dataclasses.dataclass
class PromptTemplateIdentifier(ResourceIdentifier):
"""The identifier of a prompt template."""
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
prompt_name: str
prompt_language: Optional[str] = None
@ -242,6 +264,7 @@ class PromptTemplateIdentifier(ResourceIdentifier):
model: Optional[str] = None
def __post_init__(self):
"""Post init method."""
if self.prompt_name is None:
raise ValueError("prompt_name cannot be None")
@ -256,11 +279,13 @@ class PromptTemplateIdentifier(ResourceIdentifier):
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in prompt_name, prompt_language, sys_code, model"
f"identifier_split {self.identifier_split} is not allowed in "
f"prompt_name, prompt_language, sys_code, model"
)
@property
def str_identifier(self) -> str:
"""Return the string identifier of the identifier."""
return self.identifier_split.join(
key
for key in [
@ -273,6 +298,11 @@ class PromptTemplateIdentifier(ResourceIdentifier):
)
def to_dict(self) -> Dict:
"""Convert the identifier to a dict.
Returns:
Dict: The dict of the identifier.
"""
return {
"prompt_name": self.prompt_name,
"prompt_language": self.prompt_language,
@ -283,6 +313,8 @@ class PromptTemplateIdentifier(ResourceIdentifier):
@dataclasses.dataclass
class StoragePromptTemplate(StorageItem):
"""The storage prompt template."""
prompt_name: str
content: Optional[str] = None
prompt_language: Optional[str] = None
@ -297,25 +329,28 @@ class StoragePromptTemplate(StorageItem):
_identifier: PromptTemplateIdentifier = dataclasses.field(init=False)
def __post_init__(self):
"""Post init method."""
self._identifier = PromptTemplateIdentifier(
prompt_name=self.prompt_name,
prompt_language=self.prompt_language,
sys_code=self.sys_code,
model=self.model,
)
self._check() # Assuming _check() is a method you need to call after initialization
# Assuming _check() is a method you need to call after initialization
self._check()
def to_prompt_template(self) -> PromptTemplate:
"""Convert the storage prompt template to a prompt template."""
input_variables = (
[] if not self.input_variables else self.input_variables.strip().split(",")
)
template_format = self.prompt_format or "f-string"
return PromptTemplate(
input_variables=input_variables,
template=self.content,
template_scene=self.chat_scene,
prompt_name=self.prompt_name,
template_format=self.prompt_format,
# prompt_name=self.prompt_name,
template_format=template_format,
)
@staticmethod
@ -335,12 +370,18 @@ class StoragePromptTemplate(StorageItem):
Args:
prompt_template (PromptTemplate): The prompt template to convert from.
prompt_name (str): The name of the prompt.
prompt_language (Optional[str], optional): The language of the prompt. Defaults to None. e.g. zh-cn, en.
prompt_type (Optional[str], optional): The type of the prompt. Defaults to None. e.g. common, private.
sys_code (Optional[str], optional): The system code of the prompt. Defaults to None.
user_name (Optional[str], optional): The username of the prompt. Defaults to None.
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. Defaults to None.
model (Optional[str], optional): The model name of the prompt. Defaults to None.
prompt_language (Optional[str], optional): The language of the prompt.
Defaults to None. e.g. zh-cn, en.
prompt_type (Optional[str], optional): The type of the prompt.
Defaults to None. e.g. common, private.
sys_code (Optional[str], optional): The system code of the prompt.
Defaults to None.
user_name (Optional[str], optional): The username of the prompt.
Defaults to None.
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt.
Defaults to None.
model (Optional[str], optional): The model name of the prompt.
Defaults to None.
kwargs (Dict): Other params to build the storage prompt template.
"""
input_variables = prompt_template.input_variables or kwargs.get(
@ -365,6 +406,7 @@ class StoragePromptTemplate(StorageItem):
@property
def identifier(self) -> PromptTemplateIdentifier:
"""Return the identifier of the storage prompt template."""
return self._identifier
def merge(self, other: "StorageItem") -> None:
@ -375,11 +417,17 @@ class StoragePromptTemplate(StorageItem):
"""
if not isinstance(other, StoragePromptTemplate):
raise ValueError(
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
f"Cannot merge {type(other)} into {type(self)} because they are not "
f"the same type."
)
self.from_object(other)
def to_dict(self) -> Dict:
"""Convert the storage prompt template to a dict.
Returns:
Dict: The dict of the storage prompt template.
"""
return {
"prompt_name": self.prompt_name,
"content": self.content,
@ -422,7 +470,6 @@ class PromptManager:
Simple wrapper for the storage interface.
Examples:
.. code-block:: python
# Default use InMemoryStorage
@ -458,13 +505,14 @@ class PromptManager:
def __init__(
self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None
):
"""Create a new prompt manager."""
if storage is None:
storage = InMemoryStorage()
self._storage = storage
@property
def storage(self) -> StorageInterface[StoragePromptTemplate, Any]:
"""The storage interface for prompt templates."""
"""Return the storage interface for prompt templates."""
return self._storage
def prefer_query(
@ -477,11 +525,12 @@ class PromptManager:
) -> List[StoragePromptTemplate]:
"""Query prompt templates from storage with prefer params.
Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model).
This method will query prompt templates with prefer params first, if not found, will query all prompt templates.
Sometimes, we want to query prompt templates with prefer params(e.g. some
language or some model).
This method will query prompt templates with prefer params first, if not found,
will query all prompt templates.
Examples:
Query a prompt template.
.. code-block:: python
@ -500,7 +549,8 @@ class PromptManager:
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with prompt language "zh-cn", if not found, will return all prompt templates.
# Second filter with prompt language "zh-cn", if not found, will return
# all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_prompt_language="zh-cn"
)
@ -510,17 +560,22 @@ class PromptManager:
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates.
# Second filter with model "vicuna-13b-v1.5", if not found, will return
# all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_model="vicuna-13b-v1.5"
)
Args:
prompt_name (str): The name of the prompt template.
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None.
kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly).
sys_code (Optional[str], optional): The system code of the prompt template.
Defaults to None.
prefer_prompt_language (Optional[str], optional): The language of the
prompt template. Defaults to None.
prefer_model (Optional[str], optional): The model of the prompt template.
Defaults to None.
kwargs (Dict): Other query params(If some key and value not None, wo we
query it exactly).
"""
query_spec = QuerySpec(
conditions={
@ -559,7 +614,6 @@ class PromptManager:
"""Save a prompt template to storage.
Examples:
.. code-block:: python
prompt_template = PromptTemplate(
@ -618,15 +672,17 @@ class PromptManager:
if exist_prompt_template:
return exist_prompt_template
self.save(prompt_template, prompt_name, **kwargs)
return self.storage.load(
prompt = self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
if not prompt:
raise ValueError("Can't read prompt from storage")
return prompt
def list(self, **kwargs) -> List[StoragePromptTemplate]:
"""List prompt templates from storage.
Examples:
List all prompt templates.
.. code-block:: python
@ -656,7 +712,6 @@ class PromptManager:
"""Delete a prompt template from storage.
Examples:
Delete a prompt template.
.. code-block:: python
@ -673,9 +728,12 @@ class PromptManager:
Args:
prompt_name (str): The name of the prompt template.
prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
model (Optional[str], optional): The model of the prompt template. Defaults to None.
prompt_language (Optional[str], optional): The language of the prompt
template. Defaults to None.
sys_code (Optional[str], optional): The system code of the prompt template.
Defaults to None.
model (Optional[str], optional): The model of the prompt template.
Defaults to None.
"""
identifier = PromptTemplateIdentifier(
prompt_name=prompt_name,

View File

@ -1,11 +1,15 @@
"""The interface for serializing."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, Type
from typing import Dict, Optional, Type
class Serializable(ABC):
serializer: "Serializer" = None
"""The serializable abstract class."""
serializer: Optional["Serializer"] = None
@abstractmethod
def to_dict(self) -> Dict:

View File

@ -1,5 +1,7 @@
"""The storage interface for storing and loading data."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.util.annotations import PublicAPI
@ -55,17 +57,22 @@ class StorageItem(Serializable, ABC):
"""
ID = TypeVar("ID", bound=ResourceIdentifier)
T = TypeVar("T", bound=StorageItem)
TDataRepresentation = TypeVar("TDataRepresentation")
class StorageItemAdapter(Generic[T, TDataRepresentation]):
"""The storage item adapter for converting storage items to and from the storage format.
"""Storage item adapter.
The storage item adapter for converting storage items to and from the storage
format.
Sometimes, the storage item is not the same as the storage format,
so we need to convert the storage item to the storage format and vice versa.
In database storage, the storage format is database model, but the StorageItem is the user-defined object.
In database storage, the storage format is database model, but the StorageItem is
the user-defined object.
"""
@abstractmethod
@ -110,20 +117,44 @@ class StorageItemAdapter(Generic[T, TDataRepresentation]):
class DefaultStorageItemAdapter(StorageItemAdapter[T, T]):
"""The default storage item adapter for converting storage items to and from the storage format.
"""Default storage item adapter.
The default storage item adapter for converting storage items to and from the
storage format.
The storage item is the same as the storage format, so no conversion is required.
"""
def to_storage_format(self, item: T) -> T:
"""Convert the storage item to the storage format.
Returns the storage item itself.
Args:
item (T): The storage item
Returns:
T: The data in the storage format
"""
return item
def from_storage_format(self, data: T) -> T:
"""Convert the storage format to the storage item.
Returns the storage format itself.
Args:
data (T): The data in the storage format
Returns:
T: The storage item
"""
return data
def get_query_for_identifier(
self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs
self, storage_format: Type[T], resource_id: ID, **kwargs
) -> bool:
"""Return the query for the resource identifier."""
return True
@ -132,6 +163,7 @@ class StorageError(Exception):
"""The base exception class for storage errors."""
def __init__(self, message: str):
"""Create a new StorageError."""
super().__init__(message)
@ -146,8 +178,9 @@ class QuerySpec:
"""
def __init__(
self, conditions: Dict[str, Any], limit: int = None, offset: int = 0
self, conditions: Dict[str, Any], limit: Optional[int] = None, offset: int = 0
) -> None:
"""Create a new QuerySpec."""
self.conditions = conditions
self.limit = limit
self.offset = offset
@ -162,6 +195,7 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
serializer: Optional[Serializer] = None,
adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None,
):
"""Create a new StorageInterface."""
self._serializer = serializer or JsonSerializer()
self._storage_item_adapter = adapter or DefaultStorageItemAdapter()
@ -238,7 +272,7 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
self.save_or_update(d)
@abstractmethod
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
def load(self, resource_id: ID, cls: Type[T]) -> Optional[T]:
"""Load the data from the storage.
None will be returned if the data does not exist in the storage.
@ -247,14 +281,14 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
so we suggest to use load if possible.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
resource_id (ID): The resource identifier of the data
cls (Type[T]): The type of the data
Returns:
Optional[T]: The loaded data
"""
def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]:
def load_list(self, resource_id: List[ID], cls: Type[T]) -> List[T]:
"""Load the data from the storage.
None will be returned if the data does not exist in the storage.
@ -263,7 +297,7 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
so we suggest to use load if possible.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
resource_id (ID): The resource identifier of the data
cls (Type[T]): The type of the data
Returns:
@ -277,18 +311,18 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
return result
@abstractmethod
def delete(self, resource_id: ResourceIdentifier) -> None:
def delete(self, resource_id: ID) -> None:
"""Delete the data from the storage.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
resource_id (ID): The resource identifier of the data
"""
def delete_list(self, resource_id: List[ResourceIdentifier]) -> None:
def delete_list(self, resource_id: List[ID]) -> None:
"""Delete the data from the storage.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
resource_id (ID): The resource identifier of the data
"""
for r in resource_id:
self.delete(r)
@ -297,7 +331,8 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
Query data with resource_id will be faster than query data with conditions, so please use load if possible.
Query data with resource_id will be faster than query data with conditions,
so please use load if possible.
Args:
spec (QuerySpec): The query specification
@ -328,7 +363,8 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
page (int): The page number
page_size (int): The number of items per page
cls (Type[T]): The type of the data
spec (Optional[QuerySpec], optional): The query specification. Defaults to None.
spec (Optional[QuerySpec], optional): The query specification.
Defaults to None.
Returns:
PaginationResult[T]: The pagination result
@ -356,10 +392,17 @@ class InMemoryStorage(StorageInterface[T, T]):
self,
serializer: Optional[Serializer] = None,
):
"""Create a new InMemoryStorage."""
super().__init__(serializer)
self._data = {} # Key: ResourceIdentifier, Value: Serialized data
# Key: ResourceIdentifier, Value: Serialized data
self._data: Dict[str, bytes] = {}
def save(self, data: T) -> None:
"""Save the data to the storage.
Args:
data (T): The data to save
"""
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
@ -372,6 +415,7 @@ class InMemoryStorage(StorageInterface[T, T]):
self._data[data.identifier.str_identifier] = data.serialize()
def update(self, data: T) -> None:
"""Update the data to the storage."""
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
@ -379,22 +423,34 @@ class InMemoryStorage(StorageInterface[T, T]):
self._data[data.identifier.str_identifier] = data.serialize()
def save_or_update(self, data: T) -> None:
"""Save or update the data to the storage."""
self.update(data)
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
def load(self, resource_id: ID, cls: Type[T]) -> Optional[T]:
"""Load the data from the storage."""
serialized_data = self._data.get(resource_id.str_identifier)
if serialized_data is None:
return None
return self.serializer.deserialize(serialized_data, cls)
return cast(T, self.serializer.deserialize(serialized_data, cls))
def delete(self, resource_id: ResourceIdentifier) -> None:
def delete(self, resource_id: ID) -> None:
"""Delete the data from the storage."""
if resource_id.str_identifier in self._data:
del self._data[resource_id.str_identifier]
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
Returns:
List[T]: The queried data
"""
result = []
for serialized_data in self._data.values():
data = self._serializer.deserialize(serialized_data, cls)
data = cast(T, self._serializer.deserialize(serialized_data, cls))
if all(
getattr(data, key) == value for key, value in spec.conditions.items()
):
@ -408,6 +464,15 @@ class InMemoryStorage(StorageInterface[T, T]):
return result
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
"""Count the number of data from the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
Returns:
int: The number of data
"""
count = 0
for serialized_data in self._data.values():
data = self._serializer.deserialize(serialized_data, cls)

View File

@ -1,22 +1,24 @@
from dbgpt.core.interface.operator.composer_operator import (
"""All core operators."""
from dbgpt.core.interface.operator.composer_operator import ( # noqa: F401
ChatComposerInput,
ChatHistoryPromptComposerOperator,
)
from dbgpt.core.interface.operator.llm_operator import (
from dbgpt.core.interface.operator.llm_operator import ( # noqa: F401
BaseLLM,
BaseLLMOperator,
BaseStreamingLLMOperator,
LLMBranchOperator,
RequestBuilderOperator,
)
from dbgpt.core.interface.operator.message_operator import (
from dbgpt.core.interface.operator.message_operator import ( # noqa: F401
BaseConversationOperator,
BufferedConversationMapperOperator,
ConversationMapperOperator,
PreChatHistoryLoadOperator,
TokenBufferedConversationMapperOperator,
)
from dbgpt.core.interface.operator.prompt_operator import (
from dbgpt.core.interface.operator.prompt_operator import ( # noqa: F401
DynamicPromptBuilderOperator,
HistoryDynamicPromptBuilderOperator,
HistoryPromptBuilderOperator,

View File

@ -8,6 +8,9 @@ import threading
from functools import cache
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
try:
from fastchat.conversation import (
Conversation,
@ -20,8 +23,6 @@ except ImportError as exc:
"Please install fastchat by command `pip install fschat` "
) from exc
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
if TYPE_CHECKING:
from fastchat.model.model_adapter import BaseModelAdapter

View File

@ -196,7 +196,14 @@ class DefaultModelWorker(ModelWorker):
return _try_to_count_token(prompt, self.tokenizer, self.model)
async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
# TODO if we deploy the model by vllm, it can't work, we should run
# transformer _try_to_count_token to async
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client:
return await self.model.proxy_llm_client.count_token(
self.model.proxy_llm_client.default_model, prompt
)
raise NotImplementedError
def get_model_metadata(self, params: Dict) -> ModelMetadata:

View File

@ -118,9 +118,6 @@ def replace_llama_attn_with_non_inplace_operations():
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
import transformers
def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@ -196,6 +196,15 @@ class ProxyLLMClient(LLMClient):
"""
return self._models()
@property
def default_model(self) -> str:
"""Get default model name
Returns:
str: default model name
"""
return self.model_names[0]
@cache
def _models(self) -> List[ModelMetadata]:
results = []
@ -237,6 +246,7 @@ class ProxyLLMClient(LLMClient):
Returns:
int: token count, -1 if failed
"""
return await blocking_func_to_async(
counts = await blocking_func_to_async(
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
)[0]
)
return counts[0]

View File

@ -86,6 +86,11 @@ class OpenAILLMClient(ProxyLLMClient):
self._openai_kwargs = openai_kwargs or {}
super().__init__(model_names=[model_alias], context_length=context_length)
if self._openai_less_then_v1:
from dbgpt.model.utils.chatgpt_utils import _initialize_openai
_initialize_openai(self._init_params)
@classmethod
def new_client(
cls,

View File

@ -114,7 +114,6 @@ class GeminiLLMClient(ProxyLLMClient):
self._api_key = api_key if api_key else os.getenv("GEMINI_PROXY_API_KEY")
self._api_base = api_base if api_base else os.getenv("GEMINI_PROXY_API_BASE")
self._model = model
self.default_model = self._model
if not self._api_key:
raise RuntimeError("api_key can't be empty")
@ -148,6 +147,10 @@ class GeminiLLMClient(ProxyLLMClient):
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,

View File

@ -8,9 +8,6 @@ from datetime import datetime
from time import mktime
from typing import Iterator, Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
from websockets.sync.client import connect
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
@ -56,6 +53,8 @@ def spark_generate_stream(
def get_response(request_url, data):
from websockets.sync.client import connect
with connect(request_url) as ws:
ws.send(json.dumps(data, ensure_ascii=False))
result = ""
@ -87,6 +86,8 @@ class SparkAPI:
self.spark_url = spark_url
def gen_url(self):
from wsgiref.handlers import format_date_time
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
@ -145,7 +146,6 @@ class SparkLLMClient(ProxyLLMClient):
if not api_domain:
api_domain = domain
self._model = model
self.default_model = self._model
self._model_version = model_version
self._api_base = api_base
self._domain = api_domain
@ -183,6 +183,10 @@ class SparkLLMClient(ProxyLLMClient):
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,

View File

@ -51,7 +51,6 @@ class TongyiLLMClient(ProxyLLMClient):
if api_region:
dashscope.api_region = api_region
self._model = model
self.default_model = self._model
super().__init__(
model_names=[model, model_alias],
@ -73,6 +72,10 @@ class TongyiLLMClient(ProxyLLMClient):
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,

View File

@ -121,7 +121,6 @@ class WenxinLLMClient(ProxyLLMClient):
self._api_key = api_key
self._api_secret = api_secret
self._model_version = model_version
self.default_model = self._model
super().__init__(
model_names=[model, model_alias],
@ -145,6 +144,10 @@ class WenxinLLMClient(ProxyLLMClient):
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,

View File

@ -54,7 +54,6 @@ class ZhipuLLMClient(ProxyLLMClient):
if api_key:
zhipuai.api_key = api_key
self._model = model
self.default_model = self._model
super().__init__(
model_names=[model, model_alias],
@ -76,6 +75,10 @@ class ZhipuLLMClient(ProxyLLMClient):
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,

View File

@ -88,6 +88,42 @@ def _initialize_openai_v1(init_params: OpenAIParameters):
return openai_params, api_type, api_version
def _initialize_openai(params: OpenAIParameters):
try:
import openai
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai` "
) from exc
api_type = params.api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
api_base = params.api_base or os.getenv(
"OPENAI_API_TYPE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = params.api_version or os.getenv("OPENAI_API_VERSION")
if not api_base and params.full_url:
# Adapt previous proxy_server_url configuration
api_base = params.full_url.split("/chat/completions")[0]
if api_type:
openai.api_type = api_type
if api_base:
openai.api_base = api_base
if api_key:
openai.api_key = api_key
if api_version:
openai.api_version = api_version
if params.proxies:
openai.proxy = params.proxies
def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType]:
import httpx
@ -112,9 +148,7 @@ def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType
class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]):
"""Transform ModelOutput to openai stream format."""
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
async def model_caller() -> str:
"""Read model name from share data.
In streaming mode, this transform_stream function will be executed

View File

@ -1,6 +1,6 @@
from typing import Any
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.core.interface.operator.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary

View File

@ -1,7 +1,7 @@
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.core.interface.operator.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector

View File

@ -2,7 +2,7 @@ from functools import reduce
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.core.interface.operator.retriever import RetrieverOperator
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite

View File

@ -1,3 +1,4 @@
import asyncio
import json
import logging
import uuid
@ -30,7 +31,6 @@ from .dbgpts import DbGptsInstance
CFG = Config()
import asyncio
router = APIRouter()
logger = logging.getLogger(__name__)

View File

@ -6,13 +6,13 @@ from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
from .base import MemoryStoreType
# Import first for auto create table
from .store_type.meta_db_history import DbHistoryMemory
# TODO remove global variable
CFG = Config()
logger = logging.getLogger(__name__)
# Import first for auto create table
from .store_type.meta_db_history import DbHistoryMemory
class ChatHistory:
def __init__(self):

View File

@ -5,6 +5,8 @@ from sqlalchemy.orm.session import Session
from dbgpt.util.pagination_utils import PaginationResult
from .db_manager import BaseQuery, DatabaseManager, db
# The entity type
T = TypeVar("T")
# The request schema type
@ -12,7 +14,6 @@ REQ = TypeVar("REQ")
# The response schema type
RES = TypeVar("RES")
from .db_manager import BaseQuery, DatabaseManager, db
QUERY_SPEC = Union[REQ, Dict[str, Any]]

View File

@ -1,3 +1,6 @@
from typing import Optional
def PublicAPI(*args, **kwargs):
"""Decorator to mark a function or class as a public API.
@ -64,7 +67,7 @@ def DeveloperAPI(*args, **kwargs):
return decorator
def _modify_docstring(obj, message: str = None):
def _modify_docstring(obj, message: Optional[str] = None):
if not message:
return
if not obj.__doc__:
@ -81,6 +84,7 @@ def _modify_docstring(obj, message: str = None):
if min_indent == float("inf"):
min_indent = 0
min_indent = int(min_indent)
indented_message = message.rstrip() + "\n" + (" " * min_indent)
obj.__doc__ = indented_message + original_doc

View File

@ -1,6 +1,6 @@
import os
from functools import cache
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, cast
class AppConfig:
@ -46,7 +46,7 @@ class AppConfig:
"""
env_lang = (
"zh"
if os.getenv("LANG") and os.getenv("LANG").startswith("zh")
if os.getenv("LANG") and cast(str, os.getenv("LANG")).startswith("zh")
else default
)
return self.get("dbgpt.app.global.language", env_lang)

View File

@ -1,7 +1,7 @@
"""Utilities for formatting strings."""
import json
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
from typing import Any, List, Mapping, Sequence, Set, Union
class StrictFormatter(Formatter):
@ -9,7 +9,7 @@ class StrictFormatter(Formatter):
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
used_args: Set[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
@ -39,7 +39,7 @@ class StrictFormatter(Formatter):
class NoStrictFormatter(StrictFormatter):
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
used_args: Set[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:

View File

@ -12,14 +12,14 @@ MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
@dataclass
class ParameterDescription:
param_class: str
param_name: str
param_type: str
default_value: Optional[Any]
description: str
required: Optional[bool]
valid_values: Optional[List[Any]]
ext_metadata: Dict
required: bool = False
param_class: Optional[str] = None
param_name: Optional[str] = None
param_type: Optional[str] = None
description: Optional[str] = None
default_value: Optional[Any] = None
valid_values: Optional[List[Any]] = None
ext_metadata: Optional[Dict[str, Any]] = None
@dataclass
@ -186,7 +186,9 @@ def _get_simple_privacy_field_value(obj, field_info):
return "******"
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
def _genenv_ignoring_key_case(
env_key: str, env_prefix: Optional[str] = None, default_value: Optional[str] = None
):
"""Get the value from the environment variable, ignoring the case of the key"""
if env_prefix:
env_key = env_prefix + env_key
@ -196,7 +198,9 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu
def _genenv_ignoring_key_case_with_prefixes(
env_key: str, env_prefixes: List[str] = None, default_value=None
env_key: str,
env_prefixes: Optional[List[str]] = None,
default_value: Optional[str] = None,
) -> str:
if env_prefixes:
for env_prefix in env_prefixes:
@ -208,7 +212,7 @@ def _genenv_ignoring_key_case_with_prefixes(
class EnvArgumentParser:
@staticmethod
def get_env_prefix(env_key: str) -> str:
def get_env_prefix(env_key: str) -> Optional[str]:
if not env_key:
return None
env_key = env_key.replace("-", "_")
@ -217,14 +221,14 @@ class EnvArgumentParser:
def parse_args_into_dataclass(
self,
dataclass_type: Type,
env_prefixes: List[str] = None,
command_args: List[str] = None,
env_prefixes: Optional[List[str]] = None,
command_args: Optional[List[str]] = None,
**kwargs,
) -> Any:
"""Parse parameters from environment variables and command lines and populate them into data class"""
parser = argparse.ArgumentParser()
for field in fields(dataclass_type):
env_var_value = _genenv_ignoring_key_case_with_prefixes(
env_var_value: Any = _genenv_ignoring_key_case_with_prefixes(
field.name, env_prefixes
)
if env_var_value:
@ -313,7 +317,8 @@ class EnvArgumentParser:
@staticmethod
def create_click_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
*dataclass_types: Type,
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
):
import functools
from collections import OrderedDict
@ -322,8 +327,9 @@ class EnvArgumentParser:
if _dynamic_factory:
_types = _dynamic_factory()
if _types:
dataclass_types = list(_types)
dataclass_types = list(_types) # type: ignore
for dataclass_type in dataclass_types:
# type: ignore
for field in fields(dataclass_type):
if field.name not in combined_fields:
combined_fields[field.name] = field
@ -345,7 +351,8 @@ class EnvArgumentParser:
@staticmethod
def _create_raw_click_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
*dataclass_types: Type,
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
):
combined_fields = _merge_dataclass_types(
*dataclass_types, _dynamic_factory=_dynamic_factory
@ -362,7 +369,8 @@ class EnvArgumentParser:
@staticmethod
def create_argparse_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
*dataclass_types: Type,
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
) -> argparse.ArgumentParser:
combined_fields = _merge_dataclass_types(
*dataclass_types, _dynamic_factory=_dynamic_factory
@ -429,7 +437,7 @@ class EnvArgumentParser:
return "str"
@staticmethod
def _is_require_type(field_type: Type) -> str:
def _is_require_type(field_type: Type) -> bool:
return field_type not in [Optional[int], Optional[float], Optional[bool]]
@staticmethod
@ -455,13 +463,13 @@ class EnvArgumentParser:
def _merge_dataclass_types(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
*dataclass_types: Type, _dynamic_factory: Optional[Callable[[], List[Type]]] = None
) -> OrderedDict:
combined_fields = OrderedDict()
if _dynamic_factory:
_types = _dynamic_factory()
if _types:
dataclass_types = list(_types)
dataclass_types = list(_types) # type: ignore
for dataclass_type in dataclass_types:
for field in fields(dataclass_type):
if field.name not in combined_fields:
@ -511,11 +519,12 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
if not desc:
raise ValueError("Parameter descriptions cant be empty")
param_class_str = desc[0].param_class
class_name = None
if param_class_str:
param_class = import_from_string(param_class_str, ignore_import_error=True)
if param_class:
return param_class
module_name, _, class_name = param_class_str.rpartition(".")
module_name, _, class_name = param_class_str.rpartition(".")
fields_dict = {} # This will store field names and their default values or field()
annotations = {} # This will store the type annotations for the fields
@ -526,25 +535,30 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
metadata["valid_values"] = d.valid_values
annotations[d.param_name] = _type_str_to_python_type(
d.param_type
d.param_type # type: ignore
) # Set type annotation
fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
# Create the new class. Note the setting of __annotations__ for type hints
new_class = type(
class_name, (object,), {**fields_dict, "__annotations__": annotations}
class_name, # type: ignore
(object,),
{**fields_dict, "__annotations__": annotations}, # type: ignore
)
result_class = dataclass(new_class) # Make it a dataclass
# Make it a dataclass
result_class = dataclass(new_class) # type: ignore
return result_class
def _extract_parameter_details(
parser: argparse.ArgumentParser,
param_class: str = None,
skip_names: List[str] = None,
overwrite_default_values: Dict = {},
param_class: Optional[str] = None,
skip_names: Optional[List[str]] = None,
overwrite_default_values: Optional[Dict[str, Any]] = None,
) -> List[ParameterDescription]:
if overwrite_default_values is None:
overwrite_default_values = {}
descriptions = []
for action in parser._actions:
@ -575,7 +589,9 @@ def _extract_parameter_details(
if param_name in overwrite_default_values:
default_value = overwrite_default_values[param_name]
arg_type = (
action.type if not callable(action.type) else str(action.type.__name__)
action.type
if not callable(action.type)
else str(action.type.__name__) # type: ignore
)
description = action.help
@ -583,10 +599,10 @@ def _extract_parameter_details(
required = action.required
# extract valid values for choices, if provided
valid_values = action.choices if action.choices is not None else None
valid_values = list(action.choices) if action.choices is not None else None
# set ext_metadata as an empty dict for now, can be updated later if needed
ext_metadata = {}
ext_metadata: Dict[str, Any] = {}
descriptions.append(
ParameterDescription(
@ -621,7 +637,7 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
from dbgpt._private import pydantic
version = int(pydantic.VERSION.split(".")[0])
version = int(pydantic.VERSION.split(".")[0]) # type: ignore
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
required_fields = set(schema.get("required", []))
param_descs = []
@ -661,7 +677,7 @@ def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescri
ext_metadata = (
field.field_info.extra if hasattr(field.field_info, "extra") else None
)
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
param_class = f"{model_cls.__module__}.{model_cls.__name__}"
param_desc = ParameterDescription(
param_class=param_class,
param_name=field_name,

View File

@ -5,7 +5,7 @@ import asyncio
import logging
import logging.handlers
import os
from typing import Any, List
from typing import Any, List, Optional, cast
from dbgpt.configs.model_config import LOGDIR
@ -28,19 +28,25 @@ def _get_logging_level() -> str:
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
def setup_logging_level(logging_level=None, logger_name: str = None):
def setup_logging_level(
logging_level: Optional[str] = None, logger_name: Optional[str] = None
):
if not logging_level:
logging_level = _get_logging_level()
if type(logging_level) is str:
logging_level = logging.getLevelName(logging_level.upper())
if logger_name:
logger = logging.getLogger(logger_name)
logger.setLevel(logging_level)
logger.setLevel(cast(str, logging_level))
else:
logging.basicConfig(level=logging_level, encoding="utf-8")
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
def setup_logging(
logger_name: str,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
):
if not logging_level:
logging_level = _get_logging_level()
logger = _build_logger(logger_name, logging_level, logger_filename)
@ -74,7 +80,11 @@ def get_gpu_memory(max_gpus=None):
return gpu_memory
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
def _build_logger(
logger_name,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
):
global handler
formatter = logging.Formatter(
@ -111,14 +121,14 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop:
try:
loop = asyncio.get_event_loop()
assert loop is not None
return loop
return cast(asyncio.BaseEventLoop, loop)
except RuntimeError as e:
if not "no running event loop" in str(e) and not "no current event loop" in str(
e
):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
return asyncio.get_event_loop_policy().new_event_loop()
return cast(asyncio.BaseEventLoop, asyncio.get_event_loop_policy().new_event_loop())
def logging_str_to_uvicorn_level(log_level_str):
@ -152,7 +162,7 @@ class EndpointFilter(logging.Filter):
return record.getMessage().find(self._path) == -1
def setup_http_service_logging(exclude_paths: List[str] = None):
def setup_http_service_logging(exclude_paths: Optional[List[str]] = None):
"""Setup http service logging
Now just disable some logs

View File

@ -53,7 +53,7 @@ class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
return params
with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
with DAG("dbgpt_awel_simple_rag_summary_example") as dag:
trigger = HttpTrigger(
"/examples/rag/summary", methods="POST", request_body=TriggerReqBody
)

View File

@ -13,4 +13,4 @@ aioresponses
# for git hooks
pre-commit
# Type checking
mypy==0.991
mypy==1.7.0