mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 02:20:49 +00:00
[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)
This commit is contained in:
@@ -38,8 +38,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
param = module.get_parameter(param_name)
|
||||
if not isinstance(param, ColoParameter):
|
||||
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
|
||||
if param.has_spec():
|
||||
cur_compute_pattern = param.spec.compute_spec.compute_pattern
|
||||
if param.has_compute_spec():
|
||||
cur_compute_pattern = param.tensor_spec.compute_spec.compute_pattern
|
||||
if compute_pattern is None:
|
||||
compute_pattern = cur_compute_pattern
|
||||
else:
|
||||
@@ -61,8 +61,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
cur_match = True
|
||||
for param_name, dist_spec in param_specs.items():
|
||||
param = module.get_parameter(param_name)
|
||||
if param.has_spec():
|
||||
if dist_spec != param.spec.dist_spec:
|
||||
if param.has_compute_spec():
|
||||
if dist_spec != param.tensor_spec.dist_spec:
|
||||
cur_match = False
|
||||
break
|
||||
else:
|
||||
@@ -97,7 +97,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
|
||||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
spec = TensorSpec(dist_spec, parallel_action)
|
||||
param.set_spec(spec)
|
||||
param.set_tensor_spec(spec)
|
||||
for mod in param.shared_param_modules:
|
||||
modules_update_param.add(mod)
|
||||
for mod in modules_update_param:
|
||||
|
Reference in New Issue
Block a user