mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[zero] update zero context init with the updated test utils (#327)
This commit is contained in:
@@ -82,25 +82,31 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
3. Shard the param and grad according to flags.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
convert_fp16: bool,
|
||||
convert_cuda: bool,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
shard_param: bool = False,
|
||||
shard_grad: bool = False,
|
||||
):
|
||||
def __init__(self,
|
||||
convert_fp16: bool,
|
||||
convert_cuda: bool,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
shard_param: bool = False,
|
||||
shard_grad: bool = False,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
super().__init__()
|
||||
self.convert_fp16 = convert_fp16
|
||||
self.convert_cuda = convert_cuda
|
||||
self.shard_param = shard_param
|
||||
self.shard_grad = shard_grad
|
||||
self.shard_strategy = shard_strategy
|
||||
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
||||
self.initialized_param_list = []
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when the context exits.
|
||||
"""
|
||||
pass
|
||||
if not self.rm_torch_payload_on_the_fly:
|
||||
for param in self.initialized_param_list:
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.remove_torch_payload()
|
||||
|
||||
del self.initialized_param_list
|
||||
|
||||
def _post_init_method(self, module):
|
||||
r"""The function to call at the end of the constructor of each nn.Module.
|
||||
@@ -121,7 +127,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(torch.half).to(target_device)
|
||||
|
||||
param.ca_attr = ShardedParamV2(param)
|
||||
param.ca_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor])
|
||||
if param.ca_attr.grad and self.shard_grad:
|
||||
|
@@ -7,6 +7,11 @@ from typing import List, Optional
|
||||
class BaseShardStrategy(ABC):
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
|
||||
|
||||
Args:
|
||||
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
|
||||
"""
|
||||
self.process_group = process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
@@ -14,14 +19,8 @@ class BaseShardStrategy(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def shard(self, tensor_list: List[ShardedTensor]):
|
||||
r"""
|
||||
sharded the memory of tensor on multiple processes.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gather(self, tensor_list: List[ShardedTensor]):
|
||||
r"""
|
||||
duplicate tensor payload on each processes.
|
||||
"""
|
||||
pass
|
||||
|
@@ -10,7 +10,10 @@ from typing import Union, Tuple, Optional
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
def __init__(self,
|
||||
param: torch.nn.Parameter,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rm_torch_payload=False) -> None:
|
||||
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
|
||||
if param.requires_grad and param.grad is not None:
|
||||
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
|
||||
@@ -19,7 +22,16 @@ class ShardedParamV2(object):
|
||||
self._grad_sharded_tensor = None
|
||||
|
||||
# make sure the shared param is the only owner of payload
|
||||
param.data = torch.empty([], dtype=param.dtype, device=param.device)
|
||||
# The param.data maybe used to init the other part of the model.
|
||||
# For example: File "resnet.py", line 190, in __init__
|
||||
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
# So we can not empty the .data at this time
|
||||
self.param = param
|
||||
if rm_torch_payload:
|
||||
self.remove_torch_payload()
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
|
Reference in New Issue
Block a user