fix sharded param hook and unit test

This commit is contained in:
ver217
2022-03-04 13:40:48 +08:00
committed by Frank Lee
parent 001ca624dd
commit 36f9a74ab2
6 changed files with 49 additions and 66 deletions

View File

@@ -1,10 +1,11 @@
from typing import Optional, Tuple, Union
import numpy
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from typing import Union, Tuple, Optional
import numpy
class ShardedParam(object):
@@ -28,6 +29,7 @@ class ShardedParam(object):
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
self.is_sharded = False
self.device = device
# Hijack the data payload of param
if isinstance(other, torch.nn.Parameter):
@@ -50,17 +52,19 @@ class ShardedParam(object):
self._payload_numel = None
def payload(self, target_device: torch.device):
def payload(self, target_device: Optional[torch.device] = None):
r"""
get the payload and move it to target device
"""
return self._param_payload.to(target_device)
if target_device is not None:
return self._param_payload.to(target_device)
return self._param_payload
def set_payload(self, data: torch.Tensor):
r"""
set payload as data
"""
assert self._param_payload.numel() == data.numel()
assert self._param_payload.shape == data.shape
self._param_payload.copy_(data)
def shard(self):