mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,7 +14,6 @@ from colossalai.legacy.utils import checkpoint
|
||||
|
||||
|
||||
class CheckpointModule(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint: bool = True, offload: bool = False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
@@ -22,7 +21,7 @@ class CheckpointModule(nn.Module):
|
||||
self._offload = offload
|
||||
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
|
||||
raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._use_checkpoint:
|
||||
@@ -49,9 +48,8 @@ def divide(numerator, denominator):
|
||||
Returns:
|
||||
int: the result of exact division.
|
||||
"""
|
||||
assert denominator != 0, 'denominator can not be zero'
|
||||
assert numerator % denominator == 0, \
|
||||
'{} is not divisible by {}'.format(numerator, denominator)
|
||||
assert denominator != 0, "denominator can not be zero"
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
@@ -80,7 +78,6 @@ def get_tensor_parallel_mode():
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
|
Reference in New Issue
Block a user