mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[hotfix] fix unit test test_module_spec (#1321)
This commit is contained in:
@@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
||||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
@@ -121,11 +121,13 @@ class ColoTensor(torch.Tensor):
|
||||
RuntimeError:
|
||||
"""
|
||||
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
||||
if self.process_group.tp_world_size() != 1:
|
||||
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group")
|
||||
|
||||
if self.dist_spec.placement.value != 'r':
|
||||
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE")
|
||||
# if the new pg is the same as the old pg, just returns
|
||||
if self.process_group == pg:
|
||||
return
|
||||
assert self.process_group.tp_world_size() == 1, \
|
||||
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
|
||||
assert self.dist_spec.placement.value == 'r', \
|
||||
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
|
||||
|
||||
self.process_group = pg
|
||||
|
||||
@@ -290,17 +292,17 @@ class ColoTensor(torch.Tensor):
|
||||
|
||||
def is_replicate(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.process_group.tp_world_size() == 1)
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
|
||||
def is_shard_1drow(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def is_sharded(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||
|
Reference in New Issue
Block a user