mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -12,7 +12,7 @@ from .apex_amp import convert_to_apex_amp
|
||||
from .naive_amp import convert_to_naive_amp
|
||||
from .torch_amp import convert_to_torch_amp
|
||||
|
||||
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
|
||||
__all__ = ["convert_to_amp", "convert_to_naive_amp", "convert_to_apex_amp", "convert_to_torch_amp", "AMP_TYPE"]
|
||||
|
||||
|
||||
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
|
||||
@@ -38,8 +38,7 @@ def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mod
|
||||
For ``torch_amp``, please check
|
||||
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
|
||||
"""
|
||||
assert isinstance(mode, AMP_TYPE), \
|
||||
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
|
||||
assert isinstance(mode, AMP_TYPE), f"expected the argument mode be AMP_TYPE, but got {type(mode)}"
|
||||
|
||||
if amp_config is None:
|
||||
amp_config = Config()
|
||||
|
Reference in New Issue
Block a user