[hotfix] fx get comm size bugs (#1233)

* init a checkpoint dir

* [checkpoint]support resume for cosinewarmuplr

* [checkpoint]add unit test

* fix some bugs but still not OK

* fix bugs

* make it faster

* [checkpoint]support generalized scheduler

* polish

* [tensor] torch function return colotensor

* polish

* fix bugs

* remove debug info

* polish

* polish

* [tensor] test_model pass unittests

* polish

* [hotfix] fx get comm size bug

Co-authored-by: ZhaoYi1222 <zhaoyi9499@gmail.com>
This commit is contained in:
Jiarui Fang
2022-07-08 10:54:41 +08:00
committed by GitHub
parent 42ab36b762
commit 0e199d71e8
2 changed files with 6 additions and 8 deletions

View File

@@ -15,8 +15,8 @@ def get_comm_size(prev_partition, next_partition):
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
parent_node_names = [n.name for n in parent_partition.graph.nodes]
for node in child_partition.graph.nodes:
parent_node_names = [n.name for n in prev_partition.graph.nodes]
for node in next_partition.graph.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))