mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[hotfix] fix unit test test_module_spec (#1321)
This commit is contained in:
@@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module,
|
||||
compute_pattern = compute_spec.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ComputeSpec
|
||||
# set its process_group, dist_spec and compute_spec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
@@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module,
|
||||
continue
|
||||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
param.set_process_group(pg)
|
||||
param.set_dist_spec(dist_spec)
|
||||
param.compute_spec = compute_spec
|
||||
for mod in param.shared_param_modules:
|
||||
|
Reference in New Issue
Block a user