diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e7f199129..eb0350053 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -50,7 +50,7 @@ class ModulePolicyDescription: new_weight = shard_rowwise(weight, process_group) module.weight = torch.nn.Parameter(new_weight) ``` - sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription object which specifies the module to be replaced and the target module used to replacement. method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2a75d7047..2828d5175 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -92,7 +92,7 @@ class BucketStore(BaseStore): def get_flatten_grad(self) -> Tensor: """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: - [grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] + [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] Returns: Tensor: the flattened gradients slices in the bucket