mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-24 19:08:15 +00:00
feat(core): AWEL flow 2.0 backend code (#1879)
Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
@@ -9,7 +9,6 @@ from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.agent import AgentDummyTrigger
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.flow.flow_factory import (
|
||||
FlowCategory,
|
||||
FlowFactory,
|
||||
@@ -34,7 +33,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
@@ -147,7 +146,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise ValueError(
|
||||
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
|
||||
) from e
|
||||
res = self.dao.create(request)
|
||||
self.dao.create(request)
|
||||
# Query from database
|
||||
res = self.get({"uid": request.uid})
|
||||
|
||||
state = request.state
|
||||
try:
|
||||
@@ -574,3 +575,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
return FlowCategory.CHAT_FLOW
|
||||
except Exception:
|
||||
return FlowCategory.COMMON
|
||||
|
||||
async def debug_flow(
|
||||
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Debug the flow.
|
||||
|
||||
Args:
|
||||
request (FlowDebugRequest): The request
|
||||
default_incremental (Optional[bool]): The default incremental configuration
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The output
|
||||
"""
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
|
||||
|
||||
dag = self._flow_factory.build(request.flow)
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||
task = cast(BaseOperator, leaf_nodes[0])
|
||||
dag_metadata = _parse_metadata(dag)
|
||||
# TODO: Run task with variables
|
||||
variables = request.variables
|
||||
dag_request = request.request
|
||||
|
||||
if isinstance(request.request, CommonLLMHttpRequestBody):
|
||||
incremental = request.request.incremental
|
||||
elif isinstance(request.request, dict):
|
||||
incremental = request.request.get("incremental", False)
|
||||
else:
|
||||
raise ValueError("Invalid request type")
|
||||
|
||||
if default_incremental is not None:
|
||||
incremental = default_incremental
|
||||
|
||||
try:
|
||||
async for output in safe_chat_stream_with_dag_task(
|
||||
task, dag_request, incremental
|
||||
):
|
||||
yield output
|
||||
except HTTPException as e:
|
||||
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
||||
except Exception as e:
|
||||
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
|
||||
async def _wrapper_chat_stream_flow_str(
|
||||
self, stream_iter: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
|
||||
async for output in stream_iter:
|
||||
text = output.text
|
||||
if text:
|
||||
text = text.replace("\n", "\\n")
|
||||
if output.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{text}\n\n"
|
||||
break
|
||||
else:
|
||||
yield f"data:{text}\n\n"
|
||||
|
121
dbgpt/serve/flow/service/share_utils.py
Normal file
121
dbgpt/serve/flow/service/share_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import aiofiles
|
||||
import tomlkit
|
||||
from fastapi import UploadFile
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import blocking_func_to_async
|
||||
|
||||
from ..api.schemas import ServeRequest
|
||||
|
||||
|
||||
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
flow_name = flow.name
|
||||
flow_label = flow.label
|
||||
flow_description = flow.description
|
||||
dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False)
|
||||
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
|
||||
manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json"
|
||||
readme = f"# {flow_label}\n\n{flow_description}"
|
||||
zip_file.writestr(f"{package_name}/MANIFEST.in", manifest)
|
||||
zip_file.writestr(f"{package_name}/README.md", readme)
|
||||
zip_file.writestr(
|
||||
f"{package_name}/{flow_name}/__init__.py",
|
||||
"",
|
||||
)
|
||||
zip_file.writestr(
|
||||
f"{package_name}/{flow_name}/definition/flow_definition.json",
|
||||
dag_json,
|
||||
)
|
||||
dbgpts_toml = tomlkit.document()
|
||||
# Add flow information
|
||||
dbgpts_flow_toml = tomlkit.document()
|
||||
dbgpts_flow_toml.add("label", "Simple Streaming Chat")
|
||||
name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat")
|
||||
name_with_comment.comment("A unique name for all dbgpts")
|
||||
dbgpts_flow_toml.add("name", name_with_comment)
|
||||
|
||||
dbgpts_flow_toml.add("version", "0.1.0")
|
||||
dbgpts_flow_toml.add(
|
||||
"description",
|
||||
flow_description,
|
||||
)
|
||||
dbgpts_flow_toml.add("authors", [])
|
||||
|
||||
definition_type_with_comment = tomlkit.string("json")
|
||||
definition_type_with_comment.comment("How to define the flow, python or json")
|
||||
dbgpts_flow_toml.add("definition_type", definition_type_with_comment)
|
||||
|
||||
dbgpts_toml.add("flow", dbgpts_flow_toml)
|
||||
|
||||
# Add python and json config
|
||||
python_config = tomlkit.table()
|
||||
dbgpts_toml.add("python_config", python_config)
|
||||
|
||||
json_config = tomlkit.table()
|
||||
json_config.add("file_path", "definition/flow_definition.json")
|
||||
json_config.comment("Json config")
|
||||
|
||||
dbgpts_toml.add("json_config", json_config)
|
||||
|
||||
# Transform to string
|
||||
toml_string = tomlkit.dumps(dbgpts_toml)
|
||||
zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string)
|
||||
|
||||
pyproject_toml = tomlkit.document()
|
||||
|
||||
# Add [tool.poetry] section
|
||||
tool_poetry_toml = tomlkit.table()
|
||||
tool_poetry_toml.add("name", package_name)
|
||||
tool_poetry_toml.add("version", "0.1.0")
|
||||
tool_poetry_toml.add("description", "A dbgpts package")
|
||||
tool_poetry_toml.add("authors", [])
|
||||
tool_poetry_toml.add("readme", "README.md")
|
||||
pyproject_toml["tool"] = tomlkit.table()
|
||||
pyproject_toml["tool"]["poetry"] = tool_poetry_toml
|
||||
|
||||
# Add [tool.poetry.dependencies] section
|
||||
dependencies = tomlkit.table()
|
||||
dependencies.add("python", "^3.10")
|
||||
pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies
|
||||
|
||||
# Add [build-system] section
|
||||
build_system = tomlkit.table()
|
||||
build_system.add("requires", ["poetry-core"])
|
||||
build_system.add("build-backend", "poetry.core.masonry.api")
|
||||
pyproject_toml["build-system"] = build_system
|
||||
|
||||
# Transform to string
|
||||
pyproject_toml_string = tomlkit.dumps(pyproject_toml)
|
||||
zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string)
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer
|
||||
|
||||
|
||||
async def _parse_flow_from_zip_file(
|
||||
file: UploadFile, sys_app: SystemApp
|
||||
) -> ServeRequest:
|
||||
from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path
|
||||
|
||||
filename = file.filename
|
||||
if not filename.endswith(".zip"):
|
||||
raise ValueError("Uploaded file must be a ZIP file")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, filename)
|
||||
|
||||
# Save uploaded file to temporary directory
|
||||
async with aiofiles.open(zip_path, "wb") as out_file:
|
||||
while content := await file.read(1024 * 64): # Read in chunks of 64KB
|
||||
await out_file.write(content)
|
||||
flow = await blocking_func_to_async(
|
||||
sys_app, _load_flow_package_from_zip_path, zip_path
|
||||
)
|
||||
return flow
|
152
dbgpt/serve/flow/service/variables_service.py
Normal file
152
dbgpt/serve/flow/service/variables_service.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.core.interface.variables import StorageVariables, VariablesProvider
|
||||
from dbgpt.serve.core import BaseService
|
||||
|
||||
from ..api.schemas import VariablesRequest, VariablesResponse
|
||||
from ..config import (
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
SERVE_VARIABLES_SERVICE_COMPONENT_NAME,
|
||||
ServeConfig,
|
||||
)
|
||||
from ..models.models import VariablesDao, VariablesEntity
|
||||
|
||||
|
||||
class VariablesService(
|
||||
BaseService[VariablesEntity, VariablesRequest, VariablesResponse]
|
||||
):
|
||||
"""Variables service"""
|
||||
|
||||
name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: VariablesDao = dao
|
||||
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
super().init_app(system_app)
|
||||
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = self._dao or VariablesDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
def dao(self) -> VariablesDao:
|
||||
"""Returns the internal DAO."""
|
||||
return self._dao
|
||||
|
||||
@property
|
||||
def variables_provider(self) -> VariablesProvider:
|
||||
"""Returns the internal VariablesProvider.
|
||||
|
||||
Returns:
|
||||
VariablesProvider: The internal VariablesProvider
|
||||
"""
|
||||
variables_provider = VariablesProvider.get_instance(
|
||||
self._system_app, default_component=None
|
||||
)
|
||||
if variables_provider:
|
||||
return variables_provider
|
||||
else:
|
||||
from ..serve import Serve
|
||||
|
||||
variables_provider = Serve.get_instance(self._system_app).variables_provider
|
||||
self._system_app.register_instance(variables_provider)
|
||||
return variables_provider
|
||||
|
||||
@property
|
||||
def config(self) -> ServeConfig:
|
||||
"""Returns the internal ServeConfig."""
|
||||
return self._serve_config
|
||||
|
||||
def create(self, request: VariablesRequest) -> VariablesResponse:
|
||||
"""Create a new entity
|
||||
|
||||
Args:
|
||||
request (VariablesRequest): The request
|
||||
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
variables = StorageVariables(
|
||||
key=request.key,
|
||||
name=request.name,
|
||||
label=request.label,
|
||||
value=request.value,
|
||||
value_type=request.value_type,
|
||||
category=request.category,
|
||||
scope=request.scope,
|
||||
scope_key=request.scope_key,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
enabled=1 if request.enabled else 0,
|
||||
description=request.description,
|
||||
)
|
||||
self.variables_provider.save(variables)
|
||||
query = {
|
||||
"key": request.key,
|
||||
"name": request.name,
|
||||
"scope": request.scope,
|
||||
"scope_key": request.scope_key,
|
||||
"sys_code": request.sys_code,
|
||||
"user_name": request.user_name,
|
||||
"enabled": request.enabled,
|
||||
}
|
||||
return self.dao.get_one(query)
|
||||
|
||||
def update(self, _: int, request: VariablesRequest) -> VariablesResponse:
|
||||
"""Update variables.
|
||||
|
||||
Args:
|
||||
request (VariablesRequest): The request
|
||||
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
variables = StorageVariables(
|
||||
key=request.key,
|
||||
name=request.name,
|
||||
label=request.label,
|
||||
value=request.value,
|
||||
value_type=request.value_type,
|
||||
category=request.category,
|
||||
scope=request.scope,
|
||||
scope_key=request.scope_key,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
enabled=1 if request.enabled else 0,
|
||||
description=request.description,
|
||||
)
|
||||
exist_value = self.variables_provider.get(
|
||||
variables.identifier.str_identifier, None
|
||||
)
|
||||
if exist_value is None:
|
||||
raise ValueError(
|
||||
f"Variable {variables.identifier.str_identifier} not found"
|
||||
)
|
||||
self.variables_provider.save(variables)
|
||||
query = {
|
||||
"key": request.key,
|
||||
"name": request.name,
|
||||
"scope": request.scope,
|
||||
"scope_key": request.scope_key,
|
||||
"sys_code": request.sys_code,
|
||||
"user_name": request.user_name,
|
||||
"enabled": request.enabled,
|
||||
}
|
||||
return self.dao.get_one(query)
|
||||
|
||||
def list_all_variables(self, category: str = "common") -> List[VariablesResponse]:
|
||||
"""List all variables."""
|
||||
return self.dao.get_list({"enabled": True, "category": category})
|
Reference in New Issue
Block a user