[fx]get communication size between partitions (#1224)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx]get communication size between partitions.

* polish
This commit is contained in:
YuliangLiu0306
2022-07-07 16:22:00 +08:00
committed by GitHub
parent 4951f7d80c
commit 2b7dca44b5
4 changed files with 209 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
import torch
from typing import Dict, Set
from torch.fx.node import Node, map_arg
def get_comm_size(prev_partition, next_partition):
"""Given two partitions (parent and child),
calculate the communication size between the two.
"""
# Keep tracking the communication size between parent and child
comm_size = 0
# Keep tracking all the counted node
visited_nodes = set()
# Go through all nodes in the child partition
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
parent_node_names = [n.name for n in parent_partition.graph.nodes]
for node in child_partition.graph.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
comm_size += n.meta['tensor_meta'].numel
visited_nodes.add(n)
return comm_size