mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[fx]get communication size between partitions (#1224)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx]get communication size between partitions.
* polish
This commit is contained in:
27
colossalai/fx/passes/utils.py
Normal file
27
colossalai/fx/passes/utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
from typing import Dict, Set
|
||||
from torch.fx.node import Node, map_arg
|
||||
|
||||
|
||||
def get_comm_size(prev_partition, next_partition):
|
||||
"""Given two partitions (parent and child),
|
||||
calculate the communication size between the two.
|
||||
"""
|
||||
# Keep tracking the communication size between parent and child
|
||||
comm_size = 0
|
||||
# Keep tracking all the counted node
|
||||
visited_nodes = set()
|
||||
# Go through all nodes in the child partition
|
||||
# If a node has input nodes from the parent partition,
|
||||
# the output size of those input nodes will be counted
|
||||
# and added to comm_size
|
||||
parent_node_names = [n.name for n in parent_partition.graph.nodes]
|
||||
for node in child_partition.graph.nodes:
|
||||
input_nodes: Dict[Node, None] = {}
|
||||
map_arg(node.args, lambda n: input_nodes.setdefault(n))
|
||||
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
||||
for n in input_nodes:
|
||||
if n.name in parent_node_names and n not in visited_nodes:
|
||||
comm_size += n.meta['tensor_meta'].numel
|
||||
visited_nodes.add(n)
|
||||
return comm_size
|
Reference in New Issue
Block a user