mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 12:37:14 +00:00
fix(core): Fix AWEL branch bug (#1640)
This commit is contained in:
@@ -2,12 +2,13 @@
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.awel import (
|
||||
BaseOperator,
|
||||
BranchFunc,
|
||||
BranchJoinOperator,
|
||||
BranchOperator,
|
||||
CommonLLMHttpRequestBody,
|
||||
CommonLLMHttpResponseBody,
|
||||
@@ -340,24 +341,7 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
category=OperatorCategory.LLM,
|
||||
operator_type=OperatorType.BRANCH,
|
||||
description=_("Branch the workflow based on the stream flag of the request."),
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Streaming Task Name"),
|
||||
"stream_task_name",
|
||||
str,
|
||||
optional=True,
|
||||
default="streaming_llm_task",
|
||||
description=_("The name of the streaming task."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Non-Streaming Task Name"),
|
||||
"no_stream_task_name",
|
||||
str,
|
||||
optional=True,
|
||||
default="llm_task",
|
||||
description=_("The name of the non-streaming task."),
|
||||
),
|
||||
],
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Request"),
|
||||
@@ -382,7 +366,12 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
stream_task_name: Optional[str] = None,
|
||||
no_stream_task_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new LLM branch operator.
|
||||
|
||||
Args:
|
||||
@@ -390,18 +379,13 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
no_stream_task_name (str): The name of the non-streaming task.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if not stream_task_name:
|
||||
raise ValueError("stream_task_name is not set")
|
||||
if not no_stream_task_name:
|
||||
raise ValueError("no_stream_task_name is not set")
|
||||
self._stream_task_name = stream_task_name
|
||||
self._no_stream_task_name = no_stream_task_name
|
||||
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
|
||||
"""
|
||||
Return a dict of branch function and task name.
|
||||
"""Return a dict of branch function and task name.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task
|
||||
@@ -409,6 +393,18 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
If the predicate function returns True, we will run the corresponding
|
||||
task.
|
||||
"""
|
||||
if self._stream_task_name and self._no_stream_task_name:
|
||||
stream_task_name = self._stream_task_name
|
||||
no_stream_task_name = self._no_stream_task_name
|
||||
else:
|
||||
stream_task_name = ""
|
||||
no_stream_task_name = ""
|
||||
for node in self.downstream:
|
||||
task = cast(BaseOperator, node)
|
||||
if task.streaming_operator:
|
||||
stream_task_name = node.node_name
|
||||
else:
|
||||
no_stream_task_name = node.node_name
|
||||
|
||||
async def check_stream_true(r: ModelRequest) -> bool:
|
||||
# If stream is true, we will run the streaming task. otherwise, we will run
|
||||
@@ -416,8 +412,8 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
return r.stream
|
||||
|
||||
return {
|
||||
check_stream_true: self._stream_task_name,
|
||||
lambda x: not x.stream: self._no_stream_task_name,
|
||||
check_stream_true: stream_task_name,
|
||||
lambda x: not x.stream: no_stream_task_name,
|
||||
}
|
||||
|
||||
|
||||
@@ -553,3 +549,93 @@ class StringOutput2ModelOutputOperator(MapOperator[str, ModelOutput]):
|
||||
text=input_value,
|
||||
error_code=500,
|
||||
)
|
||||
|
||||
|
||||
class LLMBranchJoinOperator(BranchJoinOperator[ModelOutput]):
|
||||
"""The LLM Branch Join Operator.
|
||||
|
||||
Decide which output to keep(streaming or non-streaming).
|
||||
"""
|
||||
|
||||
streaming_operator = True
|
||||
metadata = ViewMetadata(
|
||||
label=_("LLM Branch Join Operator"),
|
||||
name="llm_branch_join_operator",
|
||||
category=OperatorCategory.LLM,
|
||||
operator_type=OperatorType.JOIN,
|
||||
description=_("Just keep the first non-empty output."),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Streaming Model Output"),
|
||||
"stream_output",
|
||||
ModelOutput,
|
||||
is_list=True,
|
||||
description=_("The streaming output."),
|
||||
),
|
||||
IOField.build_from(
|
||||
_("Non-Streaming Model Output"),
|
||||
"not_stream_output",
|
||||
ModelOutput,
|
||||
description=_("The non-streaming output."),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"output_value",
|
||||
ModelOutput,
|
||||
is_list=True,
|
||||
description=_("The output value of the operator."),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new LLM branch join operator."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class StringBranchJoinOperator(BranchJoinOperator[str]):
|
||||
"""The String Branch Join Operator.
|
||||
|
||||
Decide which output to keep(streaming or non-streaming).
|
||||
"""
|
||||
|
||||
streaming_operator = True
|
||||
metadata = ViewMetadata(
|
||||
label=_("String Branch Join Operator"),
|
||||
name="string_branch_join_operator",
|
||||
category=OperatorCategory.COMMON,
|
||||
operator_type=OperatorType.JOIN,
|
||||
description=_("Just keep the first non-empty output."),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Streaming String Output"),
|
||||
"stream_output",
|
||||
str,
|
||||
is_list=True,
|
||||
description=_("The streaming output."),
|
||||
),
|
||||
IOField.build_from(
|
||||
_("Non-Streaming String Output"),
|
||||
"not_stream_output",
|
||||
str,
|
||||
description=_("The non-streaming output."),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("String Output"),
|
||||
"output_value",
|
||||
str,
|
||||
is_list=True,
|
||||
description=_("The output value of the operator."),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new LLM branch join operator."""
|
||||
super().__init__(**kwargs)
|
||||
|
Reference in New Issue
Block a user