mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[zero] adapt zero hooks for unsharded module (#699)
This commit is contained in:
@@ -36,6 +36,7 @@ class ZeroHook(BaseOpHook):
|
||||
self._stateful_tensor_mgr = stateful_tensor_mgr
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
@@ -45,12 +46,15 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
# gather sharded parameters
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
|
||||
# record memory statistics
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
@@ -59,18 +63,25 @@ class ZeroHook(BaseOpHook):
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
|
||||
# change tensor state to HOLD_AFTER_FWD
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
# shard gathered parameters
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
@@ -80,12 +91,15 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
# gather sharded parameters
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
|
||||
# record memory statistics
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
@@ -94,15 +108,20 @@ class ZeroHook(BaseOpHook):
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
|
||||
# change tensor state to HOLD_AFTER_BWD
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
# shard gathered parameters
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
|
||||
@@ -135,8 +135,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
self.sharded_param_list = []
|
||||
self.unshard_param_list = []
|
||||
self.param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.seed = seed
|
||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||
@@ -210,19 +209,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when exiting context.
|
||||
"""
|
||||
for param in self.sharded_param_list:
|
||||
# broadcast replicated no-shard parameters
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in self.param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
if not param.colo_attr.param_is_sharded and param.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
del self.sharded_param_list
|
||||
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in self.unshard_param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
if param.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
|
||||
del self.unshard_param_list
|
||||
del self.param_list
|
||||
|
||||
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
|
||||
torch.set_rng_state(self.cpu_rng_state)
|
||||
@@ -264,10 +259,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
self.sharded_param_list.append(param)
|
||||
else:
|
||||
self.unshard_param_list.append(param)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
|
||||
|
||||
self.param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
|
||||
@@ -121,7 +121,7 @@ class ShardedModelV2(nn.Module):
|
||||
self._ophook_list = [
|
||||
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
|
||||
]
|
||||
register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded)
|
||||
register_ophooks_recursively(self.module, self._ophook_list)
|
||||
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
@@ -366,14 +366,12 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
prev_params = {}
|
||||
for p in self.sharded_params:
|
||||
prev_params[p] = p.data
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
p.data = prev_params[p]
|
||||
p.colo_attr.remove_torch_payload()
|
||||
return gathered_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
|
||||
@@ -268,10 +268,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
p.data = self.master_params[p].payload
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
# FIXME(hhc): add hook for unsharded parameters
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
p.colo_attr.remove_torch_payload()
|
||||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
@@ -351,10 +348,11 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
|
||||
p.colo_attr.remove_torch_payload()
|
||||
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
# We gather full fp16 param here
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
|
||||
self.master_params[p].trans_state(TensorState.HOLD)
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
||||
@@ -5,6 +5,11 @@ from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage
|
||||
from .tensorful_state import StatefulTensor, TensorState
|
||||
from typing import List
|
||||
|
||||
# use this tensor as empty data point for parameters
|
||||
# we do not want users use param.data when its torch payload is removed
|
||||
# empty tensor is expected to raise error when get used
|
||||
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu')
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
@@ -29,7 +34,7 @@ class ShardedParamV2(object):
|
||||
return [self._sharded_data_tensor]
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype)
|
||||
|
||||
@property
|
||||
def sharded_data_tensor(self):
|
||||
|
||||
Reference in New Issue
Block a user