mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[hotfix] fix unit test test_module_spec (#1321)
This commit is contained in:
@@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module,
|
||||
compute_pattern = compute_spec.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ComputeSpec
|
||||
# set its process_group, dist_spec and compute_spec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
@@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module,
|
||||
continue
|
||||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
param.set_process_group(pg)
|
||||
param.set_dist_spec(dist_spec)
|
||||
param.compute_spec = compute_spec
|
||||
for mod in param.shared_param_modules:
|
||||
|
@@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
||||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
@@ -121,11 +121,13 @@ class ColoTensor(torch.Tensor):
|
||||
RuntimeError:
|
||||
"""
|
||||
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
||||
if self.process_group.tp_world_size() != 1:
|
||||
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group")
|
||||
|
||||
if self.dist_spec.placement.value != 'r':
|
||||
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE")
|
||||
# if the new pg is the same as the old pg, just returns
|
||||
if self.process_group == pg:
|
||||
return
|
||||
assert self.process_group.tp_world_size() == 1, \
|
||||
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
|
||||
assert self.dist_spec.placement.value == 'r', \
|
||||
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
|
||||
|
||||
self.process_group = pg
|
||||
|
||||
@@ -290,17 +292,17 @@ class ColoTensor(torch.Tensor):
|
||||
|
||||
def is_replicate(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.process_group.tp_world_size() == 1)
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
|
||||
def is_shard_1drow(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def is_sharded(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||
|
Reference in New Issue
Block a user