DB-GPT/dbgpt/serve/flow/service/service.py
2024-08-30 11:27:02 +08:00

700 lines
25 KiB
Python

import json
import logging
import os
from typing import AsyncIterator, List, Optional, cast
import schedule
from fastapi import HTTPException
from dbgpt._private.pydantic import model_to_json
from dbgpt.agent import AgentDummyTrigger
from dbgpt.component import SystemApp
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.flow.flow_factory import (
FlowCategory,
FlowFactory,
State,
fill_flow_panel,
)
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.awel.util.chat_util import (
is_chat_flow_type,
safe_chat_stream_with_dag_task,
safe_chat_with_dag_task,
)
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.core.schema.api import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from dbgpt.serve.core import BaseService, blocking_func_to_async
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from dbgpt.util.dbgpts.loader import DBGPTsLoader
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
logger = logging.getLogger(__name__)
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for Flow"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
self._flow_factory: FlowFactory = FlowFactory()
self._dbgpts_loader: Optional[DBGPTsLoader] = None
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
self._dbgpts_loader = system_app.get_component(
DBGPTsLoader.name,
DBGPTsLoader,
or_register_component=DBGPTsLoader,
load_dbgpts_interval=self._serve_config.load_dbgpts_interval,
)
def before_start(self):
"""Execute before the application starts"""
super().before_start()
self._pre_load_dag_from_db()
self._pre_load_dag_from_dbgpts()
def after_start(self):
"""Execute after the application starts"""
self.load_dag_from_db()
self.load_dag_from_dbgpts(is_first_load=True)
schedule.every(self._serve_config.load_dbgpts_interval).seconds.do(
self.load_dag_from_dbgpts
)
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def dbgpts_loader(self) -> DBGPTsLoader:
"""Returns the internal DBGPTsLoader."""
if self._dbgpts_loader is None:
raise ValueError("DBGPTsLoader is not initialized")
return self._dbgpts_loader
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: ServeRequest) -> ServerResponse:
"""Create a new Flow entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
def create_and_save_dag(
self, request: ServeRequest, save_failed_flow: bool = False
) -> ServerResponse:
"""Create a new Flow entity and save the DAG
Args:
request (ServeRequest): The request
save_failed_flow (bool): Whether to save the failed flow
Returns:
ServerResponse: The response
"""
try:
# Build DAG from request
if request.define_type == "json":
dag = self._flow_factory.build(request)
else:
dag = request.flow_dag
request.dag_id = dag.dag_id
# Save DAG to storage
request.flow_category = self._parse_flow_category(dag)
except Exception as e:
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = str(e)
request.dag_id = ""
return self.dao.create(request)
else:
raise ValueError(
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
) from e
self.dao.create(request)
# Query from database
res = self.get({"uid": request.uid})
state = request.state
try:
if state == State.DEPLOYED:
# Register the DAG
self.dag_manager.register_dag(dag, request.uid)
# Update state to RUNNING
request.state = State.RUNNING
request.error_message = ""
self.dao.update({"uid": request.uid}, request)
else:
logger.info(f"Flow state is {state}, skip register DAG")
except Exception as e:
logger.warning(f"Register DAG({dag.dag_id}) error: {str(e)}")
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = f"Register DAG error: {str(e)}"
request.dag_id = ""
self.dao.update({"uid": request.uid}, request)
else:
# Rollback
self.delete(request.uid)
raise e
return res
def _pre_load_dag_from_db(self):
"""Pre load DAG from db"""
entities = self.dao.get_list({})
for entity in entities:
try:
self._flow_factory.pre_load_requirements(entity)
except Exception as e:
logger.warning(
f"Pre load requirements for DAG({entity.name}, {entity.dag_id}) "
f"from db error: {str(e)}"
)
def load_dag_from_db(self):
"""Load DAG from db"""
entities = self.dao.get_list({})
for entity in entities:
try:
if entity.define_type != "json":
continue
dag = self._flow_factory.build(entity)
if entity.state in [State.DEPLOYED, State.RUNNING] or (
entity.version == "0.1.0" and entity.state == State.INITIALIZING
):
# Register the DAG
self.dag_manager.register_dag(dag, entity.uid)
# Update state to RUNNING
entity.state = State.RUNNING
entity.error_message = ""
self.dao.update({"uid": entity.uid}, entity)
except Exception as e:
logger.warning(
f"Load DAG({entity.name}, {entity.dag_id}) from db error: {str(e)}"
)
def _pre_load_dag_from_dbgpts(self):
"""Pre load DAG from dbgpts"""
flows = self.dbgpts_loader.get_flows()
for flow in flows:
try:
if flow.define_type == "json":
self._flow_factory.pre_load_requirements(flow)
except Exception as e:
logger.warning(
f"Pre load requirements for DAG({flow.name}) from "
f"dbgpts error: {str(e)}"
)
def load_dag_from_dbgpts(self, is_first_load: bool = False):
"""Load DAG from dbgpts"""
flows = self.dbgpts_loader.get_flows()
for flow in flows:
try:
if flow.define_type == "python" and flow.flow_dag is None:
continue
# Set state to DEPLOYED
flow.state = State.DEPLOYED
exist_inst = self.dao.get_one({"name": flow.name})
if not exist_inst:
self.create_and_save_dag(flow, save_failed_flow=True)
elif is_first_load or exist_inst.state != State.RUNNING:
# TODO check version, must be greater than the exist one
flow.uid = exist_inst.uid
self.update_flow(flow, check_editable=False, save_failed_flow=True)
except Exception as e:
import traceback
message = traceback.format_exc()
logger.warning(
f"Load DAG {flow.name} from dbgpts error: {str(e)}, detail: {message}"
)
def update_flow(
self,
request: ServeRequest,
check_editable: bool = True,
save_failed_flow: bool = False,
) -> ServerResponse:
"""Update a Flow entity
Args:
request (ServeRequest): The request
check_editable (bool): Whether to check the editable
save_failed_flow (bool): Whether to save the failed flow
Returns:
ServerResponse: The response
"""
new_state = State.DEPLOYED
try:
# Try to build the dag from the request
if request.define_type == "json":
dag = self._flow_factory.build(request)
else:
dag = request.flow_dag
request.flow_category = self._parse_flow_category(dag)
except Exception as e:
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = str(e)
request.dag_id = ""
return self.dao.update({"uid": request.uid}, request)
else:
raise e
# Build the query request from the request
query_request = {"uid": request.uid}
inst = self.get(query_request)
if not inst:
raise HTTPException(status_code=404, detail=f"Flow {request.uid} not found")
if check_editable and not inst.editable:
raise HTTPException(
status_code=403, detail=f"Flow {request.uid} is not editable"
)
old_state = inst.state
if not State.can_change_state(old_state, new_state):
raise HTTPException(
status_code=400,
detail=f"Flow {request.uid} state can't change from {old_state} to "
f"{new_state}",
)
old_data: Optional[ServerResponse] = None
try:
update_obj = self.dao.update(query_request, update_request=request)
old_data = self.delete(request.uid)
old_data.state = old_state
if not old_data:
raise HTTPException(
status_code=404, detail=f"Flow detail {request.uid} not found"
)
update_obj.flow_dag = request.flow_dag
return self.create_and_save_dag(update_obj)
except Exception as e:
if old_data and old_data.state == State.RUNNING:
# Old flow is running, try to recover it
# first set the state to DEPLOYED
old_data.state = State.DEPLOYED
self.create_and_save_dag(old_data)
raise e
def get(self, request: QUERY_SPEC) -> Optional[ServerResponse]:
"""Get a Flow entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
flow = self.dao.get_one(query_request)
if flow:
fill_flow_panel(flow)
metadata = self.dag_manager.get_dag_metadata(
flow.dag_id, alias_name=flow.uid
)
if metadata:
flow.metadata = metadata.to_dict()
return flow
def delete(self, uid: str) -> Optional[ServerResponse]:
"""Delete a Flow entity
Args:
uid (str): The uid
Returns:
ServerResponse: The data after deletion
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {"uid": uid}
inst = self.get(query_request)
if inst is None:
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
if inst.state == State.RUNNING and not inst.dag_id:
raise HTTPException(
status_code=404, detail=f"Running flow {uid}'s dag id not found"
)
try:
if inst.dag_id:
self.dag_manager.unregister_dag(inst.dag_id)
except Exception as e:
logger.warning(f"Unregister DAG({inst.dag_id}) error: {str(e)}")
self.dao.delete(query_request)
return inst
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of Flow entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: QUERY_SPEC, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of Flow entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
page_result = self.dao.get_list_page(
request, page, page_size, desc_order_column=ServeEntity.gmt_modified.name
)
for item in page_result.items:
metadata = self.dag_manager.get_dag_metadata(
item.dag_id, alias_name=item.uid
)
if metadata:
item.metadata = metadata.to_dict()
return page_result
def get_flow_templates(
self,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
page: int = 1,
page_size: int = 20,
) -> PaginationResult[ServerResponse]:
"""Get a list of Flow templates
Args:
user_name (Optional[str]): The user name
sys_code (Optional[str]): The system code
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
local_file_templates = self._get_flow_templates_from_files()
return PaginationResult.build_from_all(local_file_templates, page, page_size)
def _get_flow_templates_from_files(self) -> List[ServerResponse]:
"""Get a list of Flow templates from files"""
user_lang = self._system_app.config.get_current_lang(default="en")
# List files in current directory
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
template_dir = os.path.join(parent_dir, "templates", user_lang)
default_template_dir = os.path.join(parent_dir, "templates", "en")
if not os.path.exists(template_dir):
template_dir = default_template_dir
templates = []
for root, _, files in os.walk(template_dir):
for file in files:
if file.endswith(".json"):
try:
with open(os.path.join(root, file), "r") as f:
data = json.load(f)
templates.append(_parse_flow_template_from_json(data))
except Exception as e:
logger.warning(f"Load template {file} error: {str(e)}")
return templates
async def chat_stream_flow_str(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[str]:
"""Stream chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
"""
# Must be non-incremental
request.incremental = False
async for output in self.safe_chat_stream_flow(flow_uid, request):
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"
async def chat_stream_openai(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[str]:
conv_uid = request.conv_uid
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=conv_uid, choices=[choice_data], model=request.model
)
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
request.incremental = True
async for output in self.safe_chat_stream_flow(flow_uid, request):
if not output.success:
yield f"data: {json.dumps(output.to_dict(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=output.text),
)
chunk = ChatCompletionStreamResponse(
id=conv_uid,
choices=[choice_data],
model=request.model,
)
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
yield "data: [DONE]\n\n"
async def safe_chat_flow(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> ModelOutput:
"""Chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
Returns:
ModelOutput: The output
"""
incremental = request.incremental
try:
task = await self._get_callable_task(flow_uid)
return await safe_chat_with_dag_task(task, request)
except HTTPException as e:
return ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
return ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def safe_chat_stream_flow(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[ModelOutput]:
"""Stream chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
Returns:
AsyncIterator[ModelOutput]: The output
"""
incremental = request.incremental
try:
task = await self._get_callable_task(flow_uid)
async for output in safe_chat_stream_with_dag_task(
task, request, incremental
):
yield output
except HTTPException as e:
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def _get_callable_task(
self,
flow_uid: str,
) -> BaseOperator:
"""Return the callable task.
Returns:
BaseOperator: The callable task
Raises:
HTTPException: If the flow is not found
ValueError: If the flow is not a chat flow or the leaf node is not found.
"""
flow = self.get({"uid": flow_uid})
if not flow:
raise HTTPException(status_code=404, detail=f"Flow {flow_uid} not found")
dag_id = flow.dag_id
if not dag_id or dag_id not in self.dag_manager.dag_map:
raise HTTPException(
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
)
dag = self.dag_manager.dag_map[dag_id]
# if (
# flow.flow_category != FlowCategory.CHAT_FLOW
# and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
# ):
# raise ValueError(f"Flow {flow_uid} is not a chat flow")
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
return cast(BaseOperator, leaf_nodes[0])
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
"""Parse the flow category
Args:
flow_category (str): The flow category
Returns:
FlowCategory: The flow category
"""
from dbgpt.core.awel.flow.base import _get_type_cls
triggers = dag.trigger_nodes
leaf_nodes = dag.leaf_nodes
if (
not triggers
or not leaf_nodes
or len(leaf_nodes) > 1
or not isinstance(leaf_nodes[0], BaseOperator)
):
return FlowCategory.COMMON
leaf_node = cast(BaseOperator, leaf_nodes[0])
if not leaf_node.metadata or not leaf_node.metadata.outputs:
return FlowCategory.COMMON
common_http_trigger = False
agent_trigger = False
for trigger in triggers:
if isinstance(trigger, CommonLLMHttpTrigger):
common_http_trigger = True
break
if isinstance(trigger, AgentDummyTrigger):
agent_trigger = True
break
output = leaf_node.metadata.outputs[0]
try:
real_class = _get_type_cls(output.type_cls)
if agent_trigger:
return FlowCategory.CHAT_AGENT
elif common_http_trigger and is_chat_flow_type(real_class, is_class=True):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON
async def debug_flow(
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
) -> AsyncIterator[ModelOutput]:
"""Debug the flow.
Args:
request (FlowDebugRequest): The request
default_incremental (Optional[bool]): The default incremental configuration
Returns:
AsyncIterator[ModelOutput]: The output
"""
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
dag = await blocking_func_to_async(
self._system_app,
self._flow_factory.build,
request.flow,
)
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
task = cast(BaseOperator, leaf_nodes[0])
dag_metadata = _parse_metadata(dag)
# TODO: Run task with variables
variables = request.variables
dag_request = request.request
if isinstance(request.request, CommonLLMHttpRequestBody):
incremental = request.request.incremental
elif isinstance(request.request, dict):
incremental = request.request.get("incremental", False)
else:
raise ValueError("Invalid request type")
if default_incremental is not None:
incremental = default_incremental
try:
async for output in safe_chat_stream_with_dag_task(
task, dag_request, incremental
):
yield output
except HTTPException as e:
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def _wrapper_chat_stream_flow_str(
self, stream_iter: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async for output in stream_iter:
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"
def _parse_flow_template_from_json(json_dict: dict) -> ServerResponse:
"""Parse the flow from json
Args:
json_dict (dict): The json dict
Returns:
ServerResponse: The flow
"""
flow_json = json_dict["flow"]
flow_json["editable"] = False
del flow_json["uid"]
flow_json["state"] = State.INITIALIZING
flow_json["dag_id"] = None
return ServerResponse(**flow_json)