mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -20,11 +20,13 @@ class ZeroHook(BaseOpHook):
|
||||
Warning: this class has been deprecated after version 0.1.12
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
memstarts_collector: Optional[MemStatsCollector] = None,
|
||||
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None):
|
||||
def __init__(
|
||||
self,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
memstarts_collector: Optional[MemStatsCollector] = None,
|
||||
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger("ZeROHook")
|
||||
self.shard_strategy = shard_strategy
|
||||
@@ -41,7 +43,7 @@ class ZeroHook(BaseOpHook):
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert hasattr(param, "colo_attr")
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
|
||||
@@ -50,7 +52,7 @@ class ZeroHook(BaseOpHook):
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert hasattr(param, "colo_attr")
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
@@ -74,10 +76,9 @@ class ZeroHook(BaseOpHook):
|
||||
self.gather_parameters(module)
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
assert param.data.device.type == "cuda", f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
|
||||
# change tensor state to HOLD_AFTER_FWD
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
@@ -93,10 +94,9 @@ class ZeroHook(BaseOpHook):
|
||||
self.gather_parameters(module)
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
assert param.data.device.type == "cuda", f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
|
||||
# change tensor state to HOLD_AFTER_BWD
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
@@ -114,5 +114,6 @@ class ZeroHook(BaseOpHook):
|
||||
if self._stateful_tensor_mgr:
|
||||
self.logger.debug(
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
self._stateful_tensor_mgr.finish_iter()
|
||||
|
Reference in New Issue
Block a user