feat(core): Multiple ways to run dbgpts (#1734)

This commit is contained in:
Fangyin Cheng
2024-07-18 17:50:40 +08:00
committed by GitHub
parent d389fddc2f
commit f889fa3775
19 changed files with 1410 additions and 304 deletions

View File

@@ -2,6 +2,7 @@
DAG is the core component of AWEL, it is used to define the relationship between tasks.
"""
import asyncio
import contextvars
import logging
@@ -613,10 +614,14 @@ class DAG:
"""
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
self,
dag_id: str,
resource_group: Optional[ResourceGroup] = None,
tags: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize a DAG."""
self._dag_id = dag_id
self._tags: Dict[str, str] = tags or {}
self.node_map: Dict[str, DAGNode] = {}
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = []
@@ -651,6 +656,22 @@ class DAG:
"""Return the dag id of current DAG."""
return self._dag_id
@property
def tags(self) -> Dict[str, str]:
"""Return the tags of current DAG."""
return self._tags
@property
def dev_mode(self) -> bool:
"""Whether the current DAG is in dev mode.
Returns:
bool: Whether the current DAG is in dev mode
"""
from ..operators.base import _dev_mode
return _dev_mode()
def _build(self) -> None:
from ..operators.common_operator import TriggerOperator

View File

@@ -3,18 +3,49 @@
DAGManager will load DAGs from dag_dirs, and register the trigger nodes
to TriggerManager.
"""
import logging
import threading
from typing import Dict, List, Optional
from collections import defaultdict
from typing import Dict, List, Optional, Set
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .. import BaseOperator
from ..trigger.base import TriggerMetadata
from .base import DAG
from .loader import LocalFileDAGLoader
logger = logging.getLogger(__name__)
class DAGMetadata(BaseModel):
"""Metadata for the DAG."""
triggers: List[TriggerMetadata] = Field(
default_factory=list, description="The trigger metadata"
)
sse_output: bool = Field(
default=False, description="Whether the DAG is a server-sent event output"
)
streaming_output: bool = Field(
default=False, description="Whether the DAG is a streaming output"
)
tags: Optional[Dict[str, str]] = Field(
default=None, description="The tags of the DAG"
)
def to_dict(self):
"""Convert the metadata to dict."""
triggers_dict = []
for trigger in self.triggers:
triggers_dict.append(trigger.dict())
dict_value = model_to_dict(self, exclude={"triggers"})
dict_value["triggers"] = triggers_dict
return dict_value
class DAGManager(BaseComponent):
"""The component of DAGManager."""
@@ -35,6 +66,8 @@ class DAGManager(BaseComponent):
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
self.dag_alias_map: Dict[str, str] = {}
self._dag_metadata_map: Dict[str, DAGMetadata] = {}
self._tags_to_dag_ids: Dict[str, Dict[str, Set[str]]] = {}
self._trigger_manager: Optional["DefaultTriggerManager"] = None
def init_app(self, system_app: SystemApp):
@@ -73,12 +106,26 @@ class DAGManager(BaseComponent):
if alias_name:
self.dag_alias_map[alias_name] = dag_id
trigger_metadata: List["TriggerMetadata"] = []
dag_metadata = _parse_metadata(dag)
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.register_trigger(trigger, self.system_app)
tm = self._trigger_manager.register_trigger(
trigger, self.system_app
)
if tm:
trigger_metadata.append(tm)
self._trigger_manager.after_register()
else:
logger.warning("No trigger manager, not register dag trigger")
dag_metadata.triggers = trigger_metadata
self._dag_metadata_map[dag_id] = dag_metadata
tags = dag_metadata.tags
if tags:
for tag_key, tag_value in tags.items():
if tag_key not in self._tags_to_dag_ids:
self._tags_to_dag_ids[tag_key] = defaultdict(set)
self._tags_to_dag_ids[tag_key][tag_value].add(dag_id)
def unregister_dag(self, dag_id: str):
"""Unregister a DAG."""
@@ -104,7 +151,13 @@ class DAGManager(BaseComponent):
for trigger in dag.trigger_nodes:
self._trigger_manager.unregister_trigger(trigger, self.system_app)
# Finally remove the DAG from the map
metadata = self._dag_metadata_map[dag_id]
del self.dag_map[dag_id]
del self._dag_metadata_map[dag_id]
if metadata.tags:
for tag_key, tag_value in metadata.tags.items():
if tag_key in self._tags_to_dag_ids:
self._tags_to_dag_ids[tag_key][tag_value].remove(dag_id)
def get_dag(
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
@@ -116,3 +169,33 @@ class DAGManager(BaseComponent):
if alias_name in self.dag_alias_map:
return self.dag_map.get(self.dag_alias_map[alias_name])
return None
def get_dags_by_tag(self, tag_key: str, tag_value) -> List[DAG]:
"""Get all DAGs with the given tag."""
with self.lock:
dag_ids = self._tags_to_dag_ids.get(tag_key, {}).get(tag_value, set())
return [self.dag_map[dag_id] for dag_id in dag_ids]
def get_dag_metadata(
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
) -> Optional[DAGMetadata]:
"""Get a DAGMetadata by dag_id or alias_name."""
dag = self.get_dag(dag_id, alias_name)
if not dag:
return None
return self._dag_metadata_map.get(dag.dag_id)
def _parse_metadata(dag: DAG):
from ..util.chat_util import _is_sse_output
metadata = DAGMetadata()
metadata.tags = dag.tags
if not dag.leaf_nodes:
return metadata
end_node = dag.leaf_nodes[0]
if not isinstance(end_node, BaseOperator):
return metadata
metadata.sse_output = _is_sse_output(end_node)
metadata.streaming_output = end_node.streaming_operator
return metadata

View File

@@ -18,6 +18,7 @@ from dbgpt._private.pydantic import (
model_validator,
)
from dbgpt.core.awel.dag.base import DAG, DAGNode
from dbgpt.core.awel.dag.dag_manager import DAGMetadata
from .base import (
OperatorType,
@@ -352,6 +353,9 @@ class FlowPanel(BaseModel):
description="The flow panel modified time.",
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
default=None, description="The metadata of the flow"
)
@model_validator(mode="before")
@classmethod

View File

@@ -69,6 +69,15 @@ class WorkflowRunner(ABC, Generic[T]):
default_runner: Optional[WorkflowRunner] = None
def _dev_mode() -> bool:
"""Check if the operator is in dev mode.
In production mode, the default runner is not None, and the operator will run in
the same process with the DB-GPT webserver.
"""
return default_runner is None
class BaseOperatorMeta(ABCMeta):
"""Metaclass of BaseOperator."""
@@ -86,7 +95,9 @@ class BaseOperatorMeta(ABCMeta):
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory
ComponentType.EXECUTOR_DEFAULT,
DefaultExecutorFactory,
default_component=DefaultExecutorFactory(),
).create() # type: ignore
else:
executor = DefaultExecutorFactory().create()
@@ -173,13 +184,14 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def dev_mode(self) -> bool:
"""Whether the operator is in dev mode.
In production mode, the default runner is not None.
In production mode, the default runner is not None, and the operator will run in
the same process with the DB-GPT webserver.
Returns:
bool: Whether the operator is in dev mode. True if the
default runner is None.
"""
return default_runner is None
return _dev_mode()
async def _run(self, dag_ctx: DAGContext, task_log_id: str) -> TaskOutput[OUT]:
if not self.node_id:

View File

@@ -1,13 +1,22 @@
"""Base class for all trigger classes."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Generic
from typing import Any, Generic, Optional
from dbgpt._private.pydantic import BaseModel, Field
from ..operators.common_operator import TriggerOperator
from ..task.base import OUT
class TriggerMetadata(BaseModel):
"""Metadata for the trigger."""
trigger_type: Optional[str] = Field(
default=None, description="The type of the trigger"
)
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
"""Base class for all trigger classes.

View File

@@ -43,7 +43,7 @@ from ..operators.base import BaseOperator
from ..operators.common_operator import MapOperator
from ..util._typing_util import _parse_bool
from ..util.http_util import join_paths
from .base import Trigger
from .base import Trigger, TriggerMetadata
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
@@ -82,6 +82,17 @@ def _default_streaming_predict_func(body: "CommonRequestType") -> bool:
return _parse_bool(streaming)
class HttpTriggerMetadata(TriggerMetadata):
"""Trigger metadata."""
path: str = Field(..., description="The path of the trigger")
methods: List[str] = Field(..., description="The methods of the trigger")
trigger_type: Optional[str] = Field(
default="http", description="The type of the trigger"
)
class BaseHttpBody(BaseModel):
"""Http body.
@@ -444,7 +455,7 @@ class HttpTrigger(Trigger):
def mount_to_router(
self, router: "APIRouter", global_prefix: Optional[str] = None
) -> None:
) -> HttpTriggerMetadata:
"""Mount the trigger to a router.
Args:
@@ -466,8 +477,11 @@ class HttpTrigger(Trigger):
)(dynamic_route_function)
logger.info(f"Mount http trigger success, path: {path}")
return HttpTriggerMetadata(path=path, methods=self._methods)
def mount_to_app(self, app: "FastAPI", global_prefix: Optional[str] = None) -> None:
def mount_to_app(
self, app: "FastAPI", global_prefix: Optional[str] = None
) -> HttpTriggerMetadata:
"""Mount the trigger to a FastAPI app.
TODO: The performance of this method is not good, need to be optimized.
@@ -498,6 +512,7 @@ class HttpTrigger(Trigger):
app.openapi_schema = None
app.middleware_stack = None
logger.info(f"Mount http trigger success, path: {path}")
return HttpTriggerMetadata(path=path, methods=self._methods)
def remove_from_app(
self, app: "FastAPI", global_prefix: Optional[str] = None

View File

@@ -2,15 +2,16 @@
After DB-GPT started, the trigger manager will be initialized and register all triggers
"""
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from ..util.http_util import join_paths
from .base import Trigger
from .base import Trigger, TriggerMetadata
if TYPE_CHECKING:
from fastapi import APIRouter
@@ -23,7 +24,9 @@ class TriggerManager(ABC):
"""Base class for trigger manager."""
@abstractmethod
def register_trigger(self, trigger: Any, system_app: SystemApp) -> None:
def register_trigger(
self, trigger: Any, system_app: SystemApp
) -> Optional[TriggerMetadata]:
"""Register a trigger to current manager."""
@abstractmethod
@@ -65,10 +68,12 @@ class HttpTriggerManager(TriggerManager):
self._inited = False
self._router_prefix = router_prefix
self._router = router
self._trigger_map: Dict[str, Trigger] = {}
self._trigger_map: Dict[str, Tuple[Trigger, TriggerMetadata]] = {}
self._router_tables: Dict[str, Set[str]] = defaultdict(set)
def register_trigger(self, trigger: Any, system_app: SystemApp) -> None:
def register_trigger(
self, trigger: Any, system_app: SystemApp
) -> Optional[TriggerMetadata]:
"""Register a trigger to current manager."""
from .http_trigger import HttpTrigger
@@ -86,13 +91,17 @@ class HttpTriggerManager(TriggerManager):
if not app:
raise ValueError("System app not initialized")
# Mount to app, support dynamic route.
trigger.mount_to_app(app, self._router_prefix)
trigger_metadata = trigger.mount_to_app(app, self._router_prefix)
else:
trigger.mount_to_router(self._router, self._router_prefix)
self._trigger_map[trigger_id] = trigger
trigger_metadata = trigger.mount_to_router(
self._router, self._router_prefix
)
self._trigger_map[trigger_id] = (trigger, trigger_metadata)
return trigger_metadata
except Exception as e:
self._unregister_route_tables(path, methods)
raise e
return None
def unregister_trigger(self, trigger: Any, system_app: SystemApp) -> None:
"""Unregister a trigger to current manager."""
@@ -183,7 +192,9 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
if system_app and self.system_app.app:
self._http_trigger = HttpTriggerManager()
def register_trigger(self, trigger: Any, system_app: SystemApp) -> None:
def register_trigger(
self, trigger: Any, system_app: SystemApp
) -> Optional[TriggerMetadata]:
"""Register a trigger to current manager."""
from .http_trigger import HttpTrigger
@@ -191,7 +202,9 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
logger.info(f"Register trigger {trigger}")
if not self._http_trigger:
raise ValueError("Http trigger manager not initialized")
self._http_trigger.register_trigger(trigger, system_app)
return self._http_trigger.register_trigger(trigger, system_app)
else:
return None
def unregister_trigger(self, trigger: Any, system_app: SystemApp) -> None:
"""Unregister a trigger to current manager."""

