mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 06:42:05 +00:00
fix sharded param hook and unit test
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from typing import Union, Tuple, Optional
|
||||
import numpy
|
||||
|
||||
|
||||
class ShardedParam(object):
|
||||
@@ -28,6 +29,7 @@ class ShardedParam(object):
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.is_sharded = False
|
||||
self.device = device
|
||||
|
||||
# Hijack the data payload of param
|
||||
if isinstance(other, torch.nn.Parameter):
|
||||
@@ -50,17 +52,19 @@ class ShardedParam(object):
|
||||
|
||||
self._payload_numel = None
|
||||
|
||||
def payload(self, target_device: torch.device):
|
||||
def payload(self, target_device: Optional[torch.device] = None):
|
||||
r"""
|
||||
get the payload and move it to target device
|
||||
"""
|
||||
return self._param_payload.to(target_device)
|
||||
if target_device is not None:
|
||||
return self._param_payload.to(target_device)
|
||||
return self._param_payload
|
||||
|
||||
def set_payload(self, data: torch.Tensor):
|
||||
r"""
|
||||
set payload as data
|
||||
"""
|
||||
assert self._param_payload.numel() == data.numel()
|
||||
assert self._param_payload.shape == data.shape
|
||||
self._param_payload.copy_(data)
|
||||
|
||||
def shard(self):
|
||||
|
Reference in New Issue
Block a user