"""The utility functions for chatting with the DAG task.""" import json import traceback from typing import Any, AsyncIterator, Dict, Optional from ...interface.llm import ModelInferenceMetrics, ModelOutput from ...schema.api import ChatCompletionResponseStreamChoice from ..operators.base import BaseOperator from ..trigger.http_trigger import CommonLLMHttpResponseBody def is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool: """Check whether the output object is a chat flow type.""" if is_class: return output_obj in (str, CommonLLMHttpResponseBody, ModelOutput) else: chat_types = (str, CommonLLMHttpResponseBody) return isinstance(output_obj, chat_types) async def safe_chat_with_dag_task( task: BaseOperator, request: Any, covert_to_str: bool = False ) -> ModelOutput: """Chat with the DAG task. Args: task (BaseOperator): The DAG task to be executed. request (Any): The request to be passed to the DAG task. covert_to_str (bool, optional): Whether to convert the output to string. Returns: ModelOutput: The model output, the result is not incremental. """ try: finish_reason = None usage = None metrics = None error_code = 0 text = "" async for output in safe_chat_stream_with_dag_task( task, request, False, covert_to_str=covert_to_str ): finish_reason = output.finish_reason usage = output.usage metrics = output.metrics error_code = output.error_code text = output.text return ModelOutput( error_code=error_code, text=text, metrics=metrics, usage=usage, finish_reason=finish_reason, ) except Exception as e: return ModelOutput(error_code=1, text=str(e), incremental=False) async def safe_chat_stream_with_dag_task( task: BaseOperator, request: Any, incremental: bool, covert_to_str: bool = False ) -> AsyncIterator[ModelOutput]: """Chat with the DAG task. This function is similar to `chat_stream_with_dag_task`, but it will catch the exception and return the error message. Args: task (BaseOperator): The DAG task to be executed. request (Any): The request to be passed to the DAG task. incremental (bool): Whether the output is incremental. covert_to_str (bool, optional): Whether to convert the output to string. Yields: ModelOutput: The model output. """ try: async for output in chat_stream_with_dag_task( task, request, incremental, covert_to_str=covert_to_str ): yield output except Exception as e: simple_error_msg = str(e) if not simple_error_msg: simple_error_msg = traceback.format_exc() yield ModelOutput(error_code=1, text=simple_error_msg, incremental=incremental) finally: if task.streaming_operator and task.dag: await task.dag._after_dag_end(task.current_event_loop_task_id) def _is_sse_output(task: BaseOperator) -> bool: """Check whether the DAG task is a server-sent event output. Args: task (BaseOperator): The DAG task. Returns: bool: Whether the DAG task is a server-sent event output. """ return task.output_format is not None and task.output_format.upper() == "SSE" async def chat_stream_with_dag_task( task: BaseOperator, request: Any, incremental: bool, covert_to_str: bool = False ) -> AsyncIterator[ModelOutput]: """Chat with the DAG task. Args: task (BaseOperator): The DAG task to be executed. request (Any): The request to be passed to the DAG task. incremental (bool): Whether the output is incremental. covert_to_str (bool, optional): Whether to convert the output to string. Yields: ModelOutput: The model output. """ is_sse = _is_sse_output(task) if not task.streaming_operator: try: result = await task.call(request) model_output = parse_single_output( result, is_sse, covert_to_str=covert_to_str ) model_output.incremental = incremental yield model_output except Exception as e: simple_error_msg = str(e) if not simple_error_msg: simple_error_msg = traceback.format_exc() yield ModelOutput( error_code=1, text=simple_error_msg, incremental=incremental ) else: from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator if OpenAIStreamingOutputOperator and isinstance( task, OpenAIStreamingOutputOperator ): full_text = "" async for output in await task.call_stream(request): model_output = parse_openai_output(output) # The output of the OpenAI streaming API is incremental full_text += model_output.text model_output.incremental = incremental model_output.text = model_output.text if incremental else full_text yield model_output if not model_output.success: break else: full_text = "" previous_text = "" async for output in await task.call_stream(request): model_output = parse_single_output( output, is_sse, covert_to_str=covert_to_str ) model_output.incremental = incremental if task.incremental_output: # Output is incremental, append the text full_text += model_output.text else: # Output is not incremental, last output is the full text full_text = model_output.text if not incremental: # Return the full text model_output.text = full_text else: # Return the incremental text delta_text = full_text[len(previous_text) :] previous_text = ( full_text if len(full_text) > len(previous_text) else previous_text ) model_output.text = delta_text yield model_output if not model_output.success: break def parse_single_output( output: Any, is_sse: bool, covert_to_str: bool = False ) -> ModelOutput: """Parse the single output. Args: output (Any): The output to parse. is_sse (bool): Whether the output is in SSE format. covert_to_str (bool, optional): Whether to convert the output to string. Defaults to False. Returns: ModelOutput: The parsed output. """ finish_reason: Optional[str] = None usage: Optional[Dict[str, Any]] = None metrics: Optional[ModelInferenceMetrics] = None if output is None: error_code = 1 text = "The output is None!" elif isinstance(output, str): if is_sse: sse_output = parse_sse_data(output) if sse_output is None: error_code = 1 text = "The output is not a SSE format" else: error_code = 0 text = sse_output else: error_code = 0 text = output elif isinstance(output, ModelOutput): error_code = output.error_code text = output.text finish_reason = output.finish_reason usage = output.usage metrics = output.metrics elif isinstance(output, CommonLLMHttpResponseBody): error_code = output.error_code text = output.text elif isinstance(output, dict): error_code = 0 text = json.dumps(output, ensure_ascii=False) elif covert_to_str: error_code = 0 text = str(output) else: error_code = 1 text = f"The output is not a valid format({type(output)})" return ModelOutput( error_code=error_code, text=text, finish_reason=finish_reason, usage=usage, metrics=metrics, ) def parse_openai_output(output: Any) -> ModelOutput: """Parse the OpenAI output. Args: output (Any): The output to parse. It must be a stream format. Returns: ModelOutput: The parsed output. """ text = "" if not isinstance(output, str): return ModelOutput( error_code=1, text="The output is not a stream format", ) if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]": return ModelOutput(error_code=0, text="") if not output.startswith("data:"): return ModelOutput( error_code=1, text="The output is not a stream format", ) sse_output = parse_sse_data(output) if sse_output is None: return ModelOutput(error_code=1, text="The output is not a SSE format") json_data = sse_output.strip() try: dict_data = json.loads(json_data) except Exception as e: return ModelOutput( error_code=1, text=f"Invalid JSON data: {json_data}, {e}", ) if "choices" not in dict_data: return ModelOutput( error_code=1, text=dict_data.get("text", "Unknown error"), ) choices = dict_data["choices"] finish_reason: Optional[str] = None if choices: choice = choices[0] delta_data = ChatCompletionResponseStreamChoice(**choice) if delta_data.delta.content: text = delta_data.delta.content finish_reason = delta_data.finish_reason return ModelOutput(error_code=0, text=text, finish_reason=finish_reason) def parse_sse_data(output: str) -> Optional[str]: r"""Parse the SSE data. Just keep the data part. Examples: .. code-block:: python from dbgpt.core.awel.util.chat_util import parse_sse_data assert parse_sse_data("data: [DONE]") == "[DONE]" assert parse_sse_data("data:[DONE]") == "[DONE]" assert parse_sse_data("data: Hello") == "Hello" assert parse_sse_data("data: Hello\n") == "Hello" assert parse_sse_data("data: Hello\r\n") == "Hello" assert parse_sse_data("data: Hi, what's up?") == "Hi, what's up?" Args: output (str): The output. Returns: Optional[str]: The parsed data. """ if output.startswith("data:"): output = output.strip() if output.startswith("data: "): output = output[6:] else: output = output[5:] return output else: return None