[hotfix] fix unit test test_module_spec (#1321)

This commit is contained in:
HELSON
2022-07-15 14:02:32 +08:00
committed by GitHub
parent 9e4c6449b0
commit 1b41686461
3 changed files with 29 additions and 22 deletions

View File

@@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module,
compute_pattern = compute_spec.compute_pattern
if is_colo_module(module):
# for each param
# set DistSpec and ComputeSpec
# set its process_group, dist_spec and compute_spec
colo_module = get_colo_module(module)
colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
@@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module,
continue
param = module.get_parameter(param_name)
if isinstance(param, ColoParameter):
param.set_process_group(pg)
param.set_dist_spec(dist_spec)
param.compute_spec = compute_spec
for mod in param.shared_param_modules: