mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -49,11 +49,13 @@ class ColoDDP(torch.nn.Module):
|
||||
If it's None, the default data parallel group will be used. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
process_group: ColoProcessGroup,
|
||||
bucket_cap_mb: int = 25,
|
||||
rebuild_bucket: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
process_group: ColoProcessGroup,
|
||||
bucket_cap_mb: int = 25,
|
||||
rebuild_bucket: bool = True,
|
||||
) -> None:
|
||||
assert not isinstance(module, ColoDDP)
|
||||
super().__init__()
|
||||
self.module = module
|
||||
@@ -74,19 +76,18 @@ class ColoDDP(torch.nn.Module):
|
||||
def parameters(self, recurse: bool = True):
|
||||
return self.module.parameters(recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||
def named_parameters(self, prefix: str = "", recurse: bool = True):
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||
def named_buffers(self, prefix: str = "", recurse: bool = True):
|
||||
return self.module.named_buffers(prefix, recurse)
|
||||
|
||||
def named_children(self):
|
||||
return self.module.named_children()
|
||||
|
||||
def named_modules(self,
|
||||
memo: Optional[Set[torch.nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
def named_modules(
|
||||
self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||
):
|
||||
return self.module.named_modules(memo, prefix, remove_duplicate)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -114,9 +115,9 @@ class ColoDDP(torch.nn.Module):
|
||||
grad = grad / self.dp_world_size
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self.reducer.all_reduce_async(grad,
|
||||
group=self.process_group.dp_process_group(),
|
||||
callback_fn=partial(self._save_grad, p))
|
||||
self.reducer.all_reduce_async(
|
||||
grad, group=self.process_group.dp_process_group(), callback_fn=partial(self._save_grad, p)
|
||||
)
|
||||
grad.record_stream(self.comm_stream)
|
||||
else:
|
||||
ColoDDP._save_grad(p, grad)
|
||||
@@ -130,7 +131,7 @@ class ColoDDP(torch.nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _save_grad(p, grad):
|
||||
if hasattr(p, '_saved_grad'):
|
||||
if hasattr(p, "_saved_grad"):
|
||||
p._saved_grad.add_(grad)
|
||||
else:
|
||||
p._saved_grad = grad
|
||||
@@ -138,7 +139,7 @@ class ColoDDP(torch.nn.Module):
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_saved_grad', None) is not None:
|
||||
if getattr(p, "_saved_grad", None) is not None:
|
||||
if set_to_none:
|
||||
p._saved_grad = None
|
||||
else:
|
||||
@@ -167,8 +168,8 @@ class ColoDDP(torch.nn.Module):
|
||||
for p in params_to_ignore:
|
||||
p._ddp_to_ignore = True
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True):
|
||||
return self.module.load_state_dict(state_dict, strict)
|
||||
|
Reference in New Issue
Block a user