mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
feat(core): Add debug and export/import for AWEL flow
This commit is contained in:
@@ -197,7 +197,7 @@ class DAGManager(BaseComponent):
|
|||||||
return self._dag_metadata_map.get(dag.dag_id)
|
return self._dag_metadata_map.get(dag.dag_id)
|
||||||
|
|
||||||
|
|
||||||
def _parse_metadata(dag: DAG):
|
def _parse_metadata(dag: DAG) -> DAGMetadata:
|
||||||
from ..util.chat_util import _is_sse_output
|
from ..util.chat_util import _is_sse_output
|
||||||
|
|
||||||
metadata = DAGMetadata()
|
metadata = DAGMetadata()
|
||||||
|
@@ -4,7 +4,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast
|
||||||
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
@@ -166,6 +166,59 @@ class FlowData(BaseModel):
|
|||||||
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
|
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
|
||||||
|
|
||||||
|
|
||||||
|
class VariablesRequest(BaseModel):
|
||||||
|
"""Variable request model.
|
||||||
|
|
||||||
|
For creating a new variable in the DB-GPT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str = Field(
|
||||||
|
...,
|
||||||
|
description="The key of the variable to create",
|
||||||
|
examples=["dbgpt.model.openai.api_key"],
|
||||||
|
)
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
description="The name of the variable to create",
|
||||||
|
examples=["my_first_openai_key"],
|
||||||
|
)
|
||||||
|
label: str = Field(
|
||||||
|
...,
|
||||||
|
description="The label of the variable to create",
|
||||||
|
examples=["My First OpenAI Key"],
|
||||||
|
)
|
||||||
|
value: Any = Field(
|
||||||
|
..., description="The value of the variable to create", examples=["1234567890"]
|
||||||
|
)
|
||||||
|
value_type: Literal["str", "int", "float", "bool"] = Field(
|
||||||
|
"str",
|
||||||
|
description="The type of the value of the variable to create",
|
||||||
|
examples=["str", "int", "float", "bool"],
|
||||||
|
)
|
||||||
|
category: Literal["common", "secret"] = Field(
|
||||||
|
...,
|
||||||
|
description="The category of the variable to create",
|
||||||
|
examples=["common"],
|
||||||
|
)
|
||||||
|
scope: str = Field(
|
||||||
|
...,
|
||||||
|
description="The scope of the variable to create",
|
||||||
|
examples=["global"],
|
||||||
|
)
|
||||||
|
scope_key: Optional[str] = Field(
|
||||||
|
...,
|
||||||
|
description="The scope key of the variable to create",
|
||||||
|
examples=["dbgpt"],
|
||||||
|
)
|
||||||
|
enabled: Optional[bool] = Field(
|
||||||
|
True,
|
||||||
|
description="Whether the variable is enabled",
|
||||||
|
examples=[True],
|
||||||
|
)
|
||||||
|
user_name: Optional[str] = Field(None, description="User name")
|
||||||
|
sys_code: Optional[str] = Field(None, description="System code")
|
||||||
|
|
||||||
|
|
||||||
class State(str, Enum):
|
class State(str, Enum):
|
||||||
"""State of a flow panel."""
|
"""State of a flow panel."""
|
||||||
|
|
||||||
@@ -356,6 +409,12 @@ class FlowPanel(BaseModel):
|
|||||||
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
|
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
|
||||||
default=None, description="The metadata of the flow"
|
default=None, description="The metadata of the flow"
|
||||||
)
|
)
|
||||||
|
variables: Optional[List[VariablesRequest]] = Field(
|
||||||
|
default=None, description="The variables of the flow"
|
||||||
|
)
|
||||||
|
authors: Optional[List[str]] = Field(
|
||||||
|
default=None, description="The authors of the flow"
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@@ -334,7 +334,8 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
|||||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
task_output = await self._input_source.read(curr_task_ctx)
|
task_output = await self._input_source.read(curr_task_ctx)
|
||||||
curr_task_ctx.set_task_output(task_output)
|
new_task_output: TaskOutput[OUT] = await task_output.map(self.map)
|
||||||
|
curr_task_ctx.set_task_output(new_task_output)
|
||||||
return task_output
|
return task_output
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -342,6 +343,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
|||||||
"""Create a dummy InputOperator with a given input value."""
|
"""Create a dummy InputOperator with a given input value."""
|
||||||
return cls(input_source=InputSource.from_data(dummy_data), **kwargs)
|
return cls(input_source=InputSource.from_data(dummy_data), **kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_data: OUT) -> OUT:
|
||||||
|
"""Map the input data to a new value."""
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
|
||||||
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
|
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
|
||||||
"""Operator node that triggers the DAG to run."""
|
"""Operator node that triggers the DAG to run."""
|
||||||
|
@@ -87,7 +87,9 @@ class HttpTriggerMetadata(TriggerMetadata):
|
|||||||
|
|
||||||
path: str = Field(..., description="The path of the trigger")
|
path: str = Field(..., description="The path of the trigger")
|
||||||
methods: List[str] = Field(..., description="The methods of the trigger")
|
methods: List[str] = Field(..., description="The methods of the trigger")
|
||||||
|
trigger_mode: str = Field(
|
||||||
|
default="command", description="The mode of the trigger, command or chat"
|
||||||
|
)
|
||||||
trigger_type: Optional[str] = Field(
|
trigger_type: Optional[str] = Field(
|
||||||
default="http", description="The type of the trigger"
|
default="http", description="The type of the trigger"
|
||||||
)
|
)
|
||||||
@@ -477,7 +479,9 @@ class HttpTrigger(Trigger):
|
|||||||
)(dynamic_route_function)
|
)(dynamic_route_function)
|
||||||
|
|
||||||
logger.info(f"Mount http trigger success, path: {path}")
|
logger.info(f"Mount http trigger success, path: {path}")
|
||||||
return HttpTriggerMetadata(path=path, methods=self._methods)
|
return HttpTriggerMetadata(
|
||||||
|
path=path, methods=self._methods, trigger_mode=self._trigger_mode()
|
||||||
|
)
|
||||||
|
|
||||||
def mount_to_app(
|
def mount_to_app(
|
||||||
self, app: "FastAPI", global_prefix: Optional[str] = None
|
self, app: "FastAPI", global_prefix: Optional[str] = None
|
||||||
@@ -512,7 +516,9 @@ class HttpTrigger(Trigger):
|
|||||||
app.openapi_schema = None
|
app.openapi_schema = None
|
||||||
app.middleware_stack = None
|
app.middleware_stack = None
|
||||||
logger.info(f"Mount http trigger success, path: {path}")
|
logger.info(f"Mount http trigger success, path: {path}")
|
||||||
return HttpTriggerMetadata(path=path, methods=self._methods)
|
return HttpTriggerMetadata(
|
||||||
|
path=path, methods=self._methods, trigger_mode=self._trigger_mode()
|
||||||
|
)
|
||||||
|
|
||||||
def remove_from_app(
|
def remove_from_app(
|
||||||
self, app: "FastAPI", global_prefix: Optional[str] = None
|
self, app: "FastAPI", global_prefix: Optional[str] = None
|
||||||
@@ -537,6 +543,36 @@ class HttpTrigger(Trigger):
|
|||||||
# TODO, remove with path and methods
|
# TODO, remove with path and methods
|
||||||
del app_router.routes[i]
|
del app_router.routes[i]
|
||||||
|
|
||||||
|
def _trigger_mode(self) -> str:
|
||||||
|
if (
|
||||||
|
self._req_body
|
||||||
|
and isinstance(self._req_body, type)
|
||||||
|
and issubclass(self._req_body, CommonLLMHttpRequestBody)
|
||||||
|
):
|
||||||
|
return "chat"
|
||||||
|
return "command"
|
||||||
|
|
||||||
|
async def map(self, input_data: Any) -> Any:
|
||||||
|
"""Map the input data.
|
||||||
|
|
||||||
|
Do some transformation for the input data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data (Any): The input data from caller.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The mapped data.
|
||||||
|
"""
|
||||||
|
if not self._req_body or not input_data:
|
||||||
|
return await super().map(input_data)
|
||||||
|
if (
|
||||||
|
isinstance(self._req_body, type)
|
||||||
|
and issubclass(self._req_body, BaseModel)
|
||||||
|
and isinstance(input_data, dict)
|
||||||
|
):
|
||||||
|
return self._req_body(**input_data)
|
||||||
|
return await super().map(input_data)
|
||||||
|
|
||||||
def _create_route_func(self):
|
def _create_route_func(self):
|
||||||
from inspect import Parameter, Signature
|
from inspect import Parameter, Signature
|
||||||
from typing import get_type_hints
|
from typing import get_type_hints
|
||||||
|
@@ -1,9 +1,11 @@
|
|||||||
|
import io
|
||||||
import json
|
import json
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile
|
||||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from starlette.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
||||||
@@ -14,6 +16,7 @@ from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
|||||||
from ..service.service import Service
|
from ..service.service import Service
|
||||||
from ..service.variables_service import VariablesService
|
from ..service.variables_service import VariablesService
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
|
FlowDebugRequest,
|
||||||
RefreshNodeRequest,
|
RefreshNodeRequest,
|
||||||
ServeRequest,
|
ServeRequest,
|
||||||
ServerResponse,
|
ServerResponse,
|
||||||
@@ -322,10 +325,116 @@ async def update_variables(
|
|||||||
return Result.succ(res)
|
return Result.succ(res)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/flow/debug")
|
@router.post("/flow/debug", dependencies=[Depends(check_api_key)])
|
||||||
async def debug():
|
async def debug_flow(
|
||||||
"""Debug the flow."""
|
flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service)
|
||||||
# TODO: Implement the debug endpoint
|
):
|
||||||
|
"""Run the flow in debug mode."""
|
||||||
|
# Return the no-incremental stream by default
|
||||||
|
stream_iter = service.debug_flow(flow_debug_request, default_incremental=False)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
}
|
||||||
|
return StreamingResponse(
|
||||||
|
service._wrapper_chat_stream_flow_str(stream_iter),
|
||||||
|
headers=headers,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)])
|
||||||
|
async def export_flow(
|
||||||
|
uid: str,
|
||||||
|
export_type: Literal["json", "dbgpts"] = Query(
|
||||||
|
"json", description="export type(json or dbgpts)"
|
||||||
|
),
|
||||||
|
format: Literal["file", "json"] = Query(
|
||||||
|
"file", description="response format(file or json)"
|
||||||
|
),
|
||||||
|
file_name: Optional[str] = Query(default=None, description="file name to export"),
|
||||||
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||||
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||||
|
service: Service = Depends(get_service),
|
||||||
|
):
|
||||||
|
"""Export the flow to a file."""
|
||||||
|
flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code})
|
||||||
|
if not flow:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
|
||||||
|
package_name = flow.name.replace("_", "-")
|
||||||
|
file_name = file_name or package_name
|
||||||
|
if export_type == "json":
|
||||||
|
flow_dict = {"flow": flow.to_dict()}
|
||||||
|
if format == "json":
|
||||||
|
return JSONResponse(content=flow_dict)
|
||||||
|
else:
|
||||||
|
# Return the json file
|
||||||
|
return StreamingResponse(
|
||||||
|
io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")),
|
||||||
|
media_type="application/file",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment;filename={file_name}.json"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif export_type == "dbgpts":
|
||||||
|
from ..service.share_utils import _generate_dbgpts_zip
|
||||||
|
|
||||||
|
if format == "json":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="json response is not supported for dbgpts"
|
||||||
|
)
|
||||||
|
|
||||||
|
zip_buffer = await blocking_func_to_async(
|
||||||
|
global_system_app, _generate_dbgpts_zip, package_name, flow
|
||||||
|
)
|
||||||
|
return StreamingResponse(
|
||||||
|
zip_buffer,
|
||||||
|
media_type="application/x-zip-compressed",
|
||||||
|
headers={"Content-Disposition": f"attachment;filename={file_name}.zip"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/flow/import",
|
||||||
|
response_model=Result[ServerResponse],
|
||||||
|
dependencies=[Depends(check_api_key)],
|
||||||
|
)
|
||||||
|
async def import_flow(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
save_flow: bool = Query(
|
||||||
|
False, description="Whether to save the flow after importing"
|
||||||
|
),
|
||||||
|
service: Service = Depends(get_service),
|
||||||
|
):
|
||||||
|
"""Import the flow from a file."""
|
||||||
|
filename = file.filename
|
||||||
|
file_extension = filename.split(".")[-1].lower()
|
||||||
|
if file_extension == "json":
|
||||||
|
# Handle json file
|
||||||
|
json_content = await file.read()
|
||||||
|
json_dict = json.loads(json_content)
|
||||||
|
if "flow" not in json_dict:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="invalid json file, missing 'flow' key"
|
||||||
|
)
|
||||||
|
flow = ServeRequest.parse_obj(json_dict["flow"])
|
||||||
|
elif file_extension == "zip":
|
||||||
|
from ..service.share_utils import _parse_flow_from_zip_file
|
||||||
|
|
||||||
|
# Handle zip file
|
||||||
|
flow = await _parse_flow_from_zip_file(file, global_system_app)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"invalid file extension {file_extension}"
|
||||||
|
)
|
||||||
|
if save_flow:
|
||||||
|
return Result.succ(service.create_and_save_dag(flow))
|
||||||
|
else:
|
||||||
|
return Result.succ(flow)
|
||||||
|
|
||||||
|
|
||||||
def init_endpoints(system_app: SystemApp) -> None:
|
def init_endpoints(system_app: SystemApp) -> None:
|
||||||
|
@@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest
|
||||||
from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest
|
from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest
|
||||||
|
|
||||||
from ..config import SERVE_APP_NAME_HUMP
|
from ..config import SERVE_APP_NAME_HUMP
|
||||||
@@ -18,59 +18,6 @@ class ServerResponse(FlowPanel):
|
|||||||
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
|
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
|
||||||
|
|
||||||
|
|
||||||
class VariablesRequest(BaseModel):
|
|
||||||
"""Variable request model.
|
|
||||||
|
|
||||||
For creating a new variable in the DB-GPT.
|
|
||||||
"""
|
|
||||||
|
|
||||||
key: str = Field(
|
|
||||||
...,
|
|
||||||
description="The key of the variable to create",
|
|
||||||
examples=["dbgpt.model.openai.api_key"],
|
|
||||||
)
|
|
||||||
name: str = Field(
|
|
||||||
...,
|
|
||||||
description="The name of the variable to create",
|
|
||||||
examples=["my_first_openai_key"],
|
|
||||||
)
|
|
||||||
label: str = Field(
|
|
||||||
...,
|
|
||||||
description="The label of the variable to create",
|
|
||||||
examples=["My First OpenAI Key"],
|
|
||||||
)
|
|
||||||
value: Any = Field(
|
|
||||||
..., description="The value of the variable to create", examples=["1234567890"]
|
|
||||||
)
|
|
||||||
value_type: Literal["str", "int", "float", "bool"] = Field(
|
|
||||||
"str",
|
|
||||||
description="The type of the value of the variable to create",
|
|
||||||
examples=["str", "int", "float", "bool"],
|
|
||||||
)
|
|
||||||
category: Literal["common", "secret"] = Field(
|
|
||||||
...,
|
|
||||||
description="The category of the variable to create",
|
|
||||||
examples=["common"],
|
|
||||||
)
|
|
||||||
scope: str = Field(
|
|
||||||
...,
|
|
||||||
description="The scope of the variable to create",
|
|
||||||
examples=["global"],
|
|
||||||
)
|
|
||||||
scope_key: Optional[str] = Field(
|
|
||||||
...,
|
|
||||||
description="The scope key of the variable to create",
|
|
||||||
examples=["dbgpt"],
|
|
||||||
)
|
|
||||||
enabled: Optional[bool] = Field(
|
|
||||||
True,
|
|
||||||
description="Whether the variable is enabled",
|
|
||||||
examples=[True],
|
|
||||||
)
|
|
||||||
user_name: Optional[str] = Field(None, description="User name")
|
|
||||||
sys_code: Optional[str] = Field(None, description="System code")
|
|
||||||
|
|
||||||
|
|
||||||
class VariablesResponse(VariablesRequest):
|
class VariablesResponse(VariablesRequest):
|
||||||
"""Variable response model."""
|
"""Variable response model."""
|
||||||
|
|
||||||
|
@@ -8,7 +8,6 @@ from fastapi import HTTPException
|
|||||||
from dbgpt._private.pydantic import model_to_json
|
from dbgpt._private.pydantic import model_to_json
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
||||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
|
||||||
from dbgpt.core.awel.flow.flow_factory import (
|
from dbgpt.core.awel.flow.flow_factory import (
|
||||||
FlowCategory,
|
FlowCategory,
|
||||||
FlowFactory,
|
FlowFactory,
|
||||||
@@ -33,7 +32,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
|||||||
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
||||||
from dbgpt.util.pagination_utils import PaginationResult
|
from dbgpt.util.pagination_utils import PaginationResult
|
||||||
|
|
||||||
from ..api.schemas import ServeRequest, ServerResponse
|
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
|
||||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||||
from ..models.models import ServeDao, ServeEntity
|
from ..models.models import ServeDao, ServeEntity
|
||||||
|
|
||||||
@@ -146,7 +145,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
|
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
res = self.dao.create(request)
|
self.dao.create(request)
|
||||||
|
# Query from database
|
||||||
|
res = self.get({"uid": request.uid})
|
||||||
|
|
||||||
state = request.state
|
state = request.state
|
||||||
try:
|
try:
|
||||||
@@ -563,3 +564,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
return FlowCategory.CHAT_FLOW
|
return FlowCategory.CHAT_FLOW
|
||||||
except Exception:
|
except Exception:
|
||||||
return FlowCategory.COMMON
|
return FlowCategory.COMMON
|
||||||
|
|
||||||
|
async def debug_flow(
|
||||||
|
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
|
||||||
|
) -> AsyncIterator[ModelOutput]:
|
||||||
|
"""Debug the flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (FlowDebugRequest): The request
|
||||||
|
default_incremental (Optional[bool]): The default incremental configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIterator[ModelOutput]: The output
|
||||||
|
"""
|
||||||
|
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
|
||||||
|
|
||||||
|
dag = self._flow_factory.build(request.flow)
|
||||||
|
leaf_nodes = dag.leaf_nodes
|
||||||
|
if len(leaf_nodes) != 1:
|
||||||
|
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||||
|
task = cast(BaseOperator, leaf_nodes[0])
|
||||||
|
dag_metadata = _parse_metadata(dag)
|
||||||
|
# TODO: Run task with variables
|
||||||
|
variables = request.variables
|
||||||
|
dag_request = request.request
|
||||||
|
|
||||||
|
if isinstance(request.request, CommonLLMHttpRequestBody):
|
||||||
|
incremental = request.request.incremental
|
||||||
|
elif isinstance(request.request, dict):
|
||||||
|
incremental = request.request.get("incremental", False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid request type")
|
||||||
|
|
||||||
|
if default_incremental is not None:
|
||||||
|
incremental = default_incremental
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for output in safe_chat_stream_with_dag_task(
|
||||||
|
task, dag_request, incremental
|
||||||
|
):
|
||||||
|
yield output
|
||||||
|
except HTTPException as e:
|
||||||
|
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
||||||
|
except Exception as e:
|
||||||
|
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||||
|
|
||||||
|
async def _wrapper_chat_stream_flow_str(
|
||||||
|
self, stream_iter: AsyncIterator[ModelOutput]
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
|
||||||
|
async for output in stream_iter:
|
||||||
|
text = output.text
|
||||||
|
if text:
|
||||||
|
text = text.replace("\n", "\\n")
|
||||||
|
if output.error_code != 0:
|
||||||
|
yield f"data:[SERVER_ERROR]{text}\n\n"
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
yield f"data:{text}\n\n"
|
||||||
|
121
dbgpt/serve/flow/service/share_utils.py
Normal file
121
dbgpt/serve/flow/service/share_utils.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import tomlkit
|
||||||
|
from fastapi import UploadFile
|
||||||
|
|
||||||
|
from dbgpt.component import SystemApp
|
||||||
|
from dbgpt.serve.core import blocking_func_to_async
|
||||||
|
|
||||||
|
from ..api.schemas import ServeRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
|
||||||
|
|
||||||
|
zip_buffer = io.BytesIO()
|
||||||
|
flow_name = flow.name
|
||||||
|
flow_label = flow.label
|
||||||
|
flow_description = flow.description
|
||||||
|
dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False)
|
||||||
|
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
|
||||||
|
manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json"
|
||||||
|
readme = f"# {flow_label}\n\n{flow_description}"
|
||||||
|
zip_file.writestr(f"{package_name}/MANIFEST.in", manifest)
|
||||||
|
zip_file.writestr(f"{package_name}/README.md", readme)
|
||||||
|
zip_file.writestr(
|
||||||
|
f"{package_name}/{flow_name}/__init__.py",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
zip_file.writestr(
|
||||||
|
f"{package_name}/{flow_name}/definition/flow_definition.json",
|
||||||
|
dag_json,
|
||||||
|
)
|
||||||
|
dbgpts_toml = tomlkit.document()
|
||||||
|
# Add flow information
|
||||||
|
dbgpts_flow_toml = tomlkit.document()
|
||||||
|
dbgpts_flow_toml.add("label", "Simple Streaming Chat")
|
||||||
|
name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat")
|
||||||
|
name_with_comment.comment("A unique name for all dbgpts")
|
||||||
|
dbgpts_flow_toml.add("name", name_with_comment)
|
||||||
|
|
||||||
|
dbgpts_flow_toml.add("version", "0.1.0")
|
||||||
|
dbgpts_flow_toml.add(
|
||||||
|
"description",
|
||||||
|
flow_description,
|
||||||
|
)
|
||||||
|
dbgpts_flow_toml.add("authors", [])
|
||||||
|
|
||||||
|
definition_type_with_comment = tomlkit.string("json")
|
||||||
|
definition_type_with_comment.comment("How to define the flow, python or json")
|
||||||
|
dbgpts_flow_toml.add("definition_type", definition_type_with_comment)
|
||||||
|
|
||||||
|
dbgpts_toml.add("flow", dbgpts_flow_toml)
|
||||||
|
|
||||||
|
# Add python and json config
|
||||||
|
python_config = tomlkit.table()
|
||||||
|
dbgpts_toml.add("python_config", python_config)
|
||||||
|
|
||||||
|
json_config = tomlkit.table()
|
||||||
|
json_config.add("file_path", "definition/flow_definition.json")
|
||||||
|
json_config.comment("Json config")
|
||||||
|
|
||||||
|
dbgpts_toml.add("json_config", json_config)
|
||||||
|
|
||||||
|
# Transform to string
|
||||||
|
toml_string = tomlkit.dumps(dbgpts_toml)
|
||||||
|
zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string)
|
||||||
|
|
||||||
|
pyproject_toml = tomlkit.document()
|
||||||
|
|
||||||
|
# Add [tool.poetry] section
|
||||||
|
tool_poetry_toml = tomlkit.table()
|
||||||
|
tool_poetry_toml.add("name", package_name)
|
||||||
|
tool_poetry_toml.add("version", "0.1.0")
|
||||||
|
tool_poetry_toml.add("description", "A dbgpts package")
|
||||||
|
tool_poetry_toml.add("authors", [])
|
||||||
|
tool_poetry_toml.add("readme", "README.md")
|
||||||
|
pyproject_toml["tool"] = tomlkit.table()
|
||||||
|
pyproject_toml["tool"]["poetry"] = tool_poetry_toml
|
||||||
|
|
||||||
|
# Add [tool.poetry.dependencies] section
|
||||||
|
dependencies = tomlkit.table()
|
||||||
|
dependencies.add("python", "^3.10")
|
||||||
|
pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies
|
||||||
|
|
||||||
|
# Add [build-system] section
|
||||||
|
build_system = tomlkit.table()
|
||||||
|
build_system.add("requires", ["poetry-core"])
|
||||||
|
build_system.add("build-backend", "poetry.core.masonry.api")
|
||||||
|
pyproject_toml["build-system"] = build_system
|
||||||
|
|
||||||
|
# Transform to string
|
||||||
|
pyproject_toml_string = tomlkit.dumps(pyproject_toml)
|
||||||
|
zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string)
|
||||||
|
zip_buffer.seek(0)
|
||||||
|
return zip_buffer
|
||||||
|
|
||||||
|
|
||||||
|
async def _parse_flow_from_zip_file(
|
||||||
|
file: UploadFile, sys_app: SystemApp
|
||||||
|
) -> ServeRequest:
|
||||||
|
from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path
|
||||||
|
|
||||||
|
filename = file.filename
|
||||||
|
if not filename.endswith(".zip"):
|
||||||
|
raise ValueError("Uploaded file must be a ZIP file")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
zip_path = os.path.join(temp_dir, filename)
|
||||||
|
|
||||||
|
# Save uploaded file to temporary directory
|
||||||
|
async with aiofiles.open(zip_path, "wb") as out_file:
|
||||||
|
while content := await file.read(1024 * 64): # Read in chunks of 64KB
|
||||||
|
await out_file.write(content)
|
||||||
|
flow = await blocking_func_to_async(
|
||||||
|
sys_app, _load_flow_package_from_zip_path, zip_path
|
||||||
|
)
|
||||||
|
return flow
|
@@ -320,14 +320,19 @@ def _load_package_from_path(path: str):
|
|||||||
return parsed_packages
|
return parsed_packages
|
||||||
|
|
||||||
|
|
||||||
def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPackage:
|
def _load_flow_package_from_path(
|
||||||
|
name: str, path: str = INSTALL_DIR, filter_by_name: bool = True
|
||||||
|
) -> FlowPackage:
|
||||||
raw_packages = _load_installed_package(path)
|
raw_packages = _load_installed_package(path)
|
||||||
new_name = name.replace("_", "-")
|
new_name = name.replace("_", "-")
|
||||||
packages = [p for p in raw_packages if p.package == name or p.name == name]
|
if filter_by_name:
|
||||||
if not packages:
|
packages = [p for p in raw_packages if p.package == name or p.name == name]
|
||||||
packages = [
|
if not packages:
|
||||||
p for p in raw_packages if p.package == new_name or p.name == new_name
|
packages = [
|
||||||
]
|
p for p in raw_packages if p.package == new_name or p.name == new_name
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
packages = raw_packages
|
||||||
if not packages:
|
if not packages:
|
||||||
raise ValueError(f"Can't find the package {name} or {new_name}")
|
raise ValueError(f"Can't find the package {name} or {new_name}")
|
||||||
flow_package = _parse_package_metadata(packages[0])
|
flow_package = _parse_package_metadata(packages[0])
|
||||||
@@ -336,6 +341,35 @@ def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPack
|
|||||||
return cast(FlowPackage, flow_package)
|
return cast(FlowPackage, flow_package)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_flow_package_from_zip_path(zip_path: str) -> FlowPanel:
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(temp_dir)
|
||||||
|
package_names = os.listdir(temp_dir)
|
||||||
|
if not package_names:
|
||||||
|
raise ValueError("No package found in the zip file")
|
||||||
|
if len(package_names) > 1:
|
||||||
|
raise ValueError("Only support one package in the zip file")
|
||||||
|
package_name = package_names[0]
|
||||||
|
with open(
|
||||||
|
Path(temp_dir) / package_name / INSTALL_METADATA_FILE, mode="w+"
|
||||||
|
) as f:
|
||||||
|
# Write the metadata
|
||||||
|
import tomlkit
|
||||||
|
|
||||||
|
install_metadata = {
|
||||||
|
"name": package_name,
|
||||||
|
"repo": "local/dbgpts",
|
||||||
|
}
|
||||||
|
tomlkit.dump(install_metadata, f)
|
||||||
|
|
||||||
|
package = _load_flow_package_from_path("", path=temp_dir, filter_by_name=False)
|
||||||
|
return _flow_package_to_flow_panel(package)
|
||||||
|
|
||||||
|
|
||||||
def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel:
|
def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel:
|
||||||
dict_value = {
|
dict_value = {
|
||||||
"name": package.name,
|
"name": package.name,
|
||||||
@@ -345,6 +379,7 @@ def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel:
|
|||||||
"description": package.description,
|
"description": package.description,
|
||||||
"source": package.repo,
|
"source": package.repo,
|
||||||
"define_type": "json",
|
"define_type": "json",
|
||||||
|
"authors": package.authors,
|
||||||
}
|
}
|
||||||
if isinstance(package, FlowJsonPackage):
|
if isinstance(package, FlowJsonPackage):
|
||||||
dict_value["flow_data"] = package.read_definition_json()
|
dict_value["flow_data"] = package.read_definition_json()
|
||||||
|
Reference in New Issue
Block a user