feat(core): AWEL flow 2.0 backend code (#1879)

Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
Fangyin Cheng
2024-08-23 14:57:54 +08:00
committed by GitHub
parent 3a32344380
commit 9502251c08
67 changed files with 8289 additions and 190 deletions

View File

@@ -9,7 +9,6 @@ from dbgpt._private.pydantic import model_to_json
from dbgpt.agent import AgentDummyTrigger
from dbgpt.component import SystemApp
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import (
FlowCategory,
FlowFactory,
@@ -34,7 +33,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from dbgpt.util.dbgpts.loader import DBGPTsLoader
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
@@ -147,7 +146,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
raise ValueError(
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
) from e
res = self.dao.create(request)
self.dao.create(request)
# Query from database
res = self.get({"uid": request.uid})
state = request.state
try:
@@ -574,3 +575,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON
async def debug_flow(
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
) -> AsyncIterator[ModelOutput]:
"""Debug the flow.
Args:
request (FlowDebugRequest): The request
default_incremental (Optional[bool]): The default incremental configuration
Returns:
AsyncIterator[ModelOutput]: The output
"""
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
dag = self._flow_factory.build(request.flow)
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
task = cast(BaseOperator, leaf_nodes[0])
dag_metadata = _parse_metadata(dag)
# TODO: Run task with variables
variables = request.variables
dag_request = request.request
if isinstance(request.request, CommonLLMHttpRequestBody):
incremental = request.request.incremental
elif isinstance(request.request, dict):
incremental = request.request.get("incremental", False)
else:
raise ValueError("Invalid request type")
if default_incremental is not None:
incremental = default_incremental
try:
async for output in safe_chat_stream_with_dag_task(
task, dag_request, incremental
):
yield output
except HTTPException as e:
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def _wrapper_chat_stream_flow_str(
self, stream_iter: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async for output in stream_iter:
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"

View File

@@ -0,0 +1,121 @@
import io
import json
import os
import tempfile
import zipfile
import aiofiles
import tomlkit
from fastapi import UploadFile
from dbgpt.component import SystemApp
from dbgpt.serve.core import blocking_func_to_async
from ..api.schemas import ServeRequest
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
zip_buffer = io.BytesIO()
flow_name = flow.name
flow_label = flow.label
flow_description = flow.description
dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False)
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json"
readme = f"# {flow_label}\n\n{flow_description}"
zip_file.writestr(f"{package_name}/MANIFEST.in", manifest)
zip_file.writestr(f"{package_name}/README.md", readme)
zip_file.writestr(
f"{package_name}/{flow_name}/__init__.py",
"",
)
zip_file.writestr(
f"{package_name}/{flow_name}/definition/flow_definition.json",
dag_json,
)
dbgpts_toml = tomlkit.document()
# Add flow information
dbgpts_flow_toml = tomlkit.document()
dbgpts_flow_toml.add("label", "Simple Streaming Chat")
name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat")
name_with_comment.comment("A unique name for all dbgpts")
dbgpts_flow_toml.add("name", name_with_comment)
dbgpts_flow_toml.add("version", "0.1.0")
dbgpts_flow_toml.add(
"description",
flow_description,
)
dbgpts_flow_toml.add("authors", [])
definition_type_with_comment = tomlkit.string("json")
definition_type_with_comment.comment("How to define the flow, python or json")
dbgpts_flow_toml.add("definition_type", definition_type_with_comment)
dbgpts_toml.add("flow", dbgpts_flow_toml)
# Add python and json config
python_config = tomlkit.table()
dbgpts_toml.add("python_config", python_config)
json_config = tomlkit.table()
json_config.add("file_path", "definition/flow_definition.json")
json_config.comment("Json config")
dbgpts_toml.add("json_config", json_config)
# Transform to string
toml_string = tomlkit.dumps(dbgpts_toml)
zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string)
pyproject_toml = tomlkit.document()
# Add [tool.poetry] section
tool_poetry_toml = tomlkit.table()
tool_poetry_toml.add("name", package_name)
tool_poetry_toml.add("version", "0.1.0")
tool_poetry_toml.add("description", "A dbgpts package")
tool_poetry_toml.add("authors", [])
tool_poetry_toml.add("readme", "README.md")
pyproject_toml["tool"] = tomlkit.table()
pyproject_toml["tool"]["poetry"] = tool_poetry_toml
# Add [tool.poetry.dependencies] section
dependencies = tomlkit.table()
dependencies.add("python", "^3.10")
pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies
# Add [build-system] section
build_system = tomlkit.table()
build_system.add("requires", ["poetry-core"])
build_system.add("build-backend", "poetry.core.masonry.api")
pyproject_toml["build-system"] = build_system
# Transform to string
pyproject_toml_string = tomlkit.dumps(pyproject_toml)
zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string)
zip_buffer.seek(0)
return zip_buffer
async def _parse_flow_from_zip_file(
file: UploadFile, sys_app: SystemApp
) -> ServeRequest:
from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path
filename = file.filename
if not filename.endswith(".zip"):
raise ValueError("Uploaded file must be a ZIP file")
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, filename)
# Save uploaded file to temporary directory
async with aiofiles.open(zip_path, "wb") as out_file:
while content := await file.read(1024 * 64): # Read in chunks of 64KB
await out_file.write(content)
flow = await blocking_func_to_async(
sys_app, _load_flow_package_from_zip_path, zip_path
)
return flow

