mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 17:16:51 +00:00
feat(core): Support more chat flows (#1180)
This commit is contained in:
@@ -632,16 +632,36 @@ class BaseMetadata(BaseResource):
|
||||
runnable_parameters: Dict[str, Any] = {}
|
||||
if not self.parameters or not view_parameters:
|
||||
return runnable_parameters
|
||||
if len(self.parameters) != len(view_parameters):
|
||||
view_required_parameters = {
|
||||
parameter.name: parameter
|
||||
for parameter in view_parameters
|
||||
if not parameter.optional
|
||||
}
|
||||
current_required_parameters = {
|
||||
parameter.name: parameter
|
||||
for parameter in self.parameters
|
||||
if not parameter.optional
|
||||
}
|
||||
current_parameters = {
|
||||
parameter.name: parameter for parameter in self.parameters
|
||||
}
|
||||
if len(view_required_parameters) < len(current_required_parameters):
|
||||
# TODO, skip the optional parameters.
|
||||
raise FlowParameterMetadataException(
|
||||
f"Parameters count not match. Expected {len(self.parameters)}, "
|
||||
f"Parameters count not match(current key: {self.id}). "
|
||||
f"Expected {len(self.parameters)}, "
|
||||
f"but got {len(view_parameters)} from JSON metadata."
|
||||
f"Required parameters: {current_required_parameters.keys()}, "
|
||||
f"but got {view_required_parameters.keys()}."
|
||||
)
|
||||
for i, parameter in enumerate(self.parameters):
|
||||
view_param = view_parameters[i]
|
||||
for view_param in view_parameters:
|
||||
view_param_key = view_param.name
|
||||
if view_param_key not in current_parameters:
|
||||
raise FlowParameterMetadataException(
|
||||
f"Parameter {view_param_key} not found in the metadata."
|
||||
)
|
||||
runnable_parameters.update(
|
||||
parameter.to_runnable_parameter(
|
||||
current_parameters[view_param_key].to_runnable_parameter(
|
||||
view_param.get_typed_value(), resources, key_to_resource_instance
|
||||
)
|
||||
)
|
||||
|
@@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
This class extends DAGNode by adding execution capabilities.
|
||||
"""
|
||||
|
||||
streaming_operator: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
|
@@ -10,6 +10,8 @@ from .base import BaseOperator
|
||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
|
||||
|
||||
streaming_operator = True
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
call_data = curr_task_ctx.call_data
|
||||
@@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
AsyncIterator[IN] to another AsyncIterator[OUT].
|
||||
"""
|
||||
|
||||
streaming_operator = True
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
|
||||
|
Reference in New Issue
Block a user