feat(core): Support more chat flows (#1180)

This commit is contained in:
Fangyin Cheng
2024-02-22 12:19:04 +08:00
committed by GitHub
parent 16fa68d4f2
commit ab5e1c7ea1
10 changed files with 175 additions and 55 deletions

View File

@@ -632,16 +632,36 @@ class BaseMetadata(BaseResource):
runnable_parameters: Dict[str, Any] = {}
if not self.parameters or not view_parameters:
return runnable_parameters
if len(self.parameters) != len(view_parameters):
view_required_parameters = {
parameter.name: parameter
for parameter in view_parameters
if not parameter.optional
}
current_required_parameters = {
parameter.name: parameter
for parameter in self.parameters
if not parameter.optional
}
current_parameters = {
parameter.name: parameter for parameter in self.parameters
}
if len(view_required_parameters) < len(current_required_parameters):
# TODO, skip the optional parameters.
raise FlowParameterMetadataException(
f"Parameters count not match. Expected {len(self.parameters)}, "
f"Parameters count not match(current key: {self.id}). "
f"Expected {len(self.parameters)}, "
f"but got {len(view_parameters)} from JSON metadata."
f"Required parameters: {current_required_parameters.keys()}, "
f"but got {view_required_parameters.keys()}."
)
for i, parameter in enumerate(self.parameters):
view_param = view_parameters[i]
for view_param in view_parameters:
view_param_key = view_param.name
if view_param_key not in current_parameters:
raise FlowParameterMetadataException(
f"Parameter {view_param_key} not found in the metadata."
)
runnable_parameters.update(
parameter.to_runnable_parameter(
current_parameters[view_param_key].to_runnable_parameter(
view_param.get_typed_value(), resources, key_to_resource_instance
)
)

View File

@@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
This class extends DAGNode by adding execution capabilities.
"""
streaming_operator: bool = False
def __init__(
self,
task_id: Optional[str] = None,

View File

@@ -10,6 +10,8 @@ from .base import BaseOperator
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
streaming_operator = True
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
call_data = curr_task_ctx.call_data
@@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
AsyncIterator[IN] to another AsyncIterator[OUT].
"""
streaming_operator = True
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[

View File

@@ -74,11 +74,11 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the messages."""
if "system_message" not in values:
raise ValueError("No system message")
values["system_message"] = "You are a helpful AI Assistant."
if "human_message" not in values:
raise ValueError("No human message")
values["human_message"] = "{user_input}"
if "message_placeholder" not in values:
raise ValueError("No message placeholder")
values["message_placeholder"] = "chat_history"
system_message = values.pop("system_message")
human_message = values.pop("human_message")
message_placeholder = values.pop("message_placeholder")