mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,17 +14,22 @@ from ._gradient_accumulation import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
|
||||
'GradAccumGradientHandler'
|
||||
"accumulate_gradient",
|
||||
"GradAccumDataloader",
|
||||
"GradAccumOptimizer",
|
||||
"GradAccumLrSchedulerByStep",
|
||||
"GradAccumGradientHandler",
|
||||
]
|
||||
|
||||
|
||||
def accumulate_gradient(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
dataloader: Iterable,
|
||||
accumulate_size: int,
|
||||
gradient_handlers: List[BaseGradientHandler] = None,
|
||||
lr_scheduler: _LRScheduler = None):
|
||||
def accumulate_gradient(
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
dataloader: Iterable,
|
||||
accumulate_size: int,
|
||||
gradient_handlers: List[BaseGradientHandler] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
):
|
||||
r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation.
|
||||
|
||||
Args:
|
||||
|
@@ -272,8 +272,9 @@ class GradAccumGradientHandler:
|
||||
"""
|
||||
|
||||
def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None:
|
||||
assert isinstance(grad_handler, BaseGradientHandler), \
|
||||
f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}'
|
||||
assert isinstance(
|
||||
grad_handler, BaseGradientHandler
|
||||
), f"expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}"
|
||||
self.grad_handler = grad_handler
|
||||
self.accumulate_size = accumulate_size
|
||||
self.accumulate_step = 0
|
||||
|
Reference in New Issue
Block a user