[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -20,11 +20,13 @@ class ZeroHook(BaseOpHook):
Warning: this class has been deprecated after version 0.1.12
"""
def __init__(self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector] = None,
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
process_group: Optional[dist.ProcessGroup] = None):
def __init__(
self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector] = None,
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
process_group: Optional[dist.ProcessGroup] = None,
):
super().__init__()
self.logger = get_dist_logger("ZeROHook")
self.shard_strategy = shard_strategy
@@ -41,7 +43,7 @@ class ZeroHook(BaseOpHook):
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
assert hasattr(param, "colo_attr")
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
@@ -50,7 +52,7 @@ class ZeroHook(BaseOpHook):
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
assert hasattr(param, "colo_attr")
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
@@ -74,10 +76,9 @@ class ZeroHook(BaseOpHook):
self.gather_parameters(module)
for param in module.parameters(recurse=False):
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
assert param.data.device.type == "cuda", f"PRE FWD param.data must be on CUDA"
def post_fwd_exec(self, module: torch.nn.Module, *args):
# change tensor state to HOLD_AFTER_FWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
@@ -93,10 +94,9 @@ class ZeroHook(BaseOpHook):
self.gather_parameters(module)
for param in module.parameters(recurse=False):
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
assert param.data.device.type == "cuda", f"PRE BWD param.data must be on CUDA"
def post_bwd_exec(self, module: torch.nn.Module, input):
# change tensor state to HOLD_AFTER_BWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
@@ -114,5 +114,6 @@ class ZeroHook(BaseOpHook):
if self._stateful_tensor_mgr:
self.logger.debug(
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}",
ranks=[0])
ranks=[0],
)
self._stateful_tensor_mgr.finish_iter()