mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
feat(core): Multiple ways to run dbgpts (#1734)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
DAG is the core component of AWEL, it is used to define the relationship between tasks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
@@ -613,10 +614,14 @@ class DAG:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
self,
|
||||
dag_id: str,
|
||||
resource_group: Optional[ResourceGroup] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAG."""
|
||||
self._dag_id = dag_id
|
||||
self._tags: Dict[str, str] = tags or {}
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = []
|
||||
@@ -651,6 +656,22 @@ class DAG:
|
||||
"""Return the dag id of current DAG."""
|
||||
return self._dag_id
|
||||
|
||||
@property
|
||||
def tags(self) -> Dict[str, str]:
|
||||
"""Return the tags of current DAG."""
|
||||
return self._tags
|
||||
|
||||
@property
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether the current DAG is in dev mode.
|
||||
|
||||
Returns:
|
||||
bool: Whether the current DAG is in dev mode
|
||||
"""
|
||||
from ..operators.base import _dev_mode
|
||||
|
||||
return _dev_mode()
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operators.common_operator import TriggerOperator
|
||||
|
||||
|
@@ -3,18 +3,49 @@
|
||||
DAGManager will load DAGs from dag_dirs, and register the trigger nodes
|
||||
to TriggerManager.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
from .. import BaseOperator
|
||||
from ..trigger.base import TriggerMetadata
|
||||
from .base import DAG
|
||||
from .loader import LocalFileDAGLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGMetadata(BaseModel):
|
||||
"""Metadata for the DAG."""
|
||||
|
||||
triggers: List[TriggerMetadata] = Field(
|
||||
default_factory=list, description="The trigger metadata"
|
||||
)
|
||||
sse_output: bool = Field(
|
||||
default=False, description="Whether the DAG is a server-sent event output"
|
||||
)
|
||||
streaming_output: bool = Field(
|
||||
default=False, description="Whether the DAG is a streaming output"
|
||||
)
|
||||
tags: Optional[Dict[str, str]] = Field(
|
||||
default=None, description="The tags of the DAG"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert the metadata to dict."""
|
||||
triggers_dict = []
|
||||
for trigger in self.triggers:
|
||||
triggers_dict.append(trigger.dict())
|
||||
dict_value = model_to_dict(self, exclude={"triggers"})
|
||||
dict_value["triggers"] = triggers_dict
|
||||
return dict_value
|
||||
|
||||
|
||||
class DAGManager(BaseComponent):
|
||||
"""The component of DAGManager."""
|
||||
|
||||
@@ -35,6 +66,8 @@ class DAGManager(BaseComponent):
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
self.dag_alias_map: Dict[str, str] = {}
|
||||
self._dag_metadata_map: Dict[str, DAGMetadata] = {}
|
||||
self._tags_to_dag_ids: Dict[str, Dict[str, Set[str]]] = {}
|
||||
self._trigger_manager: Optional["DefaultTriggerManager"] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
@@ -73,12 +106,26 @@ class DAGManager(BaseComponent):
|
||||
if alias_name:
|
||||
self.dag_alias_map[alias_name] = dag_id
|
||||
|
||||
trigger_metadata: List["TriggerMetadata"] = []
|
||||
dag_metadata = _parse_metadata(dag)
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.register_trigger(trigger, self.system_app)
|
||||
tm = self._trigger_manager.register_trigger(
|
||||
trigger, self.system_app
|
||||
)
|
||||
if tm:
|
||||
trigger_metadata.append(tm)
|
||||
self._trigger_manager.after_register()
|
||||
else:
|
||||
logger.warning("No trigger manager, not register dag trigger")
|
||||
dag_metadata.triggers = trigger_metadata
|
||||
self._dag_metadata_map[dag_id] = dag_metadata
|
||||
tags = dag_metadata.tags
|
||||
if tags:
|
||||
for tag_key, tag_value in tags.items():
|
||||
if tag_key not in self._tags_to_dag_ids:
|
||||
self._tags_to_dag_ids[tag_key] = defaultdict(set)
|
||||
self._tags_to_dag_ids[tag_key][tag_value].add(dag_id)
|
||||
|
||||
def unregister_dag(self, dag_id: str):
|
||||
"""Unregister a DAG."""
|
||||
@@ -104,7 +151,13 @@ class DAGManager(BaseComponent):
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.unregister_trigger(trigger, self.system_app)
|
||||
# Finally remove the DAG from the map
|
||||
metadata = self._dag_metadata_map[dag_id]
|
||||
del self.dag_map[dag_id]
|
||||
del self._dag_metadata_map[dag_id]
|
||||
if metadata.tags:
|
||||
for tag_key, tag_value in metadata.tags.items():
|
||||
if tag_key in self._tags_to_dag_ids:
|
||||
self._tags_to_dag_ids[tag_key][tag_value].remove(dag_id)
|
||||
|
||||
def get_dag(
|
||||
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
|
||||
@@ -116,3 +169,33 @@ class DAGManager(BaseComponent):
|
||||
if alias_name in self.dag_alias_map:
|
||||
return self.dag_map.get(self.dag_alias_map[alias_name])
|
||||
return None
|
||||
|
||||
def get_dags_by_tag(self, tag_key: str, tag_value) -> List[DAG]:
|
||||
"""Get all DAGs with the given tag."""
|
||||
with self.lock:
|
||||
dag_ids = self._tags_to_dag_ids.get(tag_key, {}).get(tag_value, set())
|
||||
return [self.dag_map[dag_id] for dag_id in dag_ids]
|
||||
|
||||
def get_dag_metadata(
|
||||
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
|
||||
) -> Optional[DAGMetadata]:
|
||||
"""Get a DAGMetadata by dag_id or alias_name."""
|
||||
dag = self.get_dag(dag_id, alias_name)
|
||||
if not dag:
|
||||
return None
|
||||
return self._dag_metadata_map.get(dag.dag_id)
|
||||
|
||||
|
||||
def _parse_metadata(dag: DAG):
|
||||
from ..util.chat_util import _is_sse_output
|
||||
|
||||
metadata = DAGMetadata()
|
||||
metadata.tags = dag.tags
|
||||
if not dag.leaf_nodes:
|
||||
return metadata
|
||||
end_node = dag.leaf_nodes[0]
|
||||
if not isinstance(end_node, BaseOperator):
|
||||
return metadata
|
||||
metadata.sse_output = _is_sse_output(end_node)
|
||||
metadata.streaming_output = end_node.streaming_operator
|
||||
return metadata
|
||||
|
@@ -18,6 +18,7 @@ from dbgpt._private.pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.core.awel.dag.base import DAG, DAGNode
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGMetadata
|
||||
|
||||
from .base import (
|
||||
OperatorType,
|
||||
@@ -352,6 +353,9 @@ class FlowPanel(BaseModel):
|
||||
description="The flow panel modified time.",
|
||||
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
|
||||
)
|
||||
metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field(
|
||||
default=None, description="The metadata of the flow"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@@ -69,6 +69,15 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
default_runner: Optional[WorkflowRunner] = None
|
||||
|
||||
|
||||
def _dev_mode() -> bool:
|
||||
"""Check if the operator is in dev mode.
|
||||
|
||||
In production mode, the default runner is not None, and the operator will run in
|
||||
the same process with the DB-GPT webserver.
|
||||
"""
|
||||
return default_runner is None
|
||||
|
||||
|
||||
class BaseOperatorMeta(ABCMeta):
|
||||
"""Metaclass of BaseOperator."""
|
||||
|
||||
@@ -86,7 +95,9 @@ class BaseOperatorMeta(ABCMeta):
|
||||
if not executor:
|
||||
if system_app:
|
||||
executor = system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory
|
||||
ComponentType.EXECUTOR_DEFAULT,
|
||||
DefaultExecutorFactory,
|
||||
default_component=DefaultExecutorFactory(),
|
||||
).create() # type: ignore
|
||||
else:
|
||||
executor = DefaultExecutorFactory().create()
|
||||
@@ -173,13 +184,14 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether the operator is in dev mode.
|
||||
|
||||
In production mode, the default runner is not None.
|
||||
In production mode, the default runner is not None, and the operator will run in
|
||||
the same process with the DB-GPT webserver.
|
||||
|
||||
Returns:
|
||||
bool: Whether the operator is in dev mode. True if the
|
||||
default runner is None.
|
||||
"""
|
||||
return default_runner is None
|
||||
return _dev_mode()
|
||||
|
||||
async def _run(self, dag_ctx: DAGContext, task_log_id: str) -> TaskOutput[OUT]:
|
||||
if not self.node_id:
|
||||
|
@@ -1,13 +1,22 @@
|
||||
"""Base class for all trigger classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic
|
||||
from typing import Any, Generic, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from ..operators.common_operator import TriggerOperator
|
||||
from ..task.base import OUT
|
||||
|
||||
|
||||
class TriggerMetadata(BaseModel):
|
||||
"""Metadata for the trigger."""
|
||||
|
||||
trigger_type: Optional[str] = Field(
|
||||
default=None, description="The type of the trigger"
|
||||
)
|
||||
|
||||
|
||||
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
|
||||
"""Base class for all trigger classes.
|
||||
|
||||
|
@@ -43,7 +43,7 @@ from ..operators.base import BaseOperator
|
||||
from ..operators.common_operator import MapOperator
|
||||
from ..util._typing_util import _parse_bool
|
||||
from ..util.http_util import join_paths
|
||||
from .base import Trigger
|
||||
from .base import Trigger, TriggerMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
@@ -82,6 +82,17 @@ def _default_streaming_predict_func(body: "CommonRequestType") -> bool:
|
||||
return _parse_bool(streaming)
|
||||
|
||||
|
||||
class HttpTriggerMetadata(TriggerMetadata):
|
||||
"""Trigger metadata."""
|
||||
|
||||
path: str = Field(..., description="The path of the trigger")
|
||||
methods: List[str] = Field(..., description="The methods of the trigger")
|
||||
|
||||
trigger_type: Optional[str] = Field(
|
||||
default="http", description="The type of the trigger"
|
||||
)
|
||||
|
||||
|
||||
class BaseHttpBody(BaseModel):
|
||||
"""Http body.
|
||||
|
||||
@@ -444,7 +455,7 @@ class HttpTrigger(Trigger):
|
||||
|
||||
def mount_to_router(
|
||||
self, router: "APIRouter", global_prefix: Optional[str] = None
|
||||
) -> None:
|
||||
) -> HttpTriggerMetadata:
|
||||
"""Mount the trigger to a router.
|
||||
|
||||
Args:
|
||||
@@ -466,8 +477,11 @@ class HttpTrigger(Trigger):
|
||||
)(dynamic_route_function)
|
||||
|
||||
logger.info(f"Mount http trigger success, path: {path}")
|
||||
return HttpTriggerMetadata(path=path, methods=self._methods)
|
||||
|
||||
def mount_to_app(self, app: "FastAPI", global_prefix: Optional[str] = None) -> None:
|
||||
def mount_to_app(
|
||||
self, app: "FastAPI", global_prefix: Optional[str] = None
|
||||
) -> HttpTriggerMetadata:
|
||||
"""Mount the trigger to a FastAPI app.
|
||||
|
||||
TODO: The performance of this method is not good, need to be optimized.
|
||||
@@ -498,6 +512,7 @@ class HttpTrigger(Trigger):
|
||||
app.openapi_schema = None
|
||||
app.middleware_stack = None
|
||||
logger.info(f"Mount http trigger success, path: {path}")
|
||||
return HttpTriggerMetadata(path=path, methods=self._methods)
|
||||
|
||||
def remove_from_app(
|
||||
self, app: "FastAPI", global_prefix: Optional[str] = None
|
||||
|
@@ -2,15 +2,16 @@
|
||||
|
||||
After DB-GPT started, the trigger manager will be initialized and register all triggers
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
from ..util.http_util import join_paths
|
||||
from .base import Trigger
|
||||
from .base import Trigger, TriggerMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
@@ -23,7 +24,9 @@ class TriggerManager(ABC):
|
||||
"""Base class for trigger manager."""
|
||||
|
||||
@abstractmethod
|
||||
def register_trigger(self, trigger: Any, system_app: SystemApp) -> None:
|
||||
def register_trigger(
|
||||
self, trigger: Any, system_app: SystemApp
|
||||
) -> Optional[TriggerMetadata]:
|
||||
"""Register a trigger to current manager."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -65,10 +68,12 @@ class HttpTriggerManager(TriggerManager):
|
||||
self._inited = False
|
||||
self._router_prefix = router_prefix
|
||||
self._router = router
|
||||
self._trigger_map: Dict[str, Trigger] = {}
|
||||
self._trigger_map: Dict[str, Tuple[Trigger, TriggerMetadata]] = {}
|
||||
self._router_tables: Dict[str, Set[str]] = defaultdict(set)
|
||||
|
||||
def register_trigger(self, trigger: Any, system_app: SystemApp) -> None:
|
||||
def register_trigger(
|
||||
self, trigger: Any, system_app: SystemApp
|
||||
) -> Optional[TriggerMetadata]:
|
||||
"""Register a trigger to current manager."""
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
@@ -86,13 +91,17 @@ class HttpTriggerManager(TriggerManager):
|
||||
if not app:
|
||||
raise ValueError("System app not initialized")
|
||||
# Mount to app, support dynamic route.
|
||||
trigger.mount_to_app(app, self._router_prefix)
|
||||
trigger_metadata = trigger.mount_to_app(app, self._router_prefix)
|
||||
else:
|
||||
trigger.mount_to_router(self._router, self._router_prefix)
|
||||
self._trigger_map[trigger_id] = trigger
|
||||
trigger_metadata = trigger.mount_to_router(
|
||||
self._router, self._router_prefix
|
||||
)
|
||||
self._trigger_map[trigger_id] = (trigger, trigger_metadata)
|
||||
return trigger_metadata
|
||||
except Exception as e:
|
||||
self._unregister_route_tables(path, methods)
|
||||
raise e
|
||||
return None
|
||||
|
||||
def unregister_trigger(self, trigger: Any, system_app: SystemApp) -> None:
|
||||
"""Unregister a trigger to current manager."""
|
||||
@@ -183,7 +192,9 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
if system_app and self.system_app.app:
|
||||
self._http_trigger = HttpTriggerManager()
|
||||
|
||||
def register_trigger(self, trigger: Any, system_app: SystemApp) -> None:
|
||||
def register_trigger(
|
||||
self, trigger: Any, system_app: SystemApp
|
||||
) -> Optional[TriggerMetadata]:
|
||||
"""Register a trigger to current manager."""
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
@@ -191,7 +202,9 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
logger.info(f"Register trigger {trigger}")
|
||||
if not self._http_trigger:
|
||||
raise ValueError("Http trigger manager not initialized")
|
||||
self._http_trigger.register_trigger(trigger, system_app)
|
||||
return self._http_trigger.register_trigger(trigger, system_app)
|
||||
else:
|
||||
return None
|
||||
|
||||
def unregister_trigger(self, trigger: Any, system_app: SystemApp) -> None:
|
||||
"""Unregister a trigger to current manager."""
|
||||
|
323
dbgpt/core/awel/util/chat_util.py
Normal file
323
dbgpt/core/awel/util/chat_util.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""The utility functions for chatting with the DAG task."""
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
from ...interface.llm import ModelInferenceMetrics, ModelOutput
|
||||
from ...schema.api import ChatCompletionResponseStreamChoice
|
||||
from ..operators.base import BaseOperator
|
||||
from ..trigger.http_trigger import CommonLLMHttpResponseBody
|
||||
|
||||
|
||||
def is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||
"""Check whether the output object is a chat flow type."""
|
||||
if is_class:
|
||||
return output_obj in (str, CommonLLMHttpResponseBody, ModelOutput)
|
||||
else:
|
||||
chat_types = (str, CommonLLMHttpResponseBody)
|
||||
return isinstance(output_obj, chat_types)
|
||||
|
||||
|
||||
async def safe_chat_with_dag_task(
|
||||
task: BaseOperator, request: Any, covert_to_str: bool = False
|
||||
) -> ModelOutput:
|
||||
"""Chat with the DAG task.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The DAG task to be executed.
|
||||
request (Any): The request to be passed to the DAG task.
|
||||
covert_to_str (bool, optional): Whether to convert the output to string.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The model output, the result is not incremental.
|
||||
"""
|
||||
try:
|
||||
finish_reason = None
|
||||
usage = None
|
||||
metrics = None
|
||||
error_code = 0
|
||||
text = ""
|
||||
async for output in safe_chat_stream_with_dag_task(
|
||||
task, request, False, covert_to_str=covert_to_str
|
||||
):
|
||||
finish_reason = output.finish_reason
|
||||
usage = output.usage
|
||||
metrics = output.metrics
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
return ModelOutput(
|
||||
error_code=error_code,
|
||||
text=text,
|
||||
metrics=metrics,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return ModelOutput(error_code=1, text=str(e), incremental=False)
|
||||
|
||||
|
||||
async def safe_chat_stream_with_dag_task(
|
||||
task: BaseOperator, request: Any, incremental: bool, covert_to_str: bool = False
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Chat with the DAG task.
|
||||
|
||||
This function is similar to `chat_stream_with_dag_task`, but it will catch the
|
||||
exception and return the error message.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The DAG task to be executed.
|
||||
request (Any): The request to be passed to the DAG task.
|
||||
incremental (bool): Whether the output is incremental.
|
||||
covert_to_str (bool, optional): Whether to convert the output to string.
|
||||
|
||||
Yields:
|
||||
ModelOutput: The model output.
|
||||
"""
|
||||
try:
|
||||
async for output in chat_stream_with_dag_task(
|
||||
task, request, incremental, covert_to_str=covert_to_str
|
||||
):
|
||||
yield output
|
||||
except Exception as e:
|
||||
simple_error_msg = str(e)
|
||||
if not simple_error_msg:
|
||||
simple_error_msg = traceback.format_exc()
|
||||
yield ModelOutput(error_code=1, text=simple_error_msg, incremental=incremental)
|
||||
finally:
|
||||
if task.streaming_operator and task.dag:
|
||||
await task.dag._after_dag_end(task.current_event_loop_task_id)
|
||||
|
||||
|
||||
def _is_sse_output(task: BaseOperator) -> bool:
|
||||
"""Check whether the DAG task is a server-sent event output.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The DAG task.
|
||||
|
||||
Returns:
|
||||
bool: Whether the DAG task is a server-sent event output.
|
||||
"""
|
||||
return task.output_format is not None and task.output_format.upper() == "SSE"
|
||||
|
||||
|
||||
async def chat_stream_with_dag_task(
|
||||
task: BaseOperator, request: Any, incremental: bool, covert_to_str: bool = False
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Chat with the DAG task.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The DAG task to be executed.
|
||||
request (Any): The request to be passed to the DAG task.
|
||||
incremental (bool): Whether the output is incremental.
|
||||
covert_to_str (bool, optional): Whether to convert the output to string.
|
||||
|
||||
Yields:
|
||||
ModelOutput: The model output.
|
||||
"""
|
||||
is_sse = _is_sse_output(task)
|
||||
if not task.streaming_operator:
|
||||
try:
|
||||
result = await task.call(request)
|
||||
model_output = parse_single_output(
|
||||
result, is_sse, covert_to_str=covert_to_str
|
||||
)
|
||||
model_output.incremental = incremental
|
||||
yield model_output
|
||||
except Exception as e:
|
||||
simple_error_msg = str(e)
|
||||
if not simple_error_msg:
|
||||
simple_error_msg = traceback.format_exc()
|
||||
yield ModelOutput(
|
||||
error_code=1, text=simple_error_msg, incremental=incremental
|
||||
)
|
||||
else:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
|
||||
if OpenAIStreamingOutputOperator and isinstance(
|
||||
task, OpenAIStreamingOutputOperator
|
||||
):
|
||||
full_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
model_output = parse_openai_output(output)
|
||||
# The output of the OpenAI streaming API is incremental
|
||||
full_text += model_output.text
|
||||
model_output.incremental = incremental
|
||||
model_output.text = model_output.text if incremental else full_text
|
||||
yield model_output
|
||||
if not model_output.success:
|
||||
break
|
||||
else:
|
||||
full_text = ""
|
||||
previous_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
model_output = parse_single_output(
|
||||
output, is_sse, covert_to_str=covert_to_str
|
||||
)
|
||||
model_output.incremental = incremental
|
||||
if task.incremental_output:
|
||||
# Output is incremental, append the text
|
||||
full_text += model_output.text
|
||||
else:
|
||||
# Output is not incremental, last output is the full text
|
||||
full_text = model_output.text
|
||||
if not incremental:
|
||||
# Return the full text
|
||||
model_output.text = full_text
|
||||
else:
|
||||
# Return the incremental text
|
||||
delta_text = full_text[len(previous_text) :]
|
||||
previous_text = (
|
||||
full_text
|
||||
if len(full_text) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
model_output.text = delta_text
|
||||
yield model_output
|
||||
if not model_output.success:
|
||||
break
|
||||
|
||||
|
||||
def parse_single_output(
|
||||
output: Any, is_sse: bool, covert_to_str: bool = False
|
||||
) -> ModelOutput:
|
||||
"""Parse the single output.
|
||||
|
||||
Args:
|
||||
output (Any): The output to parse.
|
||||
is_sse (bool): Whether the output is in SSE format.
|
||||
covert_to_str (bool, optional): Whether to convert the output to string.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The parsed output.
|
||||
"""
|
||||
finish_reason: Optional[str] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
metrics: Optional[ModelInferenceMetrics] = None
|
||||
|
||||
if output is None:
|
||||
error_code = 1
|
||||
text = "The output is None!"
|
||||
elif isinstance(output, str):
|
||||
if is_sse:
|
||||
sse_output = parse_sse_data(output)
|
||||
if sse_output is None:
|
||||
error_code = 1
|
||||
text = "The output is not a SSE format"
|
||||
else:
|
||||
error_code = 0
|
||||
text = sse_output
|
||||
else:
|
||||
error_code = 0
|
||||
text = output
|
||||
elif isinstance(output, ModelOutput):
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
finish_reason = output.finish_reason
|
||||
usage = output.usage
|
||||
metrics = output.metrics
|
||||
elif isinstance(output, CommonLLMHttpResponseBody):
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
elif isinstance(output, dict):
|
||||
error_code = 0
|
||||
text = json.dumps(output, ensure_ascii=False)
|
||||
elif covert_to_str:
|
||||
error_code = 0
|
||||
text = str(output)
|
||||
else:
|
||||
error_code = 1
|
||||
text = f"The output is not a valid format({type(output)})"
|
||||
return ModelOutput(
|
||||
error_code=error_code,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
|
||||
def parse_openai_output(output: Any) -> ModelOutput:
|
||||
"""Parse the OpenAI output.
|
||||
|
||||
Args:
|
||||
output (Any): The output to parse. It must be a stream format.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The parsed output.
|
||||
"""
|
||||
text = ""
|
||||
if not isinstance(output, str):
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text="The output is not a stream format",
|
||||
)
|
||||
if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]":
|
||||
return ModelOutput(error_code=0, text="")
|
||||
if not output.startswith("data:"):
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text="The output is not a stream format",
|
||||
)
|
||||
|
||||
sse_output = parse_sse_data(output)
|
||||
if sse_output is None:
|
||||
return ModelOutput(error_code=1, text="The output is not a SSE format")
|
||||
json_data = sse_output.strip()
|
||||
try:
|
||||
dict_data = json.loads(json_data)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text=f"Invalid JSON data: {json_data}, {e}",
|
||||
)
|
||||
if "choices" not in dict_data:
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text=dict_data.get("text", "Unknown error"),
|
||||
)
|
||||
choices = dict_data["choices"]
|
||||
finish_reason: Optional[str] = None
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
delta_data = ChatCompletionResponseStreamChoice(**choice)
|
||||
if delta_data.delta.content:
|
||||
text = delta_data.delta.content
|
||||
finish_reason = delta_data.finish_reason
|
||||
return ModelOutput(error_code=0, text=text, finish_reason=finish_reason)
|
||||
|
||||
|
||||
def parse_sse_data(output: str) -> Optional[str]:
|
||||
r"""Parse the SSE data.
|
||||
|
||||
Just keep the data part.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.awel.util.chat_util import parse_sse_data
|
||||
|
||||
assert parse_sse_data("data: [DONE]") == "[DONE]"
|
||||
assert parse_sse_data("data:[DONE]") == "[DONE]"
|
||||
assert parse_sse_data("data: Hello") == "Hello"
|
||||
assert parse_sse_data("data: Hello\n") == "Hello"
|
||||
assert parse_sse_data("data: Hello\r\n") == "Hello"
|
||||
assert parse_sse_data("data: Hi, what's up?") == "Hi, what's up?"
|
||||
|
||||
Args:
|
||||
output (str): The output.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The parsed data.
|
||||
"""
|
||||
if output.startswith("data:"):
|
||||
output = output.strip()
|
||||
if output.startswith("data: "):
|
||||
output = output[6:]
|
||||
else:
|
||||
output = output[5:]
|
||||
|
||||
return output
|
||||
else:
|
||||
return None
|
Reference in New Issue
Block a user