refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -384,7 +384,7 @@ class DAGContext:
return self._share_data.get(key)
async def save_to_share_data(
self, key: str, data: Any, overwrite: Optional[str] = None
self, key: str, data: Any, overwrite: bool = False
) -> None:
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
@@ -407,7 +407,7 @@ class DAGContext:
return self.get_from_share_data(_build_task_key(task_name, key))
async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
self, task_name: str, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by task name and key
@@ -415,7 +415,7 @@ class DAGContext:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
overwrite (bool): Whether overwrite the share data if the key already exists.
Defaults to None.
Raises:

View File

@@ -46,7 +46,7 @@ class WorkflowRunner(ABC, Generic[T]):
node: "BaseOperator",
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
dag_ctx: Optional[DAGContext] = None,
exist_dag_ctx: Optional[DAGContext] = None,
) -> DAGContext:
"""Execute the workflow starting from a given operator.
@@ -54,7 +54,7 @@ class WorkflowRunner(ABC, Generic[T]):
node (RunnableDAGNode): The starting node of the workflow to be executed.
call_data (CALL_DATA): The data pass to root operator node.
streaming_call (bool): Whether the call is a streaming call.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
Returns:
DAGContext: The context after executing the workflow, containing the final state and data.
"""
@@ -190,7 +190,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Returns:
OUT: The output of the node after execution.
"""
out_ctx = await self._runner.execute_workflow(self, call_data, dag_ctx=dag_ctx)
out_ctx = await self._runner.execute_workflow(
self, call_data, exist_dag_ctx=dag_ctx
)
return out_ctx.current_task_context.task_output.output
def _blocking_call(
@@ -230,7 +232,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, dag_ctx=dag_ctx
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
)
return out_ctx.current_task_context.task_output.output_stream

View File

@@ -9,6 +9,12 @@ from .base import BaseOperator
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
call_data = curr_task_ctx.call_data
if call_data:
call_data = await curr_task_ctx._call_data_to_output()
output = await call_data.streamify(self.streamify)
curr_task_ctx.set_task_output(output)
return output
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
self.streamify
)

View File

@@ -76,12 +76,12 @@ def _save_call_data(
return id2call_data
if len(root_nodes) == 1:
node = root_nodes[0]
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
logger.debug(f"Save call data to node {node.node_id}, call_data: {call_data}")
id2call_data[node.node_id] = call_data
else:
for node in root_nodes:
node_id = node.node_id
logger.info(
logger.debug(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id)

View File

@@ -19,24 +19,24 @@ class DefaultWorkflowRunner(WorkflowRunner):
node: BaseOperator,
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
dag_ctx: Optional[DAGContext] = None,
exist_dag_ctx: Optional[DAGContext] = None,
) -> DAGContext:
# Save node output
# dag = node.dag
job_manager = JobManager.build_from_end_node(node, call_data)
if not dag_ctx:
if not exist_dag_ctx:
# Create DAG context
node_outputs: Dict[str, TaskContext] = {}
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
node_name_to_ids=job_manager._node_name_to_ids,
)
else:
node_outputs = dag_ctx._node_to_outputs
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
# Share node output with exist dag context
node_outputs = exist_dag_ctx._node_to_outputs
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
node_name_to_ids=job_manager._node_name_to_ids,
)
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
skip_node_ids = set()
system_app: SystemApp = DAGVar.get_current_system_app()