From 94eec1c5ade1bd03dcb8690400ab72751fcc32d5 Mon Sep 17 00:00:00 2001 From: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Date: Tue, 28 Mar 2023 15:55:09 +0800 Subject: [PATCH] [NFC] polish colossalai/engine/gradient_accumulation/_gradient_accumulation.py code style (#3277) Co-authored-by: siqi --- .../gradient_accumulation/_gradient_accumulation.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py index 89c28c3be..cf66be1cd 100644 --- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py @@ -1,21 +1,22 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Union +from typing import Any, Iterable, Tuple, Union + import torch.nn as nn from torch import Tensor -from typing import Iterable, Any, Tuple -from colossalai.nn.optimizer import ColossalaiOptimizer from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from colossalai.utils import conditional_context + from colossalai.engine import BaseGradientHandler +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils import conditional_context class GradAccumOptimizer(ColossalaiOptimizer): - """A wrapper for the optimizer to enable gradient accumulation by skipping the steps + """A wrapper for the optimizer to enable gradient accumulation by skipping the steps before accumulation size is reached. Args: @@ -161,7 +162,7 @@ class GradAccumDataloader: class GradAccumLrSchedulerByStep(_LRScheduler): - """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps + """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps before accumulation size is reached. Args: