diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/nn/parallel/layers/module_utils.py index 37d8afbec..34513de4c 100644 --- a/colossalai/nn/parallel/layers/module_utils.py +++ b/colossalai/nn/parallel/layers/module_utils.py @@ -79,8 +79,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True): check_colo_module(submodule, recursive=True) -def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recursive=True, mode='default'): - compute_pattern = parallel_action.compute_pattern +def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'): + compute_pattern = compute_spec.compute_pattern if is_colo_module(module): # for each param # set DistSpec and ComputeSpec @@ -96,7 +96,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu continue param = module.get_parameter(param_name) if isinstance(param, ColoParameter): - spec = TensorSpec(dist_spec, parallel_action) + spec = TensorSpec(dist_spec, compute_spec) param.set_tensor_spec(spec) for mod in param.shared_param_modules: modules_update_param.add(mod) @@ -104,4 +104,4 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu check_colo_module(mod, recursive=False) if recursive == True: for submodule in module.children(): - init_colo_module(submodule, parallel_action, recursive=True, mode=mode) + init_colo_module(submodule, compute_spec, recursive=True, mode=mode) diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/tensor/tensor_spec.py index 4dc944eca..f847ad62f 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/tensor/tensor_spec.py @@ -9,7 +9,7 @@ class TensorSpec(object): The specification of the ColoTensor. Args: dist_spec (_DistSpec): descriping the layout among processes. - parallel_action (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor. + compute_spec (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor. Defaults to None. """