mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,17 +14,22 @@ from ._gradient_accumulation import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
|
||||
'GradAccumGradientHandler'
|
||||
"accumulate_gradient",
|
||||
"GradAccumDataloader",
|
||||
"GradAccumOptimizer",
|
||||
"GradAccumLrSchedulerByStep",
|
||||
"GradAccumGradientHandler",
|
||||
]
|
||||
|
||||
|
||||
def accumulate_gradient(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
dataloader: Iterable,
|
||||
accumulate_size: int,
|
||||
gradient_handlers: List[BaseGradientHandler] = None,
|
||||
lr_scheduler: _LRScheduler = None):
|
||||
def accumulate_gradient(
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
dataloader: Iterable,
|
||||
accumulate_size: int,
|
||||
gradient_handlers: List[BaseGradientHandler] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
):
|
||||
r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation.
|
||||
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user