fixed typo in ShardParam (#294)

This commit is contained in:
Frank Lee 2022-03-02 17:26:23 +08:00 committed by GitHub
parent a463980aab
commit 4fbb8db586
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,25 +1,28 @@
from enum import Enum from enum import Enum
from optparse import Option
import torch import torch
from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
import torch.distributed as dist import torch.distributed as dist
class TensorType(Enum): class TensorType(Enum):
GRAD = 1 GRAD = 1
DATA = 2 DATA = 2
class ShardParam(object): class ShardParam(object):
r""" r"""
A wrapper to torch.nn.Parameter. Shard a param A wrapper to torch.nn.Parameter. Shard a param
on different processes. on different processes.
""" """
def __init__(self,
param: torch.nn.Parameter, def __init__(
tensor_type: TensorType = TensorType.DATA, self,
process_group = None, param: torch.nn.Parameter,
) -> None: tensor_type: TensorType = TensorType.DATA,
process_group=None,
) -> None:
self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group) self.local_rank = dist.get_rank(self.process_group)
@ -27,27 +30,27 @@ class ShardParam(object):
self._payload_numel = None self._payload_numel = None
self._origin_shape = param.shape self._origin_shape = param.shape
self._origin_numel = param.numel() self._origin_numel = param.numel()
self.is_shared = False self.is_sharded = False
def payload(self, target_device : torch.device): def payload(self, target_device: torch.device):
return self._param_payload.to(target_device) return self._param_payload.to(target_device)
def shard(self): def shard(self):
r""" r"""
Distributed the payload of param to all processes. Distributed the payload of param to all processes.
""" """
if self.is_shared: if self.is_sharded:
return return
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size) self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
self.is_shared = True self.is_sharded = True
def gather(self): def gather(self):
r""" r"""
Collect the payload of param from different processes to process of local rank. Collect the payload of param from different processes to process of local rank.
""" """
if not self.is_shared: if not self.is_sharded:
return return
buffer_list = [] buffer_list = []
payload_numel = self._param_payload.numel() payload_numel = self._param_payload.numel()
for i in range(self.world_size): for i in range(self.world_size):
@ -55,9 +58,10 @@ class ShardParam(object):
buffer_list.append(self._param_payload.cuda()) buffer_list.append(self._param_payload.cuda())
else: else:
buffer_list.append(torch.zeros(payload_numel).cuda()) buffer_list.append(torch.zeros(payload_numel).cuda())
torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group, async_op=False)
print(buffer_list)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_shared = False
torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank],
group=self.process_group,
async_op=False)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_sharded = False