[hotfix] add shard dim to aviod backward communication error (#2954)

This commit is contained in:
YuliangLiu0306
2023-03-01 11:41:53 +08:00
committed by GitHub
parent 090f14fd6b
commit 47fb214b3b
2 changed files with 3 additions and 0 deletions

View File

@@ -343,6 +343,7 @@ class DefaultReshapeGenerator(ReshapeGenerator):
comm_type=CommType.BEFORE,
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]