feat(core): Multiple ways to run dbgpts (#1734)

This commit is contained in:
Fangyin Cheng
2024-07-18 17:50:40 +08:00
committed by GitHub
parent d389fddc2f
commit f889fa3775
19 changed files with 1410 additions and 304 deletions

View File

@@ -1,19 +1,13 @@
import json
import logging
import traceback
from typing import Any, AsyncIterator, List, Optional, cast
from typing import AsyncIterator, List, Optional, cast
import schedule
from fastapi import HTTPException
from dbgpt._private.pydantic import model_to_json
from dbgpt.component import SystemApp
from dbgpt.core.awel import (
DAG,
BaseOperator,
CommonLLMHttpRequestBody,
CommonLLMHttpResponseBody,
)
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import (
FlowCategory,
@@ -22,10 +16,13 @@ from dbgpt.core.awel.flow.flow_factory import (
fill_flow_panel,
)
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.awel.util.chat_util import (
is_chat_flow_type,
safe_chat_stream_with_dag_task,
safe_chat_with_dag_task,
)
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.core.schema.api import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
@@ -333,6 +330,11 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
flow = self.dao.get_one(query_request)
if flow:
fill_flow_panel(flow)
metadata = self.dag_manager.get_dag_metadata(
flow.dag_id, alias_name=flow.uid
)
if metadata:
flow.metadata = metadata.to_dict()
return flow
def delete(self, uid: str) -> Optional[ServerResponse]:
@@ -390,7 +392,14 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
Returns:
List[ServerResponse]: The response
"""
return self.dao.get_list_page(request, page, page_size)
page_result = self.dao.get_list_page(request, page, page_size)
for item in page_result.items:
metadata = self.dag_manager.get_dag_metadata(
item.dag_id, alias_name=item.uid
)
if metadata:
item.metadata = metadata.to_dict()
return page_result
async def chat_stream_flow_str(
self, flow_uid: str, request: CommonLLMHttpRequestBody
@@ -463,7 +472,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
incremental = request.incremental
try:
task = await self._get_callable_task(flow_uid)
return await _safe_chat_with_dag_task(task, request)
return await safe_chat_with_dag_task(task, request)
except HTTPException as e:
return ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
@@ -484,7 +493,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
incremental = request.incremental
try:
task = await self._get_callable_task(flow_uid)
async for output in _safe_chat_stream_with_dag_task(
async for output in safe_chat_stream_with_dag_task(
task, request, incremental
):
yield output
@@ -556,220 +565,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
output = leaf_node.metadata.outputs[0]
try:
real_class = _get_type_cls(output.type_cls)
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
if common_http_trigger and is_chat_flow_type(real_class, is_class=True):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON
def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
if is_class:
return (
output_obj == str
or output_obj == CommonLLMHttpResponseBody
or output_obj == ModelOutput
)
else:
chat_types = (str, CommonLLMHttpResponseBody)
return isinstance(output_obj, chat_types)
async def _safe_chat_with_dag_task(task: BaseOperator, request: Any) -> ModelOutput:
"""Chat with the DAG task."""
try:
finish_reason = None
usage = None
metrics = None
error_code = 0
text = ""
async for output in _safe_chat_stream_with_dag_task(task, request, False):
finish_reason = output.finish_reason
usage = output.usage
metrics = output.metrics
error_code = output.error_code
text = output.text
return ModelOutput(
error_code=error_code,
text=text,
metrics=metrics,
usage=usage,
finish_reason=finish_reason,
)
except Exception as e:
return ModelOutput(error_code=1, text=str(e), incremental=False)
async def _safe_chat_stream_with_dag_task(
task: BaseOperator,
request: Any,
incremental: bool,
) -> AsyncIterator[ModelOutput]:
"""Chat with the DAG task."""
try:
async for output in _chat_stream_with_dag_task(task, request, incremental):
yield output
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
finally:
if task.streaming_operator:
if task.dag:
await task.dag._after_dag_end(task.current_event_loop_task_id)
async def _chat_stream_with_dag_task(
task: BaseOperator,
request: Any,
incremental: bool,
) -> AsyncIterator[ModelOutput]:
"""Chat with the DAG task."""
is_sse = task.output_format and task.output_format.upper() == "SSE"
if not task.streaming_operator:
try:
result = await task.call(request)
model_output = _parse_single_output(result, is_sse)
model_output.incremental = incremental
yield model_output
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
else:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
full_text = ""
async for output in await task.call_stream(request):
model_output = _parse_openai_output(output)
# The output of the OpenAI streaming API is incremental
full_text += model_output.text
model_output.incremental = incremental
model_output.text = model_output.text if incremental else full_text
yield model_output
if not model_output.success:
break
else:
full_text = ""
previous_text = ""
async for output in await task.call_stream(request):
model_output = _parse_single_output(output, is_sse)
model_output.incremental = incremental
if task.incremental_output:
# Output is incremental, append the text
full_text += model_output.text
else:
# Output is not incremental, last output is the full text
full_text = model_output.text
if not incremental:
# Return the full text
model_output.text = full_text
else:
# Return the incremental text
delta_text = full_text[len(previous_text) :]
previous_text = (
full_text
if len(full_text) > len(previous_text)
else previous_text
)
model_output.text = delta_text
yield model_output
if not model_output.success:
break
def _parse_single_output(output: Any, is_sse: bool) -> ModelOutput:
"""Parse the single output."""
finish_reason = None
usage = None
metrics = None
if output is None:
error_code = 1
text = "The output is None!"
elif isinstance(output, str):
if is_sse:
sse_output = _parse_sse_data(output)
if sse_output is None:
error_code = 1
text = "The output is not a SSE format"
else:
error_code = 0
text = sse_output
else:
error_code = 0
text = output
elif isinstance(output, ModelOutput):
error_code = output.error_code
text = output.text
finish_reason = output.finish_reason
usage = output.usage
metrics = output.metrics
elif isinstance(output, CommonLLMHttpResponseBody):
error_code = output.error_code
text = output.text
elif isinstance(output, dict):
error_code = 0
text = json.dumps(output, ensure_ascii=False)
else:
error_code = 1
text = f"The output is not a valid format({type(output)})"
return ModelOutput(
error_code=error_code,
text=text,
finish_reason=finish_reason,
usage=usage,
metrics=metrics,
)
def _parse_openai_output(output: Any) -> ModelOutput:
"""Parse the OpenAI output."""
text = ""
if not isinstance(output, str):
return ModelOutput(
error_code=1,
text="The output is not a stream format",
)
if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]":
return ModelOutput(error_code=0, text="")
if not output.startswith("data:"):
return ModelOutput(
error_code=1,
text="The output is not a stream format",
)
sse_output = _parse_sse_data(output)
if sse_output is None:
return ModelOutput(error_code=1, text="The output is not a SSE format")
json_data = sse_output.strip()
try:
dict_data = json.loads(json_data)
except Exception as e:
return ModelOutput(
error_code=1,
text=f"Invalid JSON data: {json_data}, {e}",
)
if "choices" not in dict_data:
return ModelOutput(
error_code=1,
text=dict_data.get("text", "Unknown error"),
)
choices = dict_data["choices"]
finish_reason: Optional[str] = None
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
text = delta_data.delta.content
finish_reason = delta_data.finish_reason
return ModelOutput(error_code=0, text=text, finish_reason=finish_reason)
def _parse_sse_data(output: str) -> Optional[str]:
if output.startswith("data:"):
if output.startswith("data: "):
output = output[6:]
else:
output = output[5:]
return output
else:
return None