mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 06:42:05 +00:00
[zero] sharded tensor (#305)
* init shard param from shape tuple * add more unitest for shard param * add set_payload method for ShardedParam * [zero] add shareded tensor class * polish code
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .sharded_param import ShardedParam
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParam
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
__all__ = ['ShardedParam']
|
||||
__all__ = ['ShardedParam', 'ShardedTensor']
|
||||
|
@@ -56,6 +56,13 @@ class ShardedParam(object):
|
||||
"""
|
||||
return self._param_payload.to(target_device)
|
||||
|
||||
def set_payload(self, data: torch.Tensor):
|
||||
r"""
|
||||
set payload as data
|
||||
"""
|
||||
assert self._param_payload.numel() == data.numel()
|
||||
self._param_payload.copy_(data)
|
||||
|
||||
def shard(self):
|
||||
r"""
|
||||
Distributed the payload of param to all processes.
|
||||
|
67
colossalai/zero/sharded_param/sharded_tensor.py
Normal file
67
colossalai/zero/sharded_param/sharded_tensor.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ShardedTensor(object):
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
r"""
|
||||
A tensor sharded in multiple processes.
|
||||
"""
|
||||
self._payload = tensor
|
||||
self.process_group = process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self._is_sharded = False
|
||||
self._payload = tensor
|
||||
|
||||
self._origin_shape = tensor.shape
|
||||
self._origin_numel = tensor.numel()
|
||||
self._origin_dtype = tensor.dtype
|
||||
|
||||
@property
|
||||
def is_sharded(self):
|
||||
return self._is_sharded
|
||||
|
||||
@property
|
||||
def payload(self):
|
||||
return self._payload
|
||||
|
||||
@payload.setter
|
||||
def payload(self, tensor):
|
||||
self._payload.copy_(tensor)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._origin_dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
||||
def shard(self):
|
||||
if self._is_sharded:
|
||||
return
|
||||
self._payload, _ = get_shard(self._payload, self.local_rank, self.world_size)
|
||||
self._is_sharded = True
|
||||
|
||||
def gather(self):
|
||||
if not self._is_sharded:
|
||||
return
|
||||
|
||||
buffer_list = []
|
||||
payload_numel = self._payload.numel()
|
||||
for i in range(self.world_size):
|
||||
if i == self.local_rank:
|
||||
buffer_list.append(self._payload.cuda())
|
||||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel).cuda())
|
||||
|
||||
torch.distributed.all_gather(buffer_list,
|
||||
buffer_list[self.local_rank],
|
||||
group=self.process_group,
|
||||
async_op=False)
|
||||
self._payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
||||
self._is_sharded = False
|
Reference in New Issue
Block a user