mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat: Run AWEL flow in CLI (#1341)
This commit is contained in:
@@ -127,6 +127,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
"""
|
||||
|
||||
streaming_operator: bool = False
|
||||
incremental_output: bool = False
|
||||
output_format: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -147,6 +149,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
from dbgpt.core.awel import DefaultWorkflowRunner
|
||||
|
||||
runner = DefaultWorkflowRunner()
|
||||
if "incremental_output" in kwargs:
|
||||
self.incremental_output = bool(kwargs["incremental_output"])
|
||||
if "output_format" in kwargs:
|
||||
self.output_format = kwargs["output_format"]
|
||||
|
||||
self._runner: WorkflowRunner = runner
|
||||
self._dag_ctx: Optional[DAGContext] = None
|
||||
|
@@ -132,9 +132,11 @@ class IteratorTrigger(Trigger[List[Tuple[Any, Any]]]):
|
||||
task_id = self.node_id
|
||||
|
||||
async def call_stream(call_data: Any):
|
||||
async for out in await end_node.call_stream(call_data):
|
||||
yield out
|
||||
await dag._after_dag_end(end_node.current_event_loop_task_id)
|
||||
try:
|
||||
async for out in await end_node.call_stream(call_data):
|
||||
yield out
|
||||
finally:
|
||||
await dag._after_dag_end(end_node.current_event_loop_task_id)
|
||||
|
||||
async def run_node(call_data: Any) -> Tuple[Any, Any]:
|
||||
async with semaphore:
|
||||
|
@@ -161,6 +161,7 @@ class ModelOutput:
|
||||
error_code: int
|
||||
"""The error code of the model inference. If the model inference is successful,
|
||||
the error code is 0."""
|
||||
incremental: bool = False
|
||||
model_context: Optional[Dict] = None
|
||||
finish_reason: Optional[str] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
@@ -171,6 +172,11 @@ class ModelOutput:
|
||||
"""Convert the model output to dict."""
|
||||
return asdict(self)
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
"""Check if the model inference is successful."""
|
||||
return self.error_code == 0
|
||||
|
||||
|
||||
_ModelMessageType = Union[List[ModelMessage], List[Dict[str, Any]]]
|
||||
|
||||
|
@@ -470,6 +470,8 @@ class CommonStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]
|
||||
Transform model output to the string output to show in DB-GPT chat flow page.
|
||||
"""
|
||||
|
||||
output_format = "SSE"
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Common Streaming Output Operator"),
|
||||
name="common_streaming_output_operator",
|
||||
@@ -510,8 +512,8 @@ class CommonStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]
|
||||
yield f"data:{error_msg}"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
msg = decoded_unicode.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
# msg = decoded_unicode.replace("\n", "\\n")
|
||||
yield f"data:{decoded_unicode}\n\n"
|
||||
|
||||
|
||||
class StringOutput2ModelOutputOperator(MapOperator[str, ModelOutput]):
|
||||
|
@@ -114,3 +114,11 @@ class ChatCompletionResponse(BaseModel):
|
||||
..., description="Chat completion response choices"
|
||||
)
|
||||
usage: UsageInfo = Field(..., description="Usage info")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response entity."""
|
||||
|
||||
object: str = Field("error", description="Object type")
|
||||
message: str = Field(..., description="Error message")
|
||||
code: int = Field(..., description="Error code")
|
||||
|
Reference in New Issue
Block a user