From 036404ca8a564d17446b59c09ba21ceffd0608f6 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Sat, 2 Apr 2022 18:30:06 +0800 Subject: [PATCH] Revert "[zero] polish init context (#645)" (#657) --- colossalai/zero/init_ctx/init_context.py | 38 ++++++++---------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index ff65b3191..52a166d89 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -90,11 +90,9 @@ class ZeroContextConfig(object): Args: target_device (torch.device): The device where param data are after exiting the context. - replicated (bool, optional): Whether the param is replicated across data parallel (DP) group. - We do not need to synchronize (reduce) the grads of the replicated params among DP group. + replicated (bool, optional): Whether the param is replicated across data parallel group. Some parameters are not replicated, e.g. parameters in MOE experts. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. - The process group among which tensors are sharded is assigned as an runtime arg. rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. This will reduce memory usage when initializing model. But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. @@ -112,9 +110,6 @@ class ZeroContextConfig(object): self.target_device = target_device self.is_replicated: bool = replicated self.shard_param: bool = shard_param - - if self.is_replicated is False: - assert self.shard_param is True, f"ZeroContextConfig shard_param must be False when is_replicated is False" self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly @@ -122,8 +117,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): """A context to initialize model. 1. Convert the model to fp16. - 2. The paramaters of the module are adapted to type `ShardedParameter`. - 3. Shard the param and grad according to flag `shard_param`. + 2. The paramaters of the module are adapted to type ShardedParameter. + 3. Shard the param and grad according to flags. Args: target_device (torch.device): The device where param data are after exiting the context. @@ -149,8 +144,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): super().__init__() self.shard_strategy = shard_strategy - # a list contains params that could be sharded. - self.shardable_param_list = [] + self.initialized_param_list = [] self.model_numel_tensor = model_numel_tensor self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) @@ -187,17 +181,21 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): """The callback function when exiting context. """ if not self.rm_torch_payload_on_the_fly: - for param in self.shardable_param_list: + for param in self.initialized_param_list: assert hasattr(param, 'colo_attr') param.colo_attr.remove_torch_payload() - del self.shardable_param_list + del self.initialized_param_list def _post_init_method(self, module: torch.nn.Module): """ The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times. """ + + def half_fn(t: torch.Tensor): + return t.half() if t.is_floating_point() else t + for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice if hasattr(param, 'colo_attr'): @@ -209,10 +207,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): param.is_replicated = self.is_replicated # convert parameters to half - param_half = cast_tensor_to_fp16(param.data) + param_half = half_fn(param) param.data = param_half if param.grad is not None: - grad_half = cast_tensor_to_fp16(param.grad) + grad_half = half_fn(param.grad) param.grad.data = grad_half # move torch parameters to the target device @@ -225,7 +223,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) - self.shardable_param_list.append(param) + self.initialized_param_list.append(param) # We must cast buffers # If we use BN, buffers may be on CPU and Float @@ -257,16 +255,6 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: def no_shard_zero_decrator(is_replicated: bool = True): - """ - A decorator used to wrap an __init__ function of Module. - The parameters initialized by the model will not sharded. - is_replicated indicates the grad of the param won't be reduced among the data parallel process group. - - >>> def MyModule(torch.nn.Module): - >>> @no_shard_zero_decrator(is_replicated = False) - >>> def __init__(self, ...) - >>> .... - """ def _wrapper(init_func):