mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 23:02:07 +00:00
[zero] update zero context init with the updated test utils (#327)
This commit is contained in:
@@ -10,7 +10,10 @@ from typing import Union, Tuple, Optional
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
def __init__(self,
|
||||
param: torch.nn.Parameter,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rm_torch_payload=False) -> None:
|
||||
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
|
||||
if param.requires_grad and param.grad is not None:
|
||||
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
|
||||
@@ -19,7 +22,16 @@ class ShardedParamV2(object):
|
||||
self._grad_sharded_tensor = None
|
||||
|
||||
# make sure the shared param is the only owner of payload
|
||||
param.data = torch.empty([], dtype=param.dtype, device=param.device)
|
||||
# The param.data maybe used to init the other part of the model.
|
||||
# For example: File "resnet.py", line 190, in __init__
|
||||
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
# So we can not empty the .data at this time
|
||||
self.param = param
|
||||
if rm_torch_payload:
|
||||
self.remove_torch_payload()
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
|
Reference in New Issue
Block a user