mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[zero] yet an improved sharded param (#311)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParam
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParam, ShardedParamV2
|
||||
|
||||
__all__ = ['ShardedParam', 'ShardedTensor']
|
||||
__all__ = ['ShardedParam', 'ShardedTensor', 'ShardedParamV2']
|
||||
|
@@ -6,6 +6,40 @@ import torch.distributed as dist
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Union, Tuple, Optional
|
||||
import numpy
|
||||
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
|
||||
if param.requires_grad and param.grad is not None:
|
||||
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
|
||||
param.grad = None
|
||||
else:
|
||||
self._grad_sharded_tensor = None
|
||||
|
||||
# make sure the shared param is the only owner of payload
|
||||
param.data = torch.empty([], dtype=param.dtype, device=param.device)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data_sharded_tensor.payload
|
||||
|
||||
@data.setter
|
||||
def data(self, t: torch.Tensor):
|
||||
self._data_sharded_tensor.payload = t
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
return self._grad_sharded_tensor.payload
|
||||
|
||||
@grad.setter
|
||||
def grad(self, t: torch.Tensor):
|
||||
self._grad_sharded_tensor.payload = t
|
||||
|
||||
|
||||
class ShardedParam(object):
|
||||
|
Reference in New Issue
Block a user