feat(core): Add debug and export/import for AWEL flow

This commit is contained in:
Fangyin Cheng
2024-08-18 07:41:46 +08:00
parent 25f6d8aab5
commit 4f2c56d821
9 changed files with 446 additions and 75 deletions

View File

@@ -197,7 +197,7 @@ class DAGManager(BaseComponent):
return self._dag_metadata_map.get(dag.dag_id)
def _parse_metadata(dag: DAG):
def _parse_metadata(dag: DAG) -> DAGMetadata:
from ..util.chat_util import _is_sse_output
metadata = DAGMetadata()

View File

@@ -4,7 +4,7 @@ import logging
import uuid
from contextlib import suppress
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast
from typing_extensions import Annotated
@@ -166,6 +166,59 @@ class FlowData(BaseModel):
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
class VariablesRequest(BaseModel):
"""Variable request model.
For creating a new variable in the DB-GPT.
"""
key: str = Field(
...,
description="The key of the variable to create",
examples=["dbgpt.model.openai.api_key"],
)
name: str = Field(
...,
description="The name of the variable to create",
examples=["my_first_openai_key"],
)
label: str = Field(
...,
description="The label of the variable to create",
examples=["My First OpenAI Key"],
)
value: Any = Field(
..., description="The value of the variable to create", examples=["1234567890"]
)
value_type: Literal["str", "int", "float", "bool"] = Field(
"str",
description="The type of the value of the variable to create",
examples=["str", "int", "float", "bool"],
)
category: Literal["common", "secret"] = Field(
...,
description="The category of the variable to create",
examples=["common"],
)
scope: str = Field(
...,
description="The scope of the variable to create",
examples=["global"],
)
scope_key: Optional[str] = Field(
...,
description="The scope key of the variable to create",
examples=["dbgpt"],
)
enabled: Optional[bool] = Field(
True,
description="Whether the variable is enabled",
examples=[True],
)
user_name: Optional[str] = Field(None, description="User name")
sys_code: Optional[str] = Field(None, description="System code")
class State(str, Enum):
"""State of a flow panel."""
@@ -356,6 +409,12 @@ class FlowPanel(BaseModel):
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
default=None, description="The metadata of the flow"
)
variables: Optional[List[VariablesRequest]] = Field(
default=None, description="The variables of the flow"
)
authors: Optional[List[str]] = Field(
default=None, description="The authors of the flow"
)
@model_validator(mode="before")
@classmethod

View File

