diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py index a112f2d95..f2297304f 100644 --- a/colossalai/nn/layer/utils/common.py +++ b/colossalai/nn/layer/utils/common.py @@ -13,7 +13,8 @@ from torch import Tensor, nn class CheckpointModule(nn.Module): - def __init__(self, checkpoint: bool = True, offload : bool = False): + + def __init__(self, checkpoint: bool = True, offload: bool = False): super().__init__() self.checkpoint = checkpoint self._use_checkpoint = checkpoint @@ -78,6 +79,7 @@ def get_tensor_parallel_mode(): def _ntuple(n): + def parse(x): if isinstance(x, collections.abc.Iterable): return x