[zero] yet an improved sharded param (#311)

This commit is contained in:
Jiarui Fang
2022-03-04 15:49:23 +08:00
committed by Frank Lee
parent c9e7d9582d
commit 90d3aef62c
3 changed files with 82 additions and 21 deletions

View File

@@ -1,4 +1,4 @@
from colossalai.zero.sharded_param.sharded_param import ShardedParam
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParam, ShardedParamV2
__all__ = ['ShardedParam', 'ShardedTensor']
__all__ = ['ShardedParam', 'ShardedTensor', 'ShardedParamV2']

View File

@@ -6,6 +6,40 @@ import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union, Tuple, Optional
import numpy
class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None:
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
if param.requires_grad and param.grad is not None:
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
param.grad = None
else:
self._grad_sharded_tensor = None
# make sure the shared param is the only owner of payload
param.data = torch.empty([], dtype=param.dtype, device=param.device)
@property
def data(self):
return self._data_sharded_tensor.payload
@data.setter
def data(self, t: torch.Tensor):
self._data_sharded_tensor.payload = t
@property
def grad(self):
return self._grad_sharded_tensor.payload
@grad.setter
def grad(self, t: torch.Tensor):
self._grad_sharded_tensor.payload = t
class ShardedParam(object):