mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals
|
||||
# the comm size for all gather is the size of the gathered tensor
|
||||
gather_dim = comm_spec.gather_dim
|
||||
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
|
||||
all_gather_size = device_mesh.mesh_shape[all_gather_axis]
|
||||
all_gather_size = device_mesh.shape[all_gather_axis]
|
||||
comm_size_for_all_gather = comm_size * all_gather_size
|
||||
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
|
||||
# give a tiny cost to shard
|
||||
|
Reference in New Issue
Block a user