mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 21:51:25 +00:00
feat(core): Multiple ways to run dbgpts (#1734)
This commit is contained in:
@@ -1,19 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any, AsyncIterator, List, Optional, cast
|
||||
from typing import AsyncIterator, List, Optional, cast
|
||||
|
||||
import schedule
|
||||
from fastapi import HTTPException
|
||||
|
||||
from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
BaseOperator,
|
||||
CommonLLMHttpRequestBody,
|
||||
CommonLLMHttpResponseBody,
|
||||
)
|
||||
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.flow.flow_factory import (
|
||||
FlowCategory,
|
||||
@@ -22,10 +16,13 @@ from dbgpt.core.awel.flow.flow_factory import (
|
||||
fill_flow_panel,
|
||||
)
|
||||
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
||||
from dbgpt.core.awel.util.chat_util import (
|
||||
is_chat_flow_type,
|
||||
safe_chat_stream_with_dag_task,
|
||||
safe_chat_with_dag_task,
|
||||
)
|
||||
from dbgpt.core.interface.llm import ModelOutput
|
||||
from dbgpt.core.schema.api import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
@@ -333,6 +330,11 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flow = self.dao.get_one(query_request)
|
||||
if flow:
|
||||
fill_flow_panel(flow)
|
||||
metadata = self.dag_manager.get_dag_metadata(
|
||||
flow.dag_id, alias_name=flow.uid
|
||||
)
|
||||
if metadata:
|
||||
flow.metadata = metadata.to_dict()
|
||||
return flow
|
||||
|
||||
def delete(self, uid: str) -> Optional[ServerResponse]:
|
||||
@@ -390,7 +392,14 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Returns:
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
return self.dao.get_list_page(request, page, page_size)
|
||||
page_result = self.dao.get_list_page(request, page, page_size)
|
||||
for item in page_result.items:
|
||||
metadata = self.dag_manager.get_dag_metadata(
|
||||
item.dag_id, alias_name=item.uid
|
||||
)
|
||||
if metadata:
|
||||
item.metadata = metadata.to_dict()
|
||||
return page_result
|
||||
|
||||
async def chat_stream_flow_str(
|
||||
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
||||
@@ -463,7 +472,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
incremental = request.incremental
|
||||
try:
|
||||
task = await self._get_callable_task(flow_uid)
|
||||
return await _safe_chat_with_dag_task(task, request)
|
||||
return await safe_chat_with_dag_task(task, request)
|
||||
except HTTPException as e:
|
||||
return ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
||||
except Exception as e:
|
||||
@@ -484,7 +493,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
incremental = request.incremental
|
||||
try:
|
||||
task = await self._get_callable_task(flow_uid)
|
||||
async for output in _safe_chat_stream_with_dag_task(
|
||||
async for output in safe_chat_stream_with_dag_task(
|
||||
task, request, incremental
|
||||
):
|
||||
yield output
|
||||
@@ -556,220 +565,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
output = leaf_node.metadata.outputs[0]
|
||||
try:
|
||||
real_class = _get_type_cls(output.type_cls)
|
||||
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
|
||||
if common_http_trigger and is_chat_flow_type(real_class, is_class=True):
|
||||
return FlowCategory.CHAT_FLOW
|
||||
except Exception:
|
||||
return FlowCategory.COMMON
|
||||
|
||||
|
||||
def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||
if is_class:
|
||||
return (
|
||||
output_obj == str
|
||||
or output_obj == CommonLLMHttpResponseBody
|
||||
or output_obj == ModelOutput
|
||||
)
|
||||
else:
|
||||
chat_types = (str, CommonLLMHttpResponseBody)
|
||||
return isinstance(output_obj, chat_types)
|
||||
|
||||
|
||||
async def _safe_chat_with_dag_task(task: BaseOperator, request: Any) -> ModelOutput:
|
||||
"""Chat with the DAG task."""
|
||||
try:
|
||||
finish_reason = None
|
||||
usage = None
|
||||
metrics = None
|
||||
error_code = 0
|
||||
text = ""
|
||||
async for output in _safe_chat_stream_with_dag_task(task, request, False):
|
||||
finish_reason = output.finish_reason
|
||||
usage = output.usage
|
||||
metrics = output.metrics
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
return ModelOutput(
|
||||
error_code=error_code,
|
||||
text=text,
|
||||
metrics=metrics,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return ModelOutput(error_code=1, text=str(e), incremental=False)
|
||||
|
||||
|
||||
async def _safe_chat_stream_with_dag_task(
|
||||
task: BaseOperator,
|
||||
request: Any,
|
||||
incremental: bool,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Chat with the DAG task."""
|
||||
try:
|
||||
async for output in _chat_stream_with_dag_task(task, request, incremental):
|
||||
yield output
|
||||
except Exception as e:
|
||||
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
finally:
|
||||
if task.streaming_operator:
|
||||
if task.dag:
|
||||
await task.dag._after_dag_end(task.current_event_loop_task_id)
|
||||
|
||||
|
||||
async def _chat_stream_with_dag_task(
|
||||
task: BaseOperator,
|
||||
request: Any,
|
||||
incremental: bool,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Chat with the DAG task."""
|
||||
is_sse = task.output_format and task.output_format.upper() == "SSE"
|
||||
if not task.streaming_operator:
|
||||
try:
|
||||
result = await task.call(request)
|
||||
model_output = _parse_single_output(result, is_sse)
|
||||
model_output.incremental = incremental
|
||||
yield model_output
|
||||
except Exception as e:
|
||||
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
else:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
|
||||
if OpenAIStreamingOutputOperator and isinstance(
|
||||
task, OpenAIStreamingOutputOperator
|
||||
):
|
||||
full_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
model_output = _parse_openai_output(output)
|
||||
# The output of the OpenAI streaming API is incremental
|
||||
full_text += model_output.text
|
||||
model_output.incremental = incremental
|
||||
model_output.text = model_output.text if incremental else full_text
|
||||
yield model_output
|
||||
if not model_output.success:
|
||||
break
|
||||
else:
|
||||
full_text = ""
|
||||
previous_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
model_output = _parse_single_output(output, is_sse)
|
||||
model_output.incremental = incremental
|
||||
if task.incremental_output:
|
||||
# Output is incremental, append the text
|
||||
full_text += model_output.text
|
||||
else:
|
||||
# Output is not incremental, last output is the full text
|
||||
full_text = model_output.text
|
||||
if not incremental:
|
||||
# Return the full text
|
||||
model_output.text = full_text
|
||||
else:
|
||||
# Return the incremental text
|
||||
delta_text = full_text[len(previous_text) :]
|
||||
previous_text = (
|
||||
full_text
|
||||
if len(full_text) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
model_output.text = delta_text
|
||||
yield model_output
|
||||
if not model_output.success:
|
||||
break
|
||||
|
||||
|
||||
def _parse_single_output(output: Any, is_sse: bool) -> ModelOutput:
|
||||
"""Parse the single output."""
|
||||
finish_reason = None
|
||||
usage = None
|
||||
metrics = None
|
||||
if output is None:
|
||||
error_code = 1
|
||||
text = "The output is None!"
|
||||
elif isinstance(output, str):
|
||||
if is_sse:
|
||||
sse_output = _parse_sse_data(output)
|
||||
if sse_output is None:
|
||||
error_code = 1
|
||||
text = "The output is not a SSE format"
|
||||
else:
|
||||
error_code = 0
|
||||
text = sse_output
|
||||
else:
|
||||
error_code = 0
|
||||
text = output
|
||||
elif isinstance(output, ModelOutput):
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
finish_reason = output.finish_reason
|
||||
usage = output.usage
|
||||
metrics = output.metrics
|
||||
elif isinstance(output, CommonLLMHttpResponseBody):
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
elif isinstance(output, dict):
|
||||
error_code = 0
|
||||
text = json.dumps(output, ensure_ascii=False)
|
||||
else:
|
||||
error_code = 1
|
||||
text = f"The output is not a valid format({type(output)})"
|
||||
return ModelOutput(
|
||||
error_code=error_code,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
|
||||
def _parse_openai_output(output: Any) -> ModelOutput:
|
||||
"""Parse the OpenAI output."""
|
||||
text = ""
|
||||
if not isinstance(output, str):
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text="The output is not a stream format",
|
||||
)
|
||||
if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]":
|
||||
return ModelOutput(error_code=0, text="")
|
||||
if not output.startswith("data:"):
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text="The output is not a stream format",
|
||||
)
|
||||
|
||||
sse_output = _parse_sse_data(output)
|
||||
if sse_output is None:
|
||||
return ModelOutput(error_code=1, text="The output is not a SSE format")
|
||||
json_data = sse_output.strip()
|
||||
try:
|
||||
dict_data = json.loads(json_data)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text=f"Invalid JSON data: {json_data}, {e}",
|
||||
)
|
||||
if "choices" not in dict_data:
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text=dict_data.get("text", "Unknown error"),
|
||||
)
|
||||
choices = dict_data["choices"]
|
||||
finish_reason: Optional[str] = None
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
delta_data = ChatCompletionResponseStreamChoice(**choice)
|
||||
if delta_data.delta.content:
|
||||
text = delta_data.delta.content
|
||||
finish_reason = delta_data.finish_reason
|
||||
return ModelOutput(error_code=0, text=text, finish_reason=finish_reason)
|
||||
|
||||
|
||||
def _parse_sse_data(output: str) -> Optional[str]:
|
||||
if output.startswith("data:"):
|
||||
if output.startswith("data: "):
|
||||
output = output[6:]
|
||||
else:
|
||||
output = output[5:]
|
||||
|
||||
return output
|
||||
else:
|
||||
return None
|
||||
|
Reference in New Issue
Block a user