mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
[polish] rename col_attr -> colo_attr (#558)
This commit is contained in:
@@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook):
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.remove_torch_payload()
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.remove_torch_payload()
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user