mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[zero] adapt zero hooks for unsharded module (#699)
This commit is contained in:
@@ -135,8 +135,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
self.sharded_param_list = []
|
||||
self.unshard_param_list = []
|
||||
self.param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.seed = seed
|
||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||
@@ -210,19 +209,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when exiting context.
|
||||
"""
|
||||
for param in self.sharded_param_list:
|
||||
# broadcast replicated no-shard parameters
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in self.param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
if not param.colo_attr.param_is_sharded and param.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
del self.sharded_param_list
|
||||
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in self.unshard_param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
if param.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
|
||||
del self.unshard_param_list
|
||||
del self.param_list
|
||||
|
||||
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
|
||||
torch.set_rng_state(self.cpu_rng_state)
|
||||
@@ -264,10 +259,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
self.sharded_param_list.append(param)
|
||||
else:
|
||||
self.unshard_param_list.append(param)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
|
||||
|
||||
self.param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
|
||||
Reference in New Issue
Block a user