[zero] add zero context manager to change config during initialization (#546)

This commit is contained in:
HELSON
2022-03-29 17:57:59 +08:00
committed by GitHub
parent ec5086c49c
commit 8c90d4df54
5 changed files with 185 additions and 18 deletions

View File

@@ -1,3 +1,3 @@
from .init_context import ZeroInitContext
from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
__all__ = ['ZeroInitContext']
__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator']

View File

@@ -1,9 +1,11 @@
import contextlib
import functools
from typing import Optional
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
@@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
pass
class ZeroContextConfig(object):
"""The configuration used to control zero context initialization.
Args:
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18. Defaults to False.
"""
def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
super().__init__()
self.shard_param: bool = shard_param
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
@@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
Args:
convert_fp16 (bool): Whether to convert params to fp16.
target_device (torch.device): The device where param data after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
@@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__()
self.target_device = target_device
self.shard_param = shard_param
self.shard_strategy = shard_strategy
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(shard_param=shard_param,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
ZeroContextMgr().current_context = self
@property
def shard_param(self):
return self.config.shard_param
@property
def rm_torch_payload_on_the_fly(self):
return self.config.rm_torch_payload_on_the_fly
def _pre_context_exec(self):
"""
The Callback function when entering the context
@@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
if hasattr(param, 'col_attr'):
@@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.model_numel_tensor += param.numel()
target_device = self.target_device
# convert to fp16
param.data = param.data.to(torch.half)
# convert parameters to half
param_half = half_fn(param)
param.data = param_half
if param.grad is not None:
param.grad = param.grad.to(torch.half)
grad_half = half_fn(param.grad)
param.grad.data = grad_half
# move torch parameters to the target device
target_device = self.target_device
param.data = param.data.to(target_device)
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
self.initialized_param_list.append(param)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
self.initialized_param_list.append(param)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
@@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_tensor_to_fp16(buffer.data)
class ZeroContextMgr(metaclass=SingletonMeta):
current_context: Optional[ZeroInitContext] = None
@contextlib.contextmanager
def hijack_context_config(self, **kwargs):
if self.current_context is None:
yield
else:
old_config = self.current_context.config
self.current_context.config = ZeroContextConfig(**kwargs)
yield
self.current_context.config = old_config
def no_shard_zero_context():
return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False)
def no_shard_zero_decrator(init_func):
def _no_shard(*args, **kwargs):
with no_shard_zero_context():
init_func(*args, **kwargs)
return _no_shard