[zero] adapt zero hooks for unsharded module (#699)

This commit is contained in:
HELSON
2022-04-08 20:23:26 +08:00
committed by GitHub
parent 896ade15d6
commit ee112fe1da
12 changed files with 71 additions and 59 deletions

View File

@@ -135,8 +135,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__()
self.shard_strategy = shard_strategy
self.sharded_param_list = []
self.unshard_param_list = []
self.param_list = []
self.model_numel_tensor = model_numel_tensor
self.seed = seed
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
@@ -210,19 +209,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
def _post_context_exec(self):
"""The callback function when exiting context.
"""
for param in self.sharded_param_list:
# broadcast replicated no-shard parameters
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
for param in self.param_list:
assert hasattr(param, 'colo_attr')
if not param.colo_attr.param_is_sharded and param.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.remove_torch_payload()
del self.sharded_param_list
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
for param in self.unshard_param_list:
assert hasattr(param, 'colo_attr')
if param.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
del self.unshard_param_list
del self.param_list
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
torch.set_rng_state(self.cpu_rng_state)
@@ -264,10 +259,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.sharded_data_tensor.payload
self.sharded_param_list.append(param)
else:
self.unshard_param_list.append(param)
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
self.param_list.append(param)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float