View File

@@ -0,0 +1,152 @@
from typing import List, Optional
from dbgpt import SystemApp
from dbgpt.core.interface.variables import StorageVariables, VariablesProvider
from dbgpt.serve.core import BaseService
from ..api.schemas import VariablesRequest, VariablesResponse
from ..config import (
SERVE_CONFIG_KEY_PREFIX,
SERVE_VARIABLES_SERVICE_COMPONENT_NAME,
ServeConfig,
)
from ..models.models import VariablesDao, VariablesEntity
class VariablesService(
BaseService[VariablesEntity, VariablesRequest, VariablesResponse]
):
"""Variables service"""
name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: VariablesDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or VariablesDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> VariablesDao:
"""Returns the internal DAO."""
return self._dao
@property
def variables_provider(self) -> VariablesProvider:
"""Returns the internal VariablesProvider.
Returns:
VariablesProvider: The internal VariablesProvider
"""
variables_provider = VariablesProvider.get_instance(
self._system_app, default_component=None
)
if variables_provider:
return variables_provider
else:
from ..serve import Serve
variables_provider = Serve.get_instance(self._system_app).variables_provider
self._system_app.register_instance(variables_provider)
return variables_provider
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: VariablesRequest) -> VariablesResponse:
"""Create a new entity
Args:
request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
variables = StorageVariables(
key=request.key,
name=request.name,
label=request.label,
value=request.value,
value_type=request.value_type,
category=request.category,
scope=request.scope,
scope_key=request.scope_key,
user_name=request.user_name,
sys_code=request.sys_code,
enabled=1 if request.enabled else 0,
description=request.description,
)
self.variables_provider.save(variables)
query = {
"key": request.key,
"name": request.name,
"scope": request.scope,
"scope_key": request.scope_key,
"sys_code": request.sys_code,
"user_name": request.user_name,
"enabled": request.enabled,
}
return self.dao.get_one(query)
def update(self, _: int, request: VariablesRequest) -> VariablesResponse:
"""Update variables.
Args:
request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
variables = StorageVariables(
key=request.key,
name=request.name,
label=request.label,
value=request.value,
value_type=request.value_type,
category=request.category,
scope=request.scope,
scope_key=request.scope_key,
user_name=request.user_name,
sys_code=request.sys_code,
enabled=1 if request.enabled else 0,
description=request.description,
)
exist_value = self.variables_provider.get(
variables.identifier.str_identifier, None
)
if exist_value is None:
raise ValueError(
f"Variable {variables.identifier.str_identifier} not found"
)
self.variables_provider.save(variables)
query = {
"key": request.key,
"name": request.name,
"scope": request.scope,
"scope_key": request.scope_key,
"sys_code": request.sys_code,
"user_name": request.user_name,
"enabled": request.enabled,
}
return self.dao.get_one(query)
def list_all_variables(self, category: str = "common") -> List[VariablesResponse]:
"""List all variables."""
return self.dao.get_list({"enabled": True, "category": category})