@@ -334,7 +334,8 @@ class InputOperator(BaseOperator, Generic[OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_output = await self._input_source.read(curr_task_ctx)
curr_task_ctx.set_task_output(task_output)
new_task_output: TaskOutput[OUT] = await task_output.map(self.map)
curr_task_ctx.set_task_output(new_task_output)
return task_output
@classmethod
@@ -342,6 +343,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
"""Create a dummy InputOperator with a given input value."""
return cls(input_source=InputSource.from_data(dummy_data), **kwargs)
async def map(self, input_data: OUT) -> OUT:
"""Map the input data to a new value."""
return input_data
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
"""Operator node that triggers the DAG to run."""

View File

@@ -87,7 +87,9 @@ class HttpTriggerMetadata(TriggerMetadata):
path: str = Field(..., description="The path of the trigger")
methods: List[str] = Field(..., description="The methods of the trigger")
trigger_mode: str = Field(
default="command", description="The mode of the trigger, command or chat"
)
trigger_type: Optional[str] = Field(
default="http", description="The type of the trigger"
)
@@ -477,7 +479,9 @@ class HttpTrigger(Trigger):
)(dynamic_route_function)
logger.info(f"Mount http trigger success, path: {path}")
return HttpTriggerMetadata(path=path, methods=self._methods)
return HttpTriggerMetadata(
path=path, methods=self._methods, trigger_mode=self._trigger_mode()
)
def mount_to_app(
self, app: "FastAPI", global_prefix: Optional[str] = None
@@ -512,7 +516,9 @@ class HttpTrigger(Trigger):
app.openapi_schema = None
app.middleware_stack = None
logger.info(f"Mount http trigger success, path: {path}")
return HttpTriggerMetadata(path=path, methods=self._methods)
return HttpTriggerMetadata(
path=path, methods=self._methods, trigger_mode=self._trigger_mode()
)
def remove_from_app(
self, app: "FastAPI", global_prefix: Optional[str] = None
@@ -537,6 +543,36 @@ class HttpTrigger(Trigger):
# TODO, remove with path and methods
del app_router.routes[i]
def _trigger_mode(self) -> str:
if (
self._req_body
and isinstance(self._req_body, type)
and issubclass(self._req_body, CommonLLMHttpRequestBody)
):
return "chat"
return "command"
async def map(self, input_data: Any) -> Any:
"""Map the input data.
Do some transformation for the input data.
Args:
input_data (Any): The input data from caller.
Returns:
Any: The mapped data.
"""
if not self._req_body or not input_data:
return await super().map(input_data)
if (
isinstance(self._req_body, type)
and issubclass(self._req_body, BaseModel)
and isinstance(input_data, dict)
):
return self._req_body(**input_data)
return await super().map(input_data)
def _create_route_func(self):
from inspect import Parameter, Signature
from typing import get_type_hints

View File

@@ -1,9 +1,11 @@
import io
import json
from functools import cache
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, Optional, Union
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from starlette.responses import JSONResponse, StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
@@ -15,6 +17,7 @@ from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from ..service.variables_service import VariablesService
from .schemas import (
FlowDebugRequest,
RefreshNodeRequest,
ServeRequest,
ServerResponse,
@@ -352,10 +355,116 @@ async def update_variables(
return Result.succ(res)
@router.post("/flow/debug")
async def debug():
"""Debug the flow."""
# TODO: Implement the debug endpoint
@router.post("/flow/debug", dependencies=[Depends(check_api_key)])
async def debug_flow(
flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service)
):
"""Run the flow in debug mode."""
# Return the no-incremental stream by default
stream_iter = service.debug_flow(flow_debug_request, default_incremental=False)
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
service._wrapper_chat_stream_flow_str(stream_iter),
headers=headers,
media_type="text/event-stream",
)
@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)])
async def export_flow(
uid: str,
export_type: Literal["json", "dbgpts"] = Query(
"json", description="export type(json or dbgpts)"
),
format: Literal["file", "json"] = Query(
"file", description="response format(file or json)"
),
file_name: Optional[str] = Query(default=None, description="file name to export"),
user_name: Optional[str] = Query(default=None, description="user name"),
sys_code: Optional[str] = Query(default=None, description="system code"),
service: Service = Depends(get_service),
):
"""Export the flow to a file."""
flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code})
if not flow:
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
package_name = flow.name.replace("_", "-")
file_name = file_name or package_name
if export_type == "json":
flow_dict = {"flow": flow.to_dict()}
if format == "json":
return JSONResponse(content=flow_dict)
else:
# Return the json file
return StreamingResponse(
io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")),
media_type="application/file",
headers={
"Content-Disposition": f"attachment;filename={file_name}.json"
},
)
elif export_type == "dbgpts":
from ..service.share_utils import _generate_dbgpts_zip
if format == "json":
raise HTTPException(
status_code=400, detail="json response is not supported for dbgpts"
)
zip_buffer = await blocking_func_to_async(
global_system_app, _generate_dbgpts_zip, package_name, flow
)
return StreamingResponse(
zip_buffer,
media_type="application/x-zip-compressed",
headers={"Content-Disposition": f"attachment;filename={file_name}.zip"},
)
@router.post(
"/flow/import",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def import_flow(
file: UploadFile = File(...),
save_flow: bool = Query(
False, description="Whether to save the flow after importing"
),
service: Service = Depends(get_service),
):
"""Import the flow from a file."""
filename = file.filename
file_extension = filename.split(".")[-1].lower()
if file_extension == "json":
# Handle json file
json_content = await file.read()
json_dict = json.loads(json_content)
if "flow" not in json_dict:
raise HTTPException(
status_code=400, detail="invalid json file, missing 'flow' key"
)
flow = ServeRequest.parse_obj(json_dict["flow"])
elif file_extension == "zip":
from ..service.share_utils import _parse_flow_from_zip_file
# Handle zip file
flow = await _parse_flow_from_zip_file(file, global_system_app)
else:
raise HTTPException(
status_code=400, detail=f"invalid file extension {file_extension}"
)
if save_flow:
return Result.succ(service.create_and_save_dag(flow))
else:
return Result.succ(flow)
def init_endpoints(system_app: SystemApp) -> None:

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core.awel import CommonLLMHttpRequestBody
from dbgpt.core.awel.flow.flow_factory import FlowPanel
from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest
from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest
from ..config import SERVE_APP_NAME_HUMP
@@ -18,59 +18,6 @@ class ServerResponse(FlowPanel):
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
class VariablesRequest(BaseModel):
"""Variable request model.
For creating a new variable in the DB-GPT.
"""
key: str = Field(
...,
description="The key of the variable to create",
examples=["dbgpt.model.openai.api_key"],
)
name: str = Field(
...,
description="The name of the variable to create",
examples=["my_first_openai_key"],
)
label: str = Field(
...,
description="The label of the variable to create",
examples=["My First OpenAI Key"],
)
value: Any = Field(
..., description="The value of the variable to create", examples=["1234567890"]
)
value_type: Literal["str", "int", "float", "bool"] = Field(
"str",
description="The type of the value of the variable to create",
examples=["str", "int", "float", "bool"],
)
category: Literal["common", "secret"] = Field(
...,
description="The category of the variable to create",
examples=["common"],
)
scope: str = Field(
...,
description="The scope of the variable to create",
examples=["global"],
)
scope_key: Optional[str] = Field(
...,
description="The scope key of the variable to create",
examples=["dbgpt"],
)
enabled: Optional[bool] = Field(
True,
description="Whether the variable is enabled",
examples=[True],
)
user_name: Optional[str] = Field(None, description="User name")
sys_code: Optional[str] = Field(None, description="System code")
class VariablesResponse(VariablesRequest):
"""Variable response model."""

View File

@@ -9,7 +9,6 @@ from dbgpt._private.pydantic import model_to_json
from dbgpt.agent import AgentDummyTrigger
from dbgpt.component import SystemApp
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import (
FlowCategory,
FlowFactory,
@@ -34,7 +33,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from dbgpt.util.dbgpts.loader import DBGPTsLoader
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
@@ -147,7 +146,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
raise ValueError(
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
) from e
res = self.dao.create(request)
self.dao.create(request)
# Query from database
res = self.get({"uid": request.uid})
state = request.state
try:
@@ -574,3 +575,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON
async def debug_flow(
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
) -> AsyncIterator[ModelOutput]:
"""Debug the flow.
Args:
request (FlowDebugRequest): The request
default_incremental (Optional[bool]): The default incremental configuration
Returns:
AsyncIterator[ModelOutput]: The output
"""
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
dag = self._flow_factory.build(request.flow)
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
task = cast(BaseOperator, leaf_nodes[0])
dag_metadata = _parse_metadata(dag)
# TODO: Run task with variables
variables = request.variables
dag_request = request.request
if isinstance(request.request, CommonLLMHttpRequestBody):
incremental = request.request.incremental
elif isinstance(request.request, dict):
incremental = request.request.get("incremental", False)
else:
raise ValueError("Invalid request type")
if default_incremental is not None:
incremental = default_incremental
try:
async for output in safe_chat_stream_with_dag_task(
task, dag_request, incremental
):
yield output
except HTTPException as e:
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def _wrapper_chat_stream_flow_str(
self, stream_iter: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async for output in stream_iter:
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"

View File

@@ -0,0 +1,121 @@
import io
import json
import os
import tempfile
import zipfile
import aiofiles
import tomlkit
from fastapi import UploadFile
from dbgpt.component import SystemApp
from dbgpt.serve.core import blocking_func_to_async
from ..api.schemas import ServeRequest
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
zip_buffer = io.BytesIO()
flow_name = flow.name
flow_label = flow.label
flow_description = flow.description
dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False)
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json"
readme = f"# {flow_label}\n\n{flow_description}"
zip_file.writestr(f"{package_name}/MANIFEST.in", manifest)
zip_file.writestr(f"{package_name}/README.md", readme)
zip_file.writestr(
f"{package_name}/{flow_name}/__init__.py",
"",
)
zip_file.writestr(
f"{package_name}/{flow_name}/definition/flow_definition.json",
dag_json,
)
dbgpts_toml = tomlkit.document()
# Add flow information
dbgpts_flow_toml = tomlkit.document()
dbgpts_flow_toml.add("label", "Simple Streaming Chat")
name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat")
name_with_comment.comment("A unique name for all dbgpts")
dbgpts_flow_toml.add("name", name_with_comment)
dbgpts_flow_toml.add("version", "0.1.0")
dbgpts_flow_toml.add(
"description",
flow_description,
)
dbgpts_flow_toml.add("authors", [])
definition_type_with_comment = tomlkit.string("json")
definition_type_with_comment.comment("How to define the flow, python or json")
dbgpts_flow_toml.add("definition_type", definition_type_with_comment)
dbgpts_toml.add("flow", dbgpts_flow_toml)
# Add python and json config
python_config = tomlkit.table()
dbgpts_toml.add("python_config", python_config)
json_config = tomlkit.table()
json_config.add("file_path", "definition/flow_definition.json")
json_config.comment("Json config")
dbgpts_toml.add("json_config", json_config)
# Transform to string
toml_string = tomlkit.dumps(dbgpts_toml)
zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string)
pyproject_toml = tomlkit.document()
# Add [tool.poetry] section
tool_poetry_toml = tomlkit.table()
tool_poetry_toml.add("name", package_name)
tool_poetry_toml.add("version", "0.1.0")
tool_poetry_toml.add("description", "A dbgpts package")
tool_poetry_toml.add("authors", [])
tool_poetry_toml.add("readme", "README.md")
pyproject_toml["tool"] = tomlkit.table()
pyproject_toml["tool"]["poetry"] = tool_poetry_toml
# Add [tool.poetry.dependencies] section
dependencies = tomlkit.table()
dependencies.add("python", "^3.10")
pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies
# Add [build-system] section
build_system = tomlkit.table()
build_system.add("requires", ["poetry-core"])
build_system.add("build-backend", "poetry.core.masonry.api")
pyproject_toml["build-system"] = build_system
# Transform to string
pyproject_toml_string = tomlkit.dumps(pyproject_toml)
zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string)
zip_buffer.seek(0)
return zip_buffer
async def _parse_flow_from_zip_file(
file: UploadFile, sys_app: SystemApp
) -> ServeRequest:
from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path
filename = file.filename
if not filename.endswith(".zip"):
raise ValueError("Uploaded file must be a ZIP file")
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, filename)
# Save uploaded file to temporary directory
async with aiofiles.open(zip_path, "wb") as out_file:
while content := await file.read(1024 * 64): # Read in chunks of 64KB
await out_file.write(content)
flow = await blocking_func_to_async(
sys_app, _load_flow_package_from_zip_path, zip_path
)
return flow

