[example] gpt, shard init on all processes (#2366)

This commit is contained in:
Jiarui Fang
2023-01-06 15:44:50 +08:00
committed by GitHub
parent 1f8ab6f1f5
commit 1aaeb596c6
2 changed files with 18 additions and 12 deletions

View File

@@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor):
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
Args:
pg (ProcessGroup): target pg
@@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor):
# if the new pg is the same as the old pg, just returns
if self.process_group == pg:
return
assert self.process_group.tp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
self.process_group = pg