mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[example] gpt, shard init on all processes (#2366)
This commit is contained in:
@@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor):
|
||||
def set_process_group(self, pg: ProcessGroup):
|
||||
"""set_process_group
|
||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
||||
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): target pg
|
||||
@@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor):
|
||||
# if the new pg is the same as the old pg, just returns
|
||||
if self.process_group == pg:
|
||||
return
|
||||
assert self.process_group.tp_world_size() == 1, \
|
||||
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
|
||||
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
|
||||
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
|
||||
assert self.dist_spec.placement.value == 'r', \
|
||||
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
|
||||
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
|
||||
|
||||
self.process_group = pg
|
||||
|
||||
|
Reference in New Issue
Block a user