From 571f12eff39405d067d2002db43717ba823f7faa Mon Sep 17 00:00:00 2001 From: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com> Date: Tue, 17 May 2022 08:01:06 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) --- colossalai/nn/layer/utils/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py index a112f2d95..f2297304f 100644 --- a/colossalai/nn/layer/utils/common.py +++ b/colossalai/nn/layer/utils/common.py @@ -13,7 +13,8 @@ from torch import Tensor, nn class CheckpointModule(nn.Module): - def __init__(self, checkpoint: bool = True, offload : bool = False): + + def __init__(self, checkpoint: bool = True, offload: bool = False): super().__init__() self.checkpoint = checkpoint self._use_checkpoint = checkpoint @@ -78,6 +79,7 @@ def get_tensor_parallel_mode(): def _ntuple(n): + def parse(x): if isinstance(x, collections.abc.Iterable): return x