mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 13:58:58 +00:00
feat(core): Add debug and export/import for AWEL flow
This commit is contained in:
@@ -197,7 +197,7 @@ class DAGManager(BaseComponent):
|
||||
return self._dag_metadata_map.get(dag.dag_id)
|
||||
|
||||
|
||||
def _parse_metadata(dag: DAG):
|
||||
def _parse_metadata(dag: DAG) -> DAGMetadata:
|
||||
from ..util.chat_util import _is_sse_output
|
||||
|
||||
metadata = DAGMetadata()
|
||||
|
@@ -4,7 +4,7 @@ import logging
|
||||
import uuid
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@@ -166,6 +166,59 @@ class FlowData(BaseModel):
|
||||
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
|
||||
|
||||
|
||||
class VariablesRequest(BaseModel):
|
||||
"""Variable request model.
|
||||
|
||||
For creating a new variable in the DB-GPT.
|
||||
"""
|
||||
|
||||
key: str = Field(
|
||||
...,
|
||||
description="The key of the variable to create",
|
||||
examples=["dbgpt.model.openai.api_key"],
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="The name of the variable to create",
|
||||
examples=["my_first_openai_key"],
|
||||
)
|
||||
label: str = Field(
|
||||
...,
|
||||
description="The label of the variable to create",
|
||||
examples=["My First OpenAI Key"],
|
||||
)
|
||||
value: Any = Field(
|
||||
..., description="The value of the variable to create", examples=["1234567890"]
|
||||
)
|
||||
value_type: Literal["str", "int", "float", "bool"] = Field(
|
||||
"str",
|
||||
description="The type of the value of the variable to create",
|
||||
examples=["str", "int", "float", "bool"],
|
||||
)
|
||||
category: Literal["common", "secret"] = Field(
|
||||
...,
|
||||
description="The category of the variable to create",
|
||||
examples=["common"],
|
||||
)
|
||||
scope: str = Field(
|
||||
...,
|
||||
description="The scope of the variable to create",
|
||||
examples=["global"],
|
||||
)
|
||||
scope_key: Optional[str] = Field(
|
||||
...,
|
||||
description="The scope key of the variable to create",
|
||||
examples=["dbgpt"],
|
||||
)
|
||||
enabled: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether the variable is enabled",
|
||||
examples=[True],
|
||||
)
|
||||
user_name: Optional[str] = Field(None, description="User name")
|
||||
sys_code: Optional[str] = Field(None, description="System code")
|
||||
|
||||
|
||||
class State(str, Enum):
|
||||
"""State of a flow panel."""
|
||||
|
||||
@@ -356,6 +409,12 @@ class FlowPanel(BaseModel):
|
||||
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
|
||||
default=None, description="The metadata of the flow"
|
||||
)
|
||||
variables: Optional[List[VariablesRequest]] = Field(
|
||||
default=None, description="The variables of the flow"
|
||||
)
|
||||
authors: Optional[List[str]] = Field(
|
||||
default=None, description="The authors of the flow"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@@ -334,7 +334,8 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_output = await self._input_source.read(curr_task_ctx)
|
||||
curr_task_ctx.set_task_output(task_output)
|
||||
new_task_output: TaskOutput[OUT] = await task_output.map(self.map)
|
||||
curr_task_ctx.set_task_output(new_task_output)
|
||||
return task_output
|
||||
|
||||
@classmethod
|
||||
@@ -342,6 +343,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
"""Create a dummy InputOperator with a given input value."""
|
||||
return cls(input_source=InputSource.from_data(dummy_data), **kwargs)
|
||||
|
||||
async def map(self, input_data: OUT) -> OUT:
|
||||
"""Map the input data to a new value."""
|
||||
return input_data
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
|
||||
"""Operator node that triggers the DAG to run."""
|
||||
|
@@ -87,7 +87,9 @@ class HttpTriggerMetadata(TriggerMetadata):
|
||||
|
||||
path: str = Field(..., description="The path of the trigger")
|
||||
methods: List[str] = Field(..., description="The methods of the trigger")
|
||||
|
||||
trigger_mode: str = Field(
|
||||
default="command", description="The mode of the trigger, command or chat"
|
||||
)
|
||||
trigger_type: Optional[str] = Field(
|
||||
default="http", description="The type of the trigger"
|
||||
)
|
||||
@@ -477,7 +479,9 @@ class HttpTrigger(Trigger):
|
||||
)(dynamic_route_function)
|
||||
|
||||
logger.info(f"Mount http trigger success, path: {path}")
|
||||
return HttpTriggerMetadata(path=path, methods=self._methods)
|
||||
return HttpTriggerMetadata(
|
||||
path=path, methods=self._methods, trigger_mode=self._trigger_mode()
|
||||
)
|
||||
|
||||
def mount_to_app(
|
||||
self, app: "FastAPI", global_prefix: Optional[str] = None
|
||||
@@ -512,7 +516,9 @@ class HttpTrigger(Trigger):
|
||||
app.openapi_schema = None
|
||||
app.middleware_stack = None
|
||||
logger.info(f"Mount http trigger success, path: {path}")
|
||||
return HttpTriggerMetadata(path=path, methods=self._methods)
|
||||
return HttpTriggerMetadata(
|
||||
path=path, methods=self._methods, trigger_mode=self._trigger_mode()
|
||||
)
|
||||
|
||||
def remove_from_app(
|
||||
self, app: "FastAPI", global_prefix: Optional[str] = None
|
||||
@@ -537,6 +543,36 @@ class HttpTrigger(Trigger):
|
||||
# TODO, remove with path and methods
|
||||
del app_router.routes[i]
|
||||
|
||||
def _trigger_mode(self) -> str:
|
||||
if (
|
||||
self._req_body
|
||||
and isinstance(self._req_body, type)
|
||||
and issubclass(self._req_body, CommonLLMHttpRequestBody)
|
||||
):
|
||||
return "chat"
|
||||
return "command"
|
||||
|
||||
async def map(self, input_data: Any) -> Any:
|
||||
"""Map the input data.
|
||||
|
||||
Do some transformation for the input data.
|
||||
|
||||
Args:
|
||||
input_data (Any): The input data from caller.
|
||||
|
||||
Returns:
|
||||
Any: The mapped data.
|
||||
"""
|
||||
if not self._req_body or not input_data:
|
||||
return await super().map(input_data)
|
||||
if (
|
||||
isinstance(self._req_body, type)
|
||||
and issubclass(self._req_body, BaseModel)
|
||||
and isinstance(input_data, dict)
|
||||
):
|
||||
return self._req_body(**input_data)
|
||||
return await super().map(input_data)
|
||||
|
||||
def _create_route_func(self):
|
||||
from inspect import Parameter, Signature
|
||||
from typing import get_type_hints
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import io
|
||||
import json
|
||||
from functools import cache
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
||||
@@ -15,6 +17,7 @@ from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from ..service.variables_service import VariablesService
|
||||
from .schemas import (
|
||||
FlowDebugRequest,
|
||||
RefreshNodeRequest,
|
||||
ServeRequest,
|
||||
ServerResponse,
|
||||
@@ -352,10 +355,116 @@ async def update_variables(
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.post("/flow/debug")
|
||||
async def debug():
|
||||
"""Debug the flow."""
|
||||
# TODO: Implement the debug endpoint
|
||||
@router.post("/flow/debug", dependencies=[Depends(check_api_key)])
|
||||
async def debug_flow(
|
||||
flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service)
|
||||
):
|
||||
"""Run the flow in debug mode."""
|
||||
# Return the no-incremental stream by default
|
||||
stream_iter = service.debug_flow(flow_debug_request, default_incremental=False)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
service._wrapper_chat_stream_flow_str(stream_iter),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)])
|
||||
async def export_flow(
|
||||
uid: str,
|
||||
export_type: Literal["json", "dbgpts"] = Query(
|
||||
"json", description="export type(json or dbgpts)"
|
||||
),
|
||||
format: Literal["file", "json"] = Query(
|
||||
"file", description="response format(file or json)"
|
||||
),
|
||||
file_name: Optional[str] = Query(default=None, description="file name to export"),
|
||||
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
"""Export the flow to a file."""
|
||||
flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code})
|
||||
if not flow:
|
||||
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
|
||||
package_name = flow.name.replace("_", "-")
|
||||
file_name = file_name or package_name
|
||||
if export_type == "json":
|
||||
flow_dict = {"flow": flow.to_dict()}
|
||||
if format == "json":
|
||||
return JSONResponse(content=flow_dict)
|
||||
else:
|
||||
# Return the json file
|
||||
return StreamingResponse(
|
||||
io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")),
|
||||
media_type="application/file",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment;filename={file_name}.json"
|
||||
},
|
||||
)
|
||||
|
||||
elif export_type == "dbgpts":
|
||||
from ..service.share_utils import _generate_dbgpts_zip
|
||||
|
||||
if format == "json":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="json response is not supported for dbgpts"
|
||||
)
|
||||
|
||||
zip_buffer = await blocking_func_to_async(
|
||||
global_system_app, _generate_dbgpts_zip, package_name, flow
|
||||
)
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/x-zip-compressed",
|
||||
headers={"Content-Disposition": f"attachment;filename={file_name}.zip"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/flow/import",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def import_flow(
|
||||
file: UploadFile = File(...),
|
||||
save_flow: bool = Query(
|
||||
False, description="Whether to save the flow after importing"
|
||||
),
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
"""Import the flow from a file."""
|
||||
filename = file.filename
|
||||
file_extension = filename.split(".")[-1].lower()
|
||||
if file_extension == "json":
|
||||
# Handle json file
|
||||
json_content = await file.read()
|
||||
json_dict = json.loads(json_content)
|
||||
if "flow" not in json_dict:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="invalid json file, missing 'flow' key"
|
||||
)
|
||||
flow = ServeRequest.parse_obj(json_dict["flow"])
|
||||
elif file_extension == "zip":
|
||||
from ..service.share_utils import _parse_flow_from_zip_file
|
||||
|
||||
# Handle zip file
|
||||
flow = await _parse_flow_from_zip_file(file, global_system_app)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"invalid file extension {file_extension}"
|
||||
)
|
||||
if save_flow:
|
||||
return Result.succ(service.create_and_save_dag(flow))
|
||||
else:
|
||||
return Result.succ(flow)
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
|
@@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest
|
||||
from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
@@ -18,59 +18,6 @@ class ServerResponse(FlowPanel):
|
||||
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
|
||||
|
||||
|
||||
class VariablesRequest(BaseModel):
|
||||
"""Variable request model.
|
||||
|
||||
For creating a new variable in the DB-GPT.
|
||||
"""
|
||||
|
||||
key: str = Field(
|
||||
...,
|
||||
description="The key of the variable to create",
|
||||
examples=["dbgpt.model.openai.api_key"],
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="The name of the variable to create",
|
||||
examples=["my_first_openai_key"],
|
||||
)
|
||||
label: str = Field(
|
||||
...,
|
||||
description="The label of the variable to create",
|
||||
examples=["My First OpenAI Key"],
|
||||
)
|
||||
value: Any = Field(
|
||||
..., description="The value of the variable to create", examples=["1234567890"]
|
||||
)
|
||||
value_type: Literal["str", "int", "float", "bool"] = Field(
|
||||
"str",
|
||||
description="The type of the value of the variable to create",
|
||||
examples=["str", "int", "float", "bool"],
|
||||
)
|
||||
category: Literal["common", "secret"] = Field(
|
||||
...,
|
||||
description="The category of the variable to create",
|
||||
examples=["common"],
|
||||
)
|
||||
scope: str = Field(
|
||||
...,
|
||||
description="The scope of the variable to create",
|
||||
examples=["global"],
|
||||
)
|
||||
scope_key: Optional[str] = Field(
|
||||
...,
|
||||
description="The scope key of the variable to create",
|
||||
examples=["dbgpt"],
|
||||
)
|
||||
enabled: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether the variable is enabled",
|
||||
examples=[True],
|
||||
)
|
||||
user_name: Optional[str] = Field(None, description="User name")
|
||||
sys_code: Optional[str] = Field(None, description="System code")
|
||||
|
||||
|
||||
class VariablesResponse(VariablesRequest):
|
||||
"""Variable response model."""
|
||||
|
||||
|
@@ -9,7 +9,6 @@ from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.agent import AgentDummyTrigger
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.flow.flow_factory import (
|
||||
FlowCategory,
|
||||
FlowFactory,
|
||||
@@ -34,7 +33,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
@@ -147,7 +146,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise ValueError(
|
||||
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
|
||||
) from e
|
||||
res = self.dao.create(request)
|
||||
self.dao.create(request)
|
||||
# Query from database
|
||||
res = self.get({"uid": request.uid})
|
||||
|
||||
state = request.state
|
||||
try:
|
||||
@@ -574,3 +575,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
return FlowCategory.CHAT_FLOW
|
||||
except Exception:
|
||||
return FlowCategory.COMMON
|
||||
|
||||
async def debug_flow(
|
||||
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Debug the flow.
|
||||
|
||||
Args:
|
||||
request (FlowDebugRequest): The request
|
||||
default_incremental (Optional[bool]): The default incremental configuration
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The output
|
||||
"""
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
|
||||
|
||||
dag = self._flow_factory.build(request.flow)
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||
task = cast(BaseOperator, leaf_nodes[0])
|
||||
dag_metadata = _parse_metadata(dag)
|
||||
# TODO: Run task with variables
|
||||
variables = request.variables
|
||||
dag_request = request.request
|
||||
|
||||
if isinstance(request.request, CommonLLMHttpRequestBody):
|
||||
incremental = request.request.incremental
|
||||
elif isinstance(request.request, dict):
|
||||
incremental = request.request.get("incremental", False)
|
||||
else:
|
||||
raise ValueError("Invalid request type")
|
||||
|
||||
if default_incremental is not None:
|
||||
incremental = default_incremental
|
||||
|
||||
try:
|
||||
async for output in safe_chat_stream_with_dag_task(
|
||||
task, dag_request, incremental
|
||||
):
|
||||
yield output
|
||||
except HTTPException as e:
|
||||
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
||||
except Exception as e:
|
||||
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
|
||||
async def _wrapper_chat_stream_flow_str(
|
||||
self, stream_iter: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
|
||||
async for output in stream_iter:
|
||||
text = output.text
|
||||
if text:
|
||||
text = text.replace("\n", "\\n")
|
||||
if output.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{text}\n\n"
|
||||
break
|
||||
else:
|
||||
yield f"data:{text}\n\n"
|
||||
|
121
dbgpt/serve/flow/service/share_utils.py
Normal file
121
dbgpt/serve/flow/service/share_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import aiofiles
|
||||
import tomlkit
|
||||
from fastapi import UploadFile
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import blocking_func_to_async
|
||||
|
||||
from ..api.schemas import ServeRequest
|
||||
|
||||
|
||||
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
flow_name = flow.name
|
||||
flow_label = flow.label
|
||||
flow_description = flow.description
|
||||
dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False)
|
||||
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
|
||||
manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json"
|
||||
readme = f"# {flow_label}\n\n{flow_description}"
|
||||
zip_file.writestr(f"{package_name}/MANIFEST.in", manifest)
|
||||
zip_file.writestr(f"{package_name}/README.md", readme)
|
||||
zip_file.writestr(
|
||||
f"{package_name}/{flow_name}/__init__.py",
|
||||
"",
|
||||
)
|
||||
zip_file.writestr(
|
||||
f"{package_name}/{flow_name}/definition/flow_definition.json",
|
||||
dag_json,
|
||||
)
|
||||
dbgpts_toml = tomlkit.document()
|
||||
# Add flow information
|
||||
dbgpts_flow_toml = tomlkit.document()
|
||||
dbgpts_flow_toml.add("label", "Simple Streaming Chat")
|
||||
name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat")
|
||||
name_with_comment.comment("A unique name for all dbgpts")
|
||||
dbgpts_flow_toml.add("name", name_with_comment)
|
||||
|
||||
dbgpts_flow_toml.add("version", "0.1.0")
|
||||
dbgpts_flow_toml.add(
|
||||
"description",
|
||||
flow_description,
|
||||
)
|
||||
dbgpts_flow_toml.add("authors", [])
|
||||
|
||||
definition_type_with_comment = tomlkit.string("json")
|
||||
definition_type_with_comment.comment("How to define the flow, python or json")
|
||||
dbgpts_flow_toml.add("definition_type", definition_type_with_comment)
|
||||
|
||||
dbgpts_toml.add("flow", dbgpts_flow_toml)
|
||||
|
||||
# Add python and json config
|
||||
python_config = tomlkit.table()
|
||||
dbgpts_toml.add("python_config", python_config)
|
||||
|
||||
json_config = tomlkit.table()
|
||||
json_config.add("file_path", "definition/flow_definition.json")
|
||||
json_config.comment("Json config")
|
||||
|
||||
dbgpts_toml.add("json_config", json_config)
|
||||
|
||||
# Transform to string
|
||||
toml_string = tomlkit.dumps(dbgpts_toml)
|
||||
zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string)
|
||||
|
||||
pyproject_toml = tomlkit.document()
|
||||
|
||||
# Add [tool.poetry] section
|
||||
tool_poetry_toml = tomlkit.table()
|
||||
tool_poetry_toml.add("name", package_name)
|
||||
tool_poetry_toml.add("version", "0.1.0")
|
||||
tool_poetry_toml.add("description", "A dbgpts package")
|
||||
tool_poetry_toml.add("authors", [])
|
||||
tool_poetry_toml.add("readme", "README.md")
|
||||
pyproject_toml["tool"] = tomlkit.table()
|
||||
pyproject_toml["tool"]["poetry"] = tool_poetry_toml
|
||||
|
||||
# Add [tool.poetry.dependencies] section
|
||||
dependencies = tomlkit.table()
|
||||
dependencies.add("python", "^3.10")
|
||||
pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies
|
||||
|
||||
# Add [build-system] section
|
||||
build_system = tomlkit.table()
|
||||
build_system.add("requires", ["poetry-core"])
|
||||
build_system.add("build-backend", "poetry.core.masonry.api")
|
||||
pyproject_toml["build-system"] = build_system
|
||||
|
||||
# Transform to string
|
||||
pyproject_toml_string = tomlkit.dumps(pyproject_toml)
|
||||
zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string)
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer
|
||||
|
||||
|
||||
async def _parse_flow_from_zip_file(
|
||||
file: UploadFile, sys_app: SystemApp
|
||||
) -> ServeRequest:
|
||||
from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path
|
||||
|
||||
filename = file.filename
|
||||
if not filename.endswith(".zip"):
|
||||
raise ValueError("Uploaded file must be a ZIP file")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, filename)
|
||||
|
||||
# Save uploaded file to temporary directory
|
||||
async with aiofiles.open(zip_path, "wb") as out_file:
|
||||
while content := await file.read(1024 * 64): # Read in chunks of 64KB
|
||||
await out_file.write(content)
|
||||
flow = await blocking_func_to_async(
|
||||
sys_app, _load_flow_package_from_zip_path, zip_path
|
||||
)
|
||||
return flow
|
@@ -328,14 +328,19 @@ def _load_package_from_path(path: str):
|
||||
return parsed_packages
|
||||
|
||||
|
||||
def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPackage:
|
||||
def _load_flow_package_from_path(
|
||||
name: str, path: str = INSTALL_DIR, filter_by_name: bool = True
|
||||
) -> FlowPackage:
|
||||
raw_packages = _load_installed_package(path)
|
||||
new_name = name.replace("_", "-")
|
||||
packages = [p for p in raw_packages if p.package == name or p.name == name]
|
||||
if not packages:
|
||||
packages = [
|
||||
p for p in raw_packages if p.package == new_name or p.name == new_name
|
||||
]
|
||||
if filter_by_name:
|
||||
packages = [p for p in raw_packages if p.package == name or p.name == name]
|
||||
if not packages:
|
||||
packages = [
|
||||
p for p in raw_packages if p.package == new_name or p.name == new_name
|
||||
]
|
||||
else:
|
||||
packages = raw_packages
|
||||
if not packages:
|
||||
raise ValueError(f"Can't find the package {name} or {new_name}")
|
||||
flow_package = _parse_package_metadata(packages[0])
|
||||
@@ -344,6 +349,35 @@ def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPack
|
||||
return cast(FlowPackage, flow_package)
|
||||
|
||||
|
||||
def _load_flow_package_from_zip_path(zip_path: str) -> FlowPanel:
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
package_names = os.listdir(temp_dir)
|
||||
if not package_names:
|
||||
raise ValueError("No package found in the zip file")
|
||||
if len(package_names) > 1:
|
||||
raise ValueError("Only support one package in the zip file")
|
||||
package_name = package_names[0]
|
||||
with open(
|
||||
Path(temp_dir) / package_name / INSTALL_METADATA_FILE, mode="w+"
|
||||
) as f:
|
||||
# Write the metadata
|
||||
import tomlkit
|
||||
|
||||
install_metadata = {
|
||||
"name": package_name,
|
||||
"repo": "local/dbgpts",
|
||||
}
|
||||
tomlkit.dump(install_metadata, f)
|
||||
|
||||
package = _load_flow_package_from_path("", path=temp_dir, filter_by_name=False)
|
||||
return _flow_package_to_flow_panel(package)
|
||||
|
||||
|
||||
def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel:
|
||||
dict_value = {
|
||||
"name": package.name,
|
||||
@@ -353,6 +387,7 @@ def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel:
|
||||
"description": package.description,
|
||||
"source": package.repo,
|
||||
"define_type": "json",
|
||||
"authors": package.authors,
|
||||
}
|
||||
if isinstance(package, FlowJsonPackage):
|
||||
dict_value["flow_data"] = package.read_definition_json()
|
||||
|
Reference in New Issue
Block a user