Files
DB-GPT/dbgpt/serve/flow/service/service.py
2024-05-30 18:51:57 +08:00

776 lines
27 KiB
Python

import json
import logging
import traceback
from typing import Any, AsyncIterator, List, Optional, cast
import schedule
from fastapi import HTTPException
from dbgpt._private.pydantic import model_to_json
from dbgpt.component import SystemApp
from dbgpt.core.awel import (
DAG,
BaseOperator,
CommonLLMHttpRequestBody,
CommonLLMHttpResponseBody,
)
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import (
FlowCategory,
FlowFactory,
State,
fill_flow_panel,
)
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.core.schema.api import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from dbgpt.util.dbgpts.loader import DBGPTsLoader
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
logger = logging.getLogger(__name__)
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for Flow"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
self._dag_manager: Optional[DAGManager] = None
self._flow_factory: FlowFactory = FlowFactory()
self._dbgpts_loader: Optional[DBGPTsLoader] = None
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
self._dbgpts_loader = system_app.get_component(
DBGPTsLoader.name,
DBGPTsLoader,
or_register_component=DBGPTsLoader,
load_dbgpts_interval=self._serve_config.load_dbgpts_interval,
)
def before_start(self):
"""Execute before the application starts"""
self._dag_manager = DAGManager.get_instance(self._system_app)
self._pre_load_dag_from_db()
self._pre_load_dag_from_dbgpts()
def after_start(self):
"""Execute after the application starts"""
self.load_dag_from_db()
self.load_dag_from_dbgpts(is_first_load=True)
schedule.every(self._serve_config.load_dbgpts_interval).seconds.do(
self.load_dag_from_dbgpts
)
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def dag_manager(self) -> DAGManager:
"""Returns the internal DAGManager."""
if self._dag_manager is None:
raise ValueError("DAGManager is not initialized")
return self._dag_manager
@property
def dbgpts_loader(self) -> DBGPTsLoader:
"""Returns the internal DBGPTsLoader."""
if self._dbgpts_loader is None:
raise ValueError("DBGPTsLoader is not initialized")
return self._dbgpts_loader
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: ServeRequest) -> ServerResponse:
"""Create a new Flow entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
def create_and_save_dag(
self, request: ServeRequest, save_failed_flow: bool = False
) -> ServerResponse:
"""Create a new Flow entity and save the DAG
Args:
request (ServeRequest): The request
save_failed_flow (bool): Whether to save the failed flow
Returns:
ServerResponse: The response
"""
try:
# Build DAG from request
if request.define_type == "json":
dag = self._flow_factory.build(request)
else:
dag = request.flow_dag
request.dag_id = dag.dag_id
# Save DAG to storage
request.flow_category = self._parse_flow_category(dag)
except Exception as e:
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = str(e)
request.dag_id = ""
return self.dao.create(request)
else:
raise ValueError(
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
) from e
res = self.dao.create(request)
state = request.state
try:
if state == State.DEPLOYED:
# Register the DAG
self.dag_manager.register_dag(dag, request.uid)
# Update state to RUNNING
request.state = State.RUNNING
request.error_message = ""
self.dao.update({"uid": request.uid}, request)
else:
logger.info(f"Flow state is {state}, skip register DAG")
except Exception as e:
logger.warning(f"Register DAG({dag.dag_id}) error: {str(e)}")
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = f"Register DAG error: {str(e)}"
request.dag_id = ""
self.dao.update({"uid": request.uid}, request)
else:
# Rollback
self.delete(request.uid)
raise e
return res
def _pre_load_dag_from_db(self):
"""Pre load DAG from db"""
entities = self.dao.get_list({})
for entity in entities:
try:
self._flow_factory.pre_load_requirements(entity)
except Exception as e:
logger.warning(
f"Pre load requirements for DAG({entity.name}, {entity.dag_id}) "
f"from db error: {str(e)}"
)
def load_dag_from_db(self):
"""Load DAG from db"""
entities = self.dao.get_list({})
for entity in entities:
try:
if entity.define_type != "json":
continue
dag = self._flow_factory.build(entity)
if entity.state in [State.DEPLOYED, State.RUNNING] or (
entity.version == "0.1.0" and entity.state == State.INITIALIZING
):
# Register the DAG
self.dag_manager.register_dag(dag, entity.uid)
# Update state to RUNNING
entity.state = State.RUNNING
entity.error_message = ""
self.dao.update({"uid": entity.uid}, entity)
except Exception as e:
logger.warning(
f"Load DAG({entity.name}, {entity.dag_id}) from db error: {str(e)}"
)
def _pre_load_dag_from_dbgpts(self):
"""Pre load DAG from dbgpts"""
flows = self.dbgpts_loader.get_flows()
for flow in flows:
try:
if flow.define_type == "json":
self._flow_factory.pre_load_requirements(flow)
except Exception as e:
logger.warning(
f"Pre load requirements for DAG({flow.name}) from "
f"dbgpts error: {str(e)}"
)
def load_dag_from_dbgpts(self, is_first_load: bool = False):
"""Load DAG from dbgpts"""
flows = self.dbgpts_loader.get_flows()
for flow in flows:
try:
if flow.define_type == "python" and flow.flow_dag is None:
continue
# Set state to DEPLOYED
flow.state = State.DEPLOYED
exist_inst = self.get({"name": flow.name})
if not exist_inst:
self.create_and_save_dag(flow, save_failed_flow=True)
elif is_first_load or exist_inst.state != State.RUNNING:
# TODO check version, must be greater than the exist one
flow.uid = exist_inst.uid
self.update_flow(flow, check_editable=False, save_failed_flow=True)
except Exception as e:
import traceback
message = traceback.format_exc()
logger.warning(
f"Load DAG {flow.name} from dbgpts error: {str(e)}, detail: {message}"
)
def update_flow(
self,
request: ServeRequest,
check_editable: bool = True,
save_failed_flow: bool = False,
) -> ServerResponse:
"""Update a Flow entity
Args:
request (ServeRequest): The request
check_editable (bool): Whether to check the editable
save_failed_flow (bool): Whether to save the failed flow
Returns:
ServerResponse: The response
"""
new_state = request.state
try:
# Try to build the dag from the request
if request.define_type == "json":
dag = self._flow_factory.build(request)
else:
dag = request.flow_dag
request.flow_category = self._parse_flow_category(dag)
except Exception as e:
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = str(e)
request.dag_id = ""
return self.dao.update({"uid": request.uid}, request)
else:
raise e
# Build the query request from the request
query_request = {"uid": request.uid}
inst = self.get(query_request)
if not inst:
raise HTTPException(status_code=404, detail=f"Flow {request.uid} not found")
if check_editable and not inst.editable:
raise HTTPException(
status_code=403, detail=f"Flow {request.uid} is not editable"
)
old_state = inst.state
if not State.can_change_state(old_state, new_state):
raise HTTPException(
status_code=400,
detail=f"Flow {request.uid} state can't change from {old_state} to "
f"{new_state}",
)
old_data: Optional[ServerResponse] = None
try:
update_obj = self.dao.update(query_request, update_request=request)
old_data = self.delete(request.uid)
old_data.state = old_state
if not old_data:
raise HTTPException(
status_code=404, detail=f"Flow detail {request.uid} not found"
)
update_obj.flow_dag = request.flow_dag
return self.create_and_save_dag(update_obj)
except Exception as e:
if old_data and old_data.state == State.RUNNING:
# Old flow is running, try to recover it
# first set the state to DEPLOYED
old_data.state = State.DEPLOYED
self.create_and_save_dag(old_data)
raise e
def get(self, request: QUERY_SPEC) -> Optional[ServerResponse]:
"""Get a Flow entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
flow = self.dao.get_one(query_request)
if flow:
fill_flow_panel(flow)
return flow
def delete(self, uid: str) -> Optional[ServerResponse]:
"""Delete a Flow entity
Args:
uid (str): The uid
Returns:
ServerResponse: The data after deletion
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {"uid": uid}
inst = self.get(query_request)
if inst is None:
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
if inst.state == State.RUNNING and not inst.dag_id:
raise HTTPException(
status_code=404, detail=f"Running flow {uid}'s dag id not found"
)
try:
if inst.dag_id:
self.dag_manager.unregister_dag(inst.dag_id)
except Exception as e:
logger.warning(f"Unregister DAG({inst.dag_id}) error: {str(e)}")
self.dao.delete(query_request)
return inst
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of Flow entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: QUERY_SPEC, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of Flow entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
return self.dao.get_list_page(request, page, page_size)
async def chat_stream_flow_str(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[str]:
"""Stream chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
"""
# Must be non-incremental
request.incremental = False
async for output in self.safe_chat_stream_flow(flow_uid, request):
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"
async def chat_stream_openai(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[str]:
conv_uid = request.conv_uid
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=conv_uid, choices=[choice_data], model=request.model
)
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
request.incremental = True
async for output in self.safe_chat_stream_flow(flow_uid, request):
if not output.success:
yield f"data: {json.dumps(output.to_dict(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=output.text),
)
chunk = ChatCompletionStreamResponse(
id=conv_uid,
choices=[choice_data],
model=request.model,
)
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
yield "data: [DONE]\n\n"
async def safe_chat_flow(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> ModelOutput:
"""Chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
Returns:
ModelOutput: The output
"""
incremental = request.incremental
try:
task = await self._get_callable_task(flow_uid)
return await _safe_chat_with_dag_task(task, request)
except HTTPException as e:
return ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
return ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def safe_chat_stream_flow(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[ModelOutput]:
"""Stream chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
Returns:
AsyncIterator[ModelOutput]: The output
"""
incremental = request.incremental
try:
task = await self._get_callable_task(flow_uid)
async for output in _safe_chat_stream_with_dag_task(
task, request, incremental
):
yield output
except HTTPException as e:
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def _get_callable_task(
self,
flow_uid: str,
) -> BaseOperator:
"""Return the callable task.
Returns:
BaseOperator: The callable task
Raises:
HTTPException: If the flow is not found
ValueError: If the flow is not a chat flow or the leaf node is not found.
"""
flow = self.get({"uid": flow_uid})
if not flow:
raise HTTPException(status_code=404, detail=f"Flow {flow_uid} not found")
dag_id = flow.dag_id
if not dag_id or dag_id not in self.dag_manager.dag_map:
raise HTTPException(
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
)
dag = self.dag_manager.dag_map[dag_id]
if (
flow.flow_category != FlowCategory.CHAT_FLOW
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
):
raise ValueError(f"Flow {flow_uid} is not a chat flow")
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
return cast(BaseOperator, leaf_nodes[0])
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
"""Parse the flow category
Args:
flow_category (str): The flow category
Returns:
FlowCategory: The flow category
"""
from dbgpt.core.awel.flow.base import _get_type_cls
triggers = dag.trigger_nodes
leaf_nodes = dag.leaf_nodes
if (
not triggers
or not leaf_nodes
or len(leaf_nodes) > 1
or not isinstance(leaf_nodes[0], BaseOperator)
):
return FlowCategory.COMMON
common_http_trigger = False
for trigger in triggers:
if isinstance(trigger, CommonLLMHttpTrigger):
common_http_trigger = True
break
leaf_node = cast(BaseOperator, leaf_nodes[0])
if not leaf_node.metadata or not leaf_node.metadata.outputs:
return FlowCategory.COMMON
output = leaf_node.metadata.outputs[0]
try:
real_class = _get_type_cls(output.type_cls)
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON
def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
if is_class:
return (
output_obj == str
or output_obj == CommonLLMHttpResponseBody
or output_obj == ModelOutput
)
else:
chat_types = (str, CommonLLMHttpResponseBody)
return isinstance(output_obj, chat_types)
async def _safe_chat_with_dag_task(task: BaseOperator, request: Any) -> ModelOutput:
"""Chat with the DAG task."""
try:
finish_reason = None
usage = None
metrics = None
error_code = 0
text = ""
async for output in _safe_chat_stream_with_dag_task(task, request, False):
finish_reason = output.finish_reason
usage = output.usage
metrics = output.metrics
error_code = output.error_code
text = output.text
return ModelOutput(
error_code=error_code,
text=text,
metrics=metrics,
usage=usage,
finish_reason=finish_reason,
)
except Exception as e:
return ModelOutput(error_code=1, text=str(e), incremental=False)
async def _safe_chat_stream_with_dag_task(
task: BaseOperator,
request: Any,
incremental: bool,
) -> AsyncIterator[ModelOutput]:
"""Chat with the DAG task."""
try:
async for output in _chat_stream_with_dag_task(task, request, incremental):
yield output
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
finally:
if task.streaming_operator:
if task.dag:
await task.dag._after_dag_end(task.current_event_loop_task_id)
async def _chat_stream_with_dag_task(
task: BaseOperator,
request: Any,
incremental: bool,
) -> AsyncIterator[ModelOutput]:
"""Chat with the DAG task."""
is_sse = task.output_format and task.output_format.upper() == "SSE"
if not task.streaming_operator:
try:
result = await task.call(request)
model_output = _parse_single_output(result, is_sse)
model_output.incremental = incremental
yield model_output
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
else:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
full_text = ""
async for output in await task.call_stream(request):
model_output = _parse_openai_output(output)
# The output of the OpenAI streaming API is incremental
full_text += model_output.text
model_output.incremental = incremental
model_output.text = model_output.text if incremental else full_text
yield model_output
if not model_output.success:
break
else:
full_text = ""
previous_text = ""
async for output in await task.call_stream(request):
model_output = _parse_single_output(output, is_sse)
model_output.incremental = incremental
if task.incremental_output:
# Output is incremental, append the text
full_text += model_output.text
else:
# Output is not incremental, last output is the full text
full_text = model_output.text
if not incremental:
# Return the full text
model_output.text = full_text
else:
# Return the incremental text
delta_text = full_text[len(previous_text) :]
previous_text = (
full_text
if len(full_text) > len(previous_text)
else previous_text
)
model_output.text = delta_text
yield model_output
if not model_output.success:
break
def _parse_single_output(output: Any, is_sse: bool) -> ModelOutput:
"""Parse the single output."""
finish_reason = None
usage = None
metrics = None
if output is None:
error_code = 1
text = "The output is None!"
elif isinstance(output, str):
if is_sse:
sse_output = _parse_sse_data(output)
if sse_output is None:
error_code = 1
text = "The output is not a SSE format"
else:
error_code = 0
text = sse_output
else:
error_code = 0
text = output
elif isinstance(output, ModelOutput):
error_code = output.error_code
text = output.text
finish_reason = output.finish_reason
usage = output.usage
metrics = output.metrics
elif isinstance(output, CommonLLMHttpResponseBody):
error_code = output.error_code
text = output.text
elif isinstance(output, dict):
error_code = 0
text = json.dumps(output, ensure_ascii=False)
else:
error_code = 1
text = f"The output is not a valid format({type(output)})"
return ModelOutput(
error_code=error_code,
text=text,
finish_reason=finish_reason,
usage=usage,
metrics=metrics,
)
def _parse_openai_output(output: Any) -> ModelOutput:
"""Parse the OpenAI output."""
text = ""
if not isinstance(output, str):
return ModelOutput(
error_code=1,
text="The output is not a stream format",
)
if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]":
return ModelOutput(error_code=0, text="")
if not output.startswith("data:"):
return ModelOutput(
error_code=1,
text="The output is not a stream format",
)
sse_output = _parse_sse_data(output)
if sse_output is None:
return ModelOutput(error_code=1, text="The output is not a SSE format")
json_data = sse_output.strip()
try:
dict_data = json.loads(json_data)
except Exception as e:
return ModelOutput(
error_code=1,
text=f"Invalid JSON data: {json_data}, {e}",
)
if "choices" not in dict_data:
return ModelOutput(
error_code=1,
text=dict_data.get("text", "Unknown error"),
)
choices = dict_data["choices"]
finish_reason: Optional[str] = None
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
text = delta_data.delta.content
finish_reason = delta_data.finish_reason
return ModelOutput(error_code=0, text=text, finish_reason=finish_reason)
def _parse_sse_data(output: str) -> Optional[str]:
if output.startswith("data:"):
if output.startswith("data: "):
output = output[6:]
else:
output = output[5:]
return output
else:
return None