fix(core): Fix AWEL branch bug (#1640)

This commit is contained in:
Fangyin Cheng
2024-06-18 11:11:43 +08:00
committed by GitHub
parent 49b56b4576
commit ace169ac46
32 changed files with 870 additions and 481 deletions

View File

@@ -2,12 +2,13 @@
import dataclasses
from abc import ABC
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from typing import Any, AsyncIterator, Dict, List, Optional, Union, cast
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.awel import (
BaseOperator,
BranchFunc,
BranchJoinOperator,
BranchOperator,
CommonLLMHttpRequestBody,
CommonLLMHttpResponseBody,
@@ -340,24 +341,7 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
category=OperatorCategory.LLM,
operator_type=OperatorType.BRANCH,
description=_("Branch the workflow based on the stream flag of the request."),
parameters=[
Parameter.build_from(
_("Streaming Task Name"),
"stream_task_name",
str,
optional=True,
default="streaming_llm_task",
description=_("The name of the streaming task."),
),
Parameter.build_from(
_("Non-Streaming Task Name"),
"no_stream_task_name",
str,
optional=True,
default="llm_task",
description=_("The name of the non-streaming task."),
),
],
parameters=[],
inputs=[
IOField.build_from(
_("Model Request"),
@@ -382,7 +366,12 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
],
)
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
def __init__(
self,
stream_task_name: Optional[str] = None,
no_stream_task_name: Optional[str] = None,
**kwargs,
):
"""Create a new LLM branch operator.
Args:
@@ -390,18 +379,13 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
no_stream_task_name (str): The name of the non-streaming task.
"""
super().__init__(**kwargs)
if not stream_task_name:
raise ValueError("stream_task_name is not set")
if not no_stream_task_name:
raise ValueError("no_stream_task_name is not set")
self._stream_task_name = stream_task_name
self._no_stream_task_name = no_stream_task_name
async def branches(
self,
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
"""
Return a dict of branch function and task name.
"""Return a dict of branch function and task name.
Returns:
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task
@@ -409,6 +393,18 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
If the predicate function returns True, we will run the corresponding
task.
"""
if self._stream_task_name and self._no_stream_task_name:
stream_task_name = self._stream_task_name
no_stream_task_name = self._no_stream_task_name
else:
stream_task_name = ""
no_stream_task_name = ""
for node in self.downstream:
task = cast(BaseOperator, node)
if task.streaming_operator:
stream_task_name = node.node_name
else:
no_stream_task_name = node.node_name
async def check_stream_true(r: ModelRequest) -> bool:
# If stream is true, we will run the streaming task. otherwise, we will run
@@ -416,8 +412,8 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
return r.stream
return {
check_stream_true: self._stream_task_name,
lambda x: not x.stream: self._no_stream_task_name,
check_stream_true: stream_task_name,
lambda x: not x.stream: no_stream_task_name,
}
@@ -553,3 +549,93 @@ class StringOutput2ModelOutputOperator(MapOperator[str, ModelOutput]):
text=input_value,
error_code=500,
)
class LLMBranchJoinOperator(BranchJoinOperator[ModelOutput]):
"""The LLM Branch Join Operator.
Decide which output to keep(streaming or non-streaming).
"""
streaming_operator = True
metadata = ViewMetadata(
label=_("LLM Branch Join Operator"),
name="llm_branch_join_operator",
category=OperatorCategory.LLM,
operator_type=OperatorType.JOIN,
description=_("Just keep the first non-empty output."),
parameters=[],
inputs=[
IOField.build_from(
_("Streaming Model Output"),
"stream_output",
ModelOutput,
is_list=True,
description=_("The streaming output."),
),
IOField.build_from(
_("Non-Streaming Model Output"),
"not_stream_output",
ModelOutput,
description=_("The non-streaming output."),
),
],
outputs=[
IOField.build_from(
_("Model Output"),
"output_value",
ModelOutput,
is_list=True,
description=_("The output value of the operator."),
),
],
)
def __init__(self, **kwargs):
"""Create a new LLM branch join operator."""
super().__init__(**kwargs)
class StringBranchJoinOperator(BranchJoinOperator[str]):
"""The String Branch Join Operator.
Decide which output to keep(streaming or non-streaming).
"""
streaming_operator = True
metadata = ViewMetadata(
label=_("String Branch Join Operator"),
name="string_branch_join_operator",
category=OperatorCategory.COMMON,
operator_type=OperatorType.JOIN,
description=_("Just keep the first non-empty output."),
parameters=[],
inputs=[
IOField.build_from(
_("Streaming String Output"),
"stream_output",
str,
is_list=True,
description=_("The streaming output."),
),
IOField.build_from(
_("Non-Streaming String Output"),
"not_stream_output",
str,
description=_("The non-streaming output."),
),
],
outputs=[
IOField.build_from(
_("String Output"),
"output_value",
str,
is_list=True,
description=_("The output value of the operator."),
),
],
)
def __init__(self, **kwargs):
"""Create a new LLM branch join operator."""
super().__init__(**kwargs)