mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[zero] add zero context manager to change config during initialization (#546)
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .init_context import ZeroInitContext
|
||||
from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
|
||||
|
||||
__all__ = ['ZeroInitContext']
|
||||
__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator']
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
@@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
pass
|
||||
|
||||
|
||||
class ZeroContextConfig(object):
|
||||
"""The configuration used to control zero context initialization.
|
||||
|
||||
Args:
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
If set to `False`, remove tensor payload on param.data afther the context exist.
|
||||
This is used when you add some logic to operate tensors in __init__ of module.
|
||||
See torchvision resnet18. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
|
||||
super().__init__()
|
||||
self.shard_param: bool = shard_param
|
||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""A context to initialize model.
|
||||
|
||||
@@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
3. Shard the param and grad according to flags.
|
||||
|
||||
Args:
|
||||
convert_fp16 (bool): Whether to convert params to fp16.
|
||||
target_device (torch.device): The device where param data after exiting the context.
|
||||
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
@@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
super().__init__()
|
||||
self.target_device = target_device
|
||||
self.shard_param = shard_param
|
||||
self.shard_strategy = shard_strategy
|
||||
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
||||
self.initialized_param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(shard_param=shard_param,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
@property
|
||||
def shard_param(self):
|
||||
return self.config.shard_param
|
||||
|
||||
@property
|
||||
def rm_torch_payload_on_the_fly(self):
|
||||
return self.config.rm_torch_payload_on_the_fly
|
||||
|
||||
def _pre_context_exec(self):
|
||||
"""
|
||||
The Callback function when entering the context
|
||||
@@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.half() if t.is_floating_point() else t
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
if hasattr(param, 'col_attr'):
|
||||
@@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
self.model_numel_tensor += param.numel()
|
||||
|
||||
target_device = self.target_device
|
||||
|
||||
# convert to fp16
|
||||
param.data = param.data.to(torch.half)
|
||||
# convert parameters to half
|
||||
param_half = half_fn(param)
|
||||
param.data = param_half
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(torch.half)
|
||||
grad_half = half_fn(param.grad)
|
||||
param.grad.data = grad_half
|
||||
|
||||
# move torch parameters to the target device
|
||||
target_device = self.target_device
|
||||
param.data = param.data.to(target_device)
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
|
||||
self.initialized_param_list.append(param)
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
@@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
for buffer in module.buffers(recurse=False):
|
||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||
buffer.data = cast_tensor_to_fp16(buffer.data)
|
||||
|
||||
|
||||
class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
current_context: Optional[ZeroInitContext] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hijack_context_config(self, **kwargs):
|
||||
if self.current_context is None:
|
||||
yield
|
||||
else:
|
||||
old_config = self.current_context.config
|
||||
self.current_context.config = ZeroContextConfig(**kwargs)
|
||||
yield
|
||||
self.current_context.config = old_config
|
||||
|
||||
|
||||
def no_shard_zero_context():
|
||||
return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False)
|
||||
|
||||
|
||||
def no_shard_zero_decrator(init_func):
|
||||
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context():
|
||||
init_func(*args, **kwargs)
|
||||
|
||||
return _no_shard
|
||||
|
Reference in New Issue
Block a user