View File

@@ -0,0 +1,323 @@
"""The utility functions for chatting with the DAG task."""
import json
import traceback
from typing import Any, AsyncIterator, Dict, Optional
from ...interface.llm import ModelInferenceMetrics, ModelOutput
from ...schema.api import ChatCompletionResponseStreamChoice
from ..operators.base import BaseOperator
from ..trigger.http_trigger import CommonLLMHttpResponseBody
def is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
"""Check whether the output object is a chat flow type."""
if is_class:
return output_obj in (str, CommonLLMHttpResponseBody, ModelOutput)
else:
chat_types = (str, CommonLLMHttpResponseBody)
return isinstance(output_obj, chat_types)
async def safe_chat_with_dag_task(
task: BaseOperator, request: Any, covert_to_str: bool = False
) -> ModelOutput:
"""Chat with the DAG task.
Args:
task (BaseOperator): The DAG task to be executed.
request (Any): The request to be passed to the DAG task.
covert_to_str (bool, optional): Whether to convert the output to string.
Returns:
ModelOutput: The model output, the result is not incremental.
"""
try:
finish_reason = None
usage = None
metrics = None
error_code = 0
text = ""
async for output in safe_chat_stream_with_dag_task(
task, request, False, covert_to_str=covert_to_str
):
finish_reason = output.finish_reason
usage = output.usage
metrics = output.metrics
error_code = output.error_code
text = output.text
return ModelOutput(
error_code=error_code,
text=text,
metrics=metrics,
usage=usage,
finish_reason=finish_reason,
)
except Exception as e:
return ModelOutput(error_code=1, text=str(e), incremental=False)
async def safe_chat_stream_with_dag_task(
task: BaseOperator, request: Any, incremental: bool, covert_to_str: bool = False
) -> AsyncIterator[ModelOutput]:
"""Chat with the DAG task.
This function is similar to `chat_stream_with_dag_task`, but it will catch the
exception and return the error message.
Args:
task (BaseOperator): The DAG task to be executed.
request (Any): The request to be passed to the DAG task.
incremental (bool): Whether the output is incremental.
covert_to_str (bool, optional): Whether to convert the output to string.
Yields:
ModelOutput: The model output.
"""
try:
async for output in chat_stream_with_dag_task(
task, request, incremental, covert_to_str=covert_to_str
):
yield output
except Exception as e:
simple_error_msg = str(e)
if not simple_error_msg:
simple_error_msg = traceback.format_exc()
yield ModelOutput(error_code=1, text=simple_error_msg, incremental=incremental)
finally:
if task.streaming_operator and task.dag:
await task.dag._after_dag_end(task.current_event_loop_task_id)
def _is_sse_output(task: BaseOperator) -> bool:
"""Check whether the DAG task is a server-sent event output.
Args:
task (BaseOperator): The DAG task.
Returns:
bool: Whether the DAG task is a server-sent event output.
"""
return task.output_format is not None and task.output_format.upper() == "SSE"
async def chat_stream_with_dag_task(
task: BaseOperator, request: Any, incremental: bool, covert_to_str: bool = False
) -> AsyncIterator[ModelOutput]:
"""Chat with the DAG task.
Args:
task (BaseOperator): The DAG task to be executed.
request (Any): The request to be passed to the DAG task.
incremental (bool): Whether the output is incremental.
covert_to_str (bool, optional): Whether to convert the output to string.
Yields:
ModelOutput: The model output.
"""
is_sse = _is_sse_output(task)
if not task.streaming_operator:
try:
result = await task.call(request)
model_output = parse_single_output(
result, is_sse, covert_to_str=covert_to_str
)
model_output.incremental = incremental
yield model_output
except Exception as e:
simple_error_msg = str(e)
if not simple_error_msg:
simple_error_msg = traceback.format_exc()
yield ModelOutput(
error_code=1, text=simple_error_msg, incremental=incremental
)
else:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
full_text = ""
async for output in await task.call_stream(request):
model_output = parse_openai_output(output)
# The output of the OpenAI streaming API is incremental
full_text += model_output.text
model_output.incremental = incremental
model_output.text = model_output.text if incremental else full_text
yield model_output
if not model_output.success:
break
else:
full_text = ""
previous_text = ""
async for output in await task.call_stream(request):
model_output = parse_single_output(
output, is_sse, covert_to_str=covert_to_str
)
model_output.incremental = incremental
if task.incremental_output:
# Output is incremental, append the text
full_text += model_output.text
else:
# Output is not incremental, last output is the full text
full_text = model_output.text
if not incremental:
# Return the full text
model_output.text = full_text
else:
# Return the incremental text
delta_text = full_text[len(previous_text) :]
previous_text = (
full_text
if len(full_text) > len(previous_text)
else previous_text
)
model_output.text = delta_text
yield model_output
if not model_output.success:
break
def parse_single_output(
output: Any, is_sse: bool, covert_to_str: bool = False
) -> ModelOutput:
"""Parse the single output.
Args:
output (Any): The output to parse.
is_sse (bool): Whether the output is in SSE format.
covert_to_str (bool, optional): Whether to convert the output to string.
Defaults to False.
Returns:
ModelOutput: The parsed output.
"""
finish_reason: Optional[str] = None
usage: Optional[Dict[str, Any]] = None
metrics: Optional[ModelInferenceMetrics] = None
if output is None:
error_code = 1
text = "The output is None!"
elif isinstance(output, str):
if is_sse:
sse_output = parse_sse_data(output)
if sse_output is None:
error_code = 1
text = "The output is not a SSE format"
else:
error_code = 0
text = sse_output
else:
error_code = 0
text = output
elif isinstance(output, ModelOutput):
error_code = output.error_code
text = output.text
finish_reason = output.finish_reason
usage = output.usage
metrics = output.metrics
elif isinstance(output, CommonLLMHttpResponseBody):
error_code = output.error_code
text = output.text
elif isinstance(output, dict):
error_code = 0
text = json.dumps(output, ensure_ascii=False)
elif covert_to_str:
error_code = 0
text = str(output)
else:
error_code = 1
text = f"The output is not a valid format({type(output)})"
return ModelOutput(
error_code=error_code,
text=text,
finish_reason=finish_reason,
usage=usage,
metrics=metrics,
)
def parse_openai_output(output: Any) -> ModelOutput:
"""Parse the OpenAI output.
Args:
output (Any): The output to parse. It must be a stream format.
Returns:
ModelOutput: The parsed output.
"""
text = ""
if not isinstance(output, str):
return ModelOutput(
error_code=1,
text="The output is not a stream format",
)
if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]":
return ModelOutput(error_code=0, text="")
if not output.startswith("data:"):
return ModelOutput(
error_code=1,
text="The output is not a stream format",
)
sse_output = parse_sse_data(output)
if sse_output is None:
return ModelOutput(error_code=1, text="The output is not a SSE format")
json_data = sse_output.strip()
try:
dict_data = json.loads(json_data)
except Exception as e:
return ModelOutput(
error_code=1,
text=f"Invalid JSON data: {json_data}, {e}",
)
if "choices" not in dict_data:
return ModelOutput(
error_code=1,
text=dict_data.get("text", "Unknown error"),
)
choices = dict_data["choices"]
finish_reason: Optional[str] = None
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
text = delta_data.delta.content
finish_reason = delta_data.finish_reason
return ModelOutput(error_code=0, text=text, finish_reason=finish_reason)
def parse_sse_data(output: str) -> Optional[str]:
r"""Parse the SSE data.
Just keep the data part.
Examples:
.. code-block:: python
from dbgpt.core.awel.util.chat_util import parse_sse_data
assert parse_sse_data("data: [DONE]") == "[DONE]"
assert parse_sse_data("data:[DONE]") == "[DONE]"
assert parse_sse_data("data: Hello") == "Hello"
assert parse_sse_data("data: Hello\n") == "Hello"
assert parse_sse_data("data: Hello\r\n") == "Hello"
assert parse_sse_data("data: Hi, what's up?") == "Hi, what's up?"
Args:
output (str): The output.
Returns:
Optional[str]: The parsed data.
"""
if output.startswith("data:"):
output = output.strip()
if output.startswith("data: "):
output = output[6:]
else:
output = output[5:]
return output
else:
return None