View File

@@ -328,14 +328,19 @@ def _load_package_from_path(path: str):
return parsed_packages
def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPackage:
def _load_flow_package_from_path(
name: str, path: str = INSTALL_DIR, filter_by_name: bool = True
) -> FlowPackage:
raw_packages = _load_installed_package(path)
new_name = name.replace("_", "-")
packages = [p for p in raw_packages if p.package == name or p.name == name]
if not packages:
packages = [
p for p in raw_packages if p.package == new_name or p.name == new_name
]
if filter_by_name:
packages = [p for p in raw_packages if p.package == name or p.name == name]
if not packages:
packages = [
p for p in raw_packages if p.package == new_name or p.name == new_name
]
else:
packages = raw_packages
if not packages:
raise ValueError(f"Can't find the package {name} or {new_name}")
flow_package = _parse_package_metadata(packages[0])
@@ -344,6 +349,35 @@ def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPack
return cast(FlowPackage, flow_package)
def _load_flow_package_from_zip_path(zip_path: str) -> FlowPanel:
import tempfile
import zipfile
with tempfile.TemporaryDirectory() as temp_dir:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(temp_dir)
package_names = os.listdir(temp_dir)
if not package_names:
raise ValueError("No package found in the zip file")
if len(package_names) > 1:
raise ValueError("Only support one package in the zip file")
package_name = package_names[0]
with open(
Path(temp_dir) / package_name / INSTALL_METADATA_FILE, mode="w+"
) as f:
# Write the metadata
import tomlkit
install_metadata = {
"name": package_name,
"repo": "local/dbgpts",
}
tomlkit.dump(install_metadata, f)
package = _load_flow_package_from_path("", path=temp_dir, filter_by_name=False)
return _flow_package_to_flow_panel(package)
def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel:
dict_value = {
"name": package.name,
@@ -353,6 +387,7 @@ def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel:
"description": package.description,
"source": package.repo,
"define_type": "json",
"authors": package.authors,
}
if isinstance(package, FlowJsonPackage):
dict_value["flow_data"] = package.read_definition_json()