[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals
# the comm size for all gather is the size of the gathered tensor
gather_dim = comm_spec.gather_dim
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
all_gather_size = device_mesh.mesh_shape[all_gather_axis]
all_gather_size = device_mesh.shape[all_gather_axis]
comm_size_for_all_gather = comm_size * all_gather_size
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
# give a tiny cost to shard