mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,6 +10,12 @@ from .common import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
|
||||
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
|
||||
"CheckpointModule",
|
||||
"divide",
|
||||
"ACT2FN",
|
||||
"set_tensor_parallel_attribute_by_size",
|
||||
"set_tensor_parallel_attribute_by_partition",
|
||||
"get_tensor_parallel_mode",
|
||||
"_ntuple",
|
||||
"to_2tuple",
|
||||
]
|
||||
|
@@ -14,7 +14,6 @@ from colossalai.legacy.utils import checkpoint
|
||||
|
||||
|
||||
class CheckpointModule(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint: bool = True, offload: bool = False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
@@ -22,7 +21,7 @@ class CheckpointModule(nn.Module):
|
||||
self._offload = offload
|
||||
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
|
||||
raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._use_checkpoint:
|
||||
@@ -49,9 +48,8 @@ def divide(numerator, denominator):
|
||||
Returns:
|
||||
int: the result of exact division.
|
||||
"""
|
||||
assert denominator != 0, 'denominator can not be zero'
|
||||
assert numerator % denominator == 0, \
|
||||
'{} is not divisible by {}'.format(numerator, denominator)
|
||||
assert denominator != 0, "denominator can not be zero"
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
@@ -80,7 +78,6 @@ def get_tensor_parallel_mode():
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
|
Reference in New Issue
Block a user