feat: Run AWEL flow in CLI (#1341)

This commit is contained in:
Fangyin Cheng
2024-03-27 12:50:05 +08:00
committed by GitHub
parent 340a9fbc35
commit 3a7a2cbbb8
42 changed files with 1454 additions and 422 deletions

View File

@@ -127,6 +127,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
"""
streaming_operator: bool = False
incremental_output: bool = False
output_format: Optional[str] = None
def __init__(
self,
@@ -147,6 +149,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
from dbgpt.core.awel import DefaultWorkflowRunner
runner = DefaultWorkflowRunner()
if "incremental_output" in kwargs:
self.incremental_output = bool(kwargs["incremental_output"])
if "output_format" in kwargs:
self.output_format = kwargs["output_format"]
self._runner: WorkflowRunner = runner
self._dag_ctx: Optional[DAGContext] = None

View File

@@ -132,9 +132,11 @@ class IteratorTrigger(Trigger[List[Tuple[Any, Any]]]):
task_id = self.node_id
async def call_stream(call_data: Any):
async for out in await end_node.call_stream(call_data):
yield out
await dag._after_dag_end(end_node.current_event_loop_task_id)
try:
async for out in await end_node.call_stream(call_data):
yield out
finally:
await dag._after_dag_end(end_node.current_event_loop_task_id)
async def run_node(call_data: Any) -> Tuple[Any, Any]:
async with semaphore:

View File

@@ -161,6 +161,7 @@ class ModelOutput:
error_code: int
"""The error code of the model inference. If the model inference is successful,
the error code is 0."""
incremental: bool = False
model_context: Optional[Dict] = None
finish_reason: Optional[str] = None
usage: Optional[Dict[str, Any]] = None
@@ -171,6 +172,11 @@ class ModelOutput:
"""Convert the model output to dict."""
return asdict(self)
@property
def success(self) -> bool:
"""Check if the model inference is successful."""
return self.error_code == 0
_ModelMessageType = Union[List[ModelMessage], List[Dict[str, Any]]]

View File

@@ -470,6 +470,8 @@ class CommonStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]
Transform model output to the string output to show in DB-GPT chat flow page.
"""
output_format = "SSE"
metadata = ViewMetadata(
label=_("Common Streaming Output Operator"),
name="common_streaming_output_operator",
@@ -510,8 +512,8 @@ class CommonStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]
yield f"data:{error_msg}"
return
decoded_unicode = model_output.text.replace("\ufffd", "")
msg = decoded_unicode.replace("\n", "\\n")
yield f"data:{msg}\n\n"
# msg = decoded_unicode.replace("\n", "\\n")
yield f"data:{decoded_unicode}\n\n"
class StringOutput2ModelOutputOperator(MapOperator[str, ModelOutput]):

View File

@@ -114,3 +114,11 @@ class ChatCompletionResponse(BaseModel):
..., description="Chat completion response choices"
)
usage: UsageInfo = Field(..., description="Usage info")
class ErrorResponse(BaseModel):
"""Error response entity."""
object: str = Field("error", description="Object type")
message: str = Field(..., description="Error message")
code: int = Field(..., description="Error code")