mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 02:25:08 +00:00
feat(core): Support i18n (#1327)
This commit is contained in:
@@ -228,7 +228,7 @@ class DAGLifecycle:
|
||||
"""Execute before DAG run."""
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
async def after_dag_end(self, event_loop_task_id: int):
|
||||
"""Execute after DAG end.
|
||||
|
||||
This method may be called multiple times, please make sure it is idempotent.
|
||||
@@ -464,6 +464,7 @@ class DAGContext:
|
||||
self,
|
||||
node_to_outputs: Dict[str, TaskContext],
|
||||
share_data: Dict[str, Any],
|
||||
event_loop_task_id: int,
|
||||
streaming_call: bool = False,
|
||||
node_name_to_ids: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
@@ -483,6 +484,7 @@ class DAGContext:
|
||||
self._share_data: Dict[str, Any] = share_data
|
||||
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
self._event_loop_task_id = event_loop_task_id
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
@@ -600,6 +602,9 @@ class DAGContext:
|
||||
raise ValueError("key can't be None")
|
||||
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
|
||||
|
||||
async def _clean_all(self):
|
||||
pass
|
||||
|
||||
|
||||
class DAG:
|
||||
"""The DAG class.
|
||||
@@ -618,6 +623,8 @@ class DAG:
|
||||
self._leaf_nodes: List[DAGNode] = []
|
||||
self._trigger_nodes: List[DAGNode] = []
|
||||
self._resource_group: Optional[ResourceGroup] = resource_group
|
||||
self._lock = asyncio.Lock()
|
||||
self._event_loop_task_id_to_ctx: Dict[int, DAGContext] = {}
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
@@ -689,13 +696,41 @@ class DAG:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
async def _after_dag_end(self) -> None:
|
||||
async def _save_dag_ctx(self, dag_ctx: DAGContext) -> None:
|
||||
async with self._lock:
|
||||
event_loop_task_id = dag_ctx._event_loop_task_id
|
||||
current_task = asyncio.current_task()
|
||||
task_name = current_task.get_name() if current_task else None
|
||||
self._event_loop_task_id_to_ctx[event_loop_task_id] = dag_ctx
|
||||
logger.debug(
|
||||
f"Save DAG context {dag_ctx} to event loop task {event_loop_task_id}, "
|
||||
f"task_name: {task_name}"
|
||||
)
|
||||
|
||||
async def _after_dag_end(self, event_loop_task_id: Optional[int] = None) -> None:
|
||||
"""Execute after DAG end."""
|
||||
tasks = []
|
||||
event_loop_task_id = event_loop_task_id or id(asyncio.current_task())
|
||||
for node in self.node_map.values():
|
||||
tasks.append(node.after_dag_end())
|
||||
tasks.append(node.after_dag_end(event_loop_task_id))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Clear the DAG context
|
||||
async with self._lock:
|
||||
current_task = asyncio.current_task()
|
||||
task_name = current_task.get_name() if current_task else None
|
||||
if event_loop_task_id not in self._event_loop_task_id_to_ctx:
|
||||
raise RuntimeError(
|
||||
f"DAG context not found with event loop task id "
|
||||
f"{event_loop_task_id}, task_name: {task_name}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Clean DAG context with event loop task id {event_loop_task_id}, "
|
||||
f"task_name: {task_name}"
|
||||
)
|
||||
dag_ctx = self._event_loop_task_id_to_ctx.pop(event_loop_task_id)
|
||||
await dag_ctx._clean_all()
|
||||
|
||||
def print_tree(self) -> None:
|
||||
"""Print the DAG tree""" # noqa: D400
|
||||
_print_format_dag_tree(self)
|
||||
|
Reference in New Issue
Block a user