mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 17:39:02 +00:00
feat(core): Multiple ways to run dbgpts (#1734)
This commit is contained in:
323
dbgpt/core/awel/util/chat_util.py
Normal file
323
dbgpt/core/awel/util/chat_util.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""The utility functions for chatting with the DAG task."""
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
from ...interface.llm import ModelInferenceMetrics, ModelOutput
|
||||
from ...schema.api import ChatCompletionResponseStreamChoice
|
||||
from ..operators.base import BaseOperator
|
||||
from ..trigger.http_trigger import CommonLLMHttpResponseBody
|
||||
|
||||
|
||||
def is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||
"""Check whether the output object is a chat flow type."""
|
||||
if is_class:
|
||||
return output_obj in (str, CommonLLMHttpResponseBody, ModelOutput)
|
||||
else:
|
||||
chat_types = (str, CommonLLMHttpResponseBody)
|
||||
return isinstance(output_obj, chat_types)
|
||||
|
||||
|
||||
async def safe_chat_with_dag_task(
|
||||
task: BaseOperator, request: Any, covert_to_str: bool = False
|
||||
) -> ModelOutput:
|
||||
"""Chat with the DAG task.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The DAG task to be executed.
|
||||
request (Any): The request to be passed to the DAG task.
|
||||
covert_to_str (bool, optional): Whether to convert the output to string.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The model output, the result is not incremental.
|
||||
"""
|
||||
try:
|
||||
finish_reason = None
|
||||
usage = None
|
||||
metrics = None
|
||||
error_code = 0
|
||||
text = ""
|
||||
async for output in safe_chat_stream_with_dag_task(
|
||||
task, request, False, covert_to_str=covert_to_str
|
||||
):
|
||||
finish_reason = output.finish_reason
|
||||
usage = output.usage
|
||||
metrics = output.metrics
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
return ModelOutput(
|
||||
error_code=error_code,
|
||||
text=text,
|
||||
metrics=metrics,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return ModelOutput(error_code=1, text=str(e), incremental=False)
|
||||
|
||||
|
||||
async def safe_chat_stream_with_dag_task(
|
||||
task: BaseOperator, request: Any, incremental: bool, covert_to_str: bool = False
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Chat with the DAG task.
|
||||
|
||||
This function is similar to `chat_stream_with_dag_task`, but it will catch the
|
||||
exception and return the error message.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The DAG task to be executed.
|
||||
request (Any): The request to be passed to the DAG task.
|
||||
incremental (bool): Whether the output is incremental.
|
||||
covert_to_str (bool, optional): Whether to convert the output to string.
|
||||
|
||||
Yields:
|
||||
ModelOutput: The model output.
|
||||
"""
|
||||
try:
|
||||
async for output in chat_stream_with_dag_task(
|
||||
task, request, incremental, covert_to_str=covert_to_str
|
||||
):
|
||||
yield output
|
||||
except Exception as e:
|
||||
simple_error_msg = str(e)
|
||||
if not simple_error_msg:
|
||||
simple_error_msg = traceback.format_exc()
|
||||
yield ModelOutput(error_code=1, text=simple_error_msg, incremental=incremental)
|
||||
finally:
|
||||
if task.streaming_operator and task.dag:
|
||||
await task.dag._after_dag_end(task.current_event_loop_task_id)
|
||||
|
||||
|
||||
def _is_sse_output(task: BaseOperator) -> bool:
|
||||
"""Check whether the DAG task is a server-sent event output.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The DAG task.
|
||||
|
||||
Returns:
|
||||
bool: Whether the DAG task is a server-sent event output.
|
||||
"""
|
||||
return task.output_format is not None and task.output_format.upper() == "SSE"
|
||||
|
||||
|
||||
async def chat_stream_with_dag_task(
|
||||
task: BaseOperator, request: Any, incremental: bool, covert_to_str: bool = False
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Chat with the DAG task.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The DAG task to be executed.
|
||||
request (Any): The request to be passed to the DAG task.
|
||||
incremental (bool): Whether the output is incremental.
|
||||
covert_to_str (bool, optional): Whether to convert the output to string.
|
||||
|
||||
Yields:
|
||||
ModelOutput: The model output.
|
||||
"""
|
||||
is_sse = _is_sse_output(task)
|
||||
if not task.streaming_operator:
|
||||
try:
|
||||
result = await task.call(request)
|
||||
model_output = parse_single_output(
|
||||
result, is_sse, covert_to_str=covert_to_str
|
||||
)
|
||||
model_output.incremental = incremental
|
||||
yield model_output
|
||||
except Exception as e:
|
||||
simple_error_msg = str(e)
|
||||
if not simple_error_msg:
|
||||
simple_error_msg = traceback.format_exc()
|
||||
yield ModelOutput(
|
||||
error_code=1, text=simple_error_msg, incremental=incremental
|
||||
)
|
||||
else:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
|
||||
if OpenAIStreamingOutputOperator and isinstance(
|
||||
task, OpenAIStreamingOutputOperator
|
||||
):
|
||||
full_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
model_output = parse_openai_output(output)
|
||||
# The output of the OpenAI streaming API is incremental
|
||||
full_text += model_output.text
|
||||
model_output.incremental = incremental
|
||||
model_output.text = model_output.text if incremental else full_text
|
||||
yield model_output
|
||||
if not model_output.success:
|
||||
break
|
||||
else:
|
||||
full_text = ""
|
||||
previous_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
model_output = parse_single_output(
|
||||
output, is_sse, covert_to_str=covert_to_str
|
||||
)
|
||||
model_output.incremental = incremental
|
||||
if task.incremental_output:
|
||||
# Output is incremental, append the text
|
||||
full_text += model_output.text
|
||||
else:
|
||||
# Output is not incremental, last output is the full text
|
||||
full_text = model_output.text
|
||||
if not incremental:
|
||||
# Return the full text
|
||||
model_output.text = full_text
|
||||
else:
|
||||
# Return the incremental text
|
||||
delta_text = full_text[len(previous_text) :]
|
||||
previous_text = (
|
||||
full_text
|
||||
if len(full_text) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
model_output.text = delta_text
|
||||
yield model_output
|
||||
if not model_output.success:
|
||||
break
|
||||
|
||||
|
||||
def parse_single_output(
|
||||
output: Any, is_sse: bool, covert_to_str: bool = False
|
||||
) -> ModelOutput:
|
||||
"""Parse the single output.
|
||||
|
||||
Args:
|
||||
output (Any): The output to parse.
|
||||
is_sse (bool): Whether the output is in SSE format.
|
||||
covert_to_str (bool, optional): Whether to convert the output to string.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The parsed output.
|
||||
"""
|
||||
finish_reason: Optional[str] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
metrics: Optional[ModelInferenceMetrics] = None
|
||||
|
||||
if output is None:
|
||||
error_code = 1
|
||||
text = "The output is None!"
|
||||
elif isinstance(output, str):
|
||||
if is_sse:
|
||||
sse_output = parse_sse_data(output)
|
||||
if sse_output is None:
|
||||
error_code = 1
|
||||
text = "The output is not a SSE format"
|
||||
else:
|
||||
error_code = 0
|
||||
text = sse_output
|
||||
else:
|
||||
error_code = 0
|
||||
text = output
|
||||
elif isinstance(output, ModelOutput):
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
finish_reason = output.finish_reason
|
||||
usage = output.usage
|
||||
metrics = output.metrics
|
||||
elif isinstance(output, CommonLLMHttpResponseBody):
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
elif isinstance(output, dict):
|
||||
error_code = 0
|
||||
text = json.dumps(output, ensure_ascii=False)
|
||||
elif covert_to_str:
|
||||
error_code = 0
|
||||
text = str(output)
|
||||
else:
|
||||
error_code = 1
|
||||
text = f"The output is not a valid format({type(output)})"
|
||||
return ModelOutput(
|
||||
error_code=error_code,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
|
||||
def parse_openai_output(output: Any) -> ModelOutput:
|
||||
"""Parse the OpenAI output.
|
||||
|
||||
Args:
|
||||
output (Any): The output to parse. It must be a stream format.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The parsed output.
|
||||
"""
|
||||
text = ""
|
||||
if not isinstance(output, str):
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text="The output is not a stream format",
|
||||
)
|
||||
if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]":
|
||||
return ModelOutput(error_code=0, text="")
|
||||
if not output.startswith("data:"):
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text="The output is not a stream format",
|
||||
)
|
||||
|
||||
sse_output = parse_sse_data(output)
|
||||
if sse_output is None:
|
||||
return ModelOutput(error_code=1, text="The output is not a SSE format")
|
||||
json_data = sse_output.strip()
|
||||
try:
|
||||
dict_data = json.loads(json_data)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text=f"Invalid JSON data: {json_data}, {e}",
|
||||
)
|
||||
if "choices" not in dict_data:
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text=dict_data.get("text", "Unknown error"),
|
||||
)
|
||||
choices = dict_data["choices"]
|
||||
finish_reason: Optional[str] = None
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
delta_data = ChatCompletionResponseStreamChoice(**choice)
|
||||
if delta_data.delta.content:
|
||||
text = delta_data.delta.content
|
||||
finish_reason = delta_data.finish_reason
|
||||
return ModelOutput(error_code=0, text=text, finish_reason=finish_reason)
|
||||
|
||||
|
||||
def parse_sse_data(output: str) -> Optional[str]:
|
||||
r"""Parse the SSE data.
|
||||
|
||||
Just keep the data part.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.awel.util.chat_util import parse_sse_data
|
||||
|
||||
assert parse_sse_data("data: [DONE]") == "[DONE]"
|
||||
assert parse_sse_data("data:[DONE]") == "[DONE]"
|
||||
assert parse_sse_data("data: Hello") == "Hello"
|
||||
assert parse_sse_data("data: Hello\n") == "Hello"
|
||||
assert parse_sse_data("data: Hello\r\n") == "Hello"
|
||||
assert parse_sse_data("data: Hi, what's up?") == "Hi, what's up?"
|
||||
|
||||
Args:
|
||||
output (str): The output.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The parsed data.
|
||||
"""
|
||||
if output.startswith("data:"):
|
||||
output = output.strip()
|
||||
if output.startswith("data: "):
|
||||
output = output[6:]
|
||||
else:
|
||||
output = output[5:]
|
||||
|
||||
return output
|
||||
else:
|
||||
return None
|
Reference in New Issue
Block a user