mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
[zero] polish sharded param name (#484)
* [zero] polish sharded param name * polish code * polish * polish code * polish * polsih * polish
This commit is contained in:
@@ -160,8 +160,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
# if param.col_attr.grad and self.shard_grad:
|
||||
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
@@ -165,7 +165,7 @@ class ShardedModelV2(nn.Module):
|
||||
if self.shard_param:
|
||||
for p in self.module.parameters():
|
||||
if not p.col_attr.param_is_sharded:
|
||||
self.shard_strategy.shard([p.col_attr.data], self.process_group)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group)
|
||||
for p in self.module.parameters():
|
||||
p.col_attr.bwd_count = 0
|
||||
if not p.requires_grad:
|
||||
@@ -249,13 +249,15 @@ class ShardedModelV2(nn.Module):
|
||||
param.col_attr.fp16_grad = reduced_grad.data
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()], self.process_group)
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
prev_params = {}
|
||||
for p in self.module.parameters():
|
||||
prev_params[p] = p.data
|
||||
p.data = p.col_attr.data.payload
|
||||
p.data = p.col_attr.sharded_data_tensor.payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()], self.process_group)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
for p in self.module.parameters():
|
||||
p.data = prev_params[p]
|
||||
return gathered_state_dict
|
||||
|
@@ -11,9 +11,9 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
|
||||
"""
|
||||
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
||||
assert hasattr(zero_param, 'col_attr')
|
||||
shard_flag = zero_param.col_attr.data.is_sharded
|
||||
shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.gather([zero_param.col_attr.data])
|
||||
param.data = copy.deepcopy(zero_param.col_attr.data.payload)
|
||||
sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor])
|
||||
param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload)
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.shard([zero_param.col_attr.data])
|
||||
sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor])
|
||||
|
@@ -109,17 +109,17 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
is_param_sharded = p.col_attr.data.is_sharded
|
||||
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# TODO (ver217): we may not use shard / gather here
|
||||
# Param is no sharded, which means we use ZeRO-2 here
|
||||
# As we only store param shard, we shard it here
|
||||
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
|
||||
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device)
|
||||
if not is_param_sharded:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
# unscale grads if scaled
|
||||
@@ -149,24 +149,24 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
# a chunk.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
is_param_sharded = p.col_attr.data.is_sharded
|
||||
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# We use ZeRO-2 here
|
||||
# The `p.col_attr.data` saves full fp16 param
|
||||
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
|
||||
# But we only have updated fp32 param shard here
|
||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||
# Then we will gather them
|
||||
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
# We have to use `copy_payload` instead of `reset_payload`
|
||||
# Since p.data is fp32 and p.col_attr.data is fp16
|
||||
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.col_attr.data.copy_payload(p.data)
|
||||
p.col_attr.sharded_data_tensor.copy_payload(p.data)
|
||||
|
||||
if not is_param_sharded:
|
||||
# We gather full fp16 param here
|
||||
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
|
||||
p.data = p.col_attr.data.payload
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
p.data = p.col_attr.sharded_data_tensor.payload
|
||||
return ret
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParam, ShardedParamV2
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
|
||||
__all__ = ['ShardedParam', 'ShardedTensor', 'ShardedParamV2']
|
||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
||||
|
@@ -1,12 +1,7 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
@@ -15,7 +10,7 @@ class ShardedParamV2(object):
|
||||
param: torch.nn.Parameter,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rm_torch_payload=False) -> None:
|
||||
self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||
self.fp16_grad: Optional[torch.Tensor] = None
|
||||
self.fp32_grad: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -37,105 +32,9 @@ class ShardedParamV2(object):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data_sharded_tensor
|
||||
def sharded_data_tensor(self):
|
||||
return self._sharded_data_tensor
|
||||
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self._data_sharded_tensor.is_sharded
|
||||
|
||||
|
||||
class ShardedParam(object):
|
||||
r"""
|
||||
A wrapper to torch.nn.Parameter. Shard a param
|
||||
on memory space of different processes.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
other: Union[torch.nn.Parameter, Tuple[int, ...]],
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
is_sharded: bool = False,
|
||||
device: Optional[torch.device] = None) -> None:
|
||||
r"""
|
||||
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
|
||||
process_group: the process group storing the shared data.
|
||||
is_sharded: is shared the param during __init__.
|
||||
device: the device to place param data payload on
|
||||
"""
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.is_sharded = False
|
||||
self.device = device
|
||||
|
||||
# Hijack the data payload of param
|
||||
if isinstance(other, torch.nn.Parameter):
|
||||
self._param_payload = other.data.to(device)
|
||||
self._origin_shape = other.shape
|
||||
self._origin_numel = other.numel()
|
||||
if is_sharded:
|
||||
self.shard()
|
||||
elif isinstance(other, tuple):
|
||||
self._origin_shape = other
|
||||
self._origin_numel = numpy.prod(other)
|
||||
|
||||
# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
|
||||
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
|
||||
self._param_payload = torch.empty(self._origin_shape, device=device)
|
||||
if is_sharded:
|
||||
self.shard()
|
||||
else:
|
||||
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
|
||||
|
||||
self._payload_numel = None
|
||||
|
||||
def payload(self, target_device: Optional[torch.device] = None):
|
||||
r"""
|
||||
get the payload and move it to target device
|
||||
"""
|
||||
if target_device is not None:
|
||||
return self._param_payload.to(target_device)
|
||||
return self._param_payload
|
||||
|
||||
def set_payload(self, data: torch.Tensor):
|
||||
r"""
|
||||
set payload as data
|
||||
"""
|
||||
assert self._param_payload.shape == data.shape
|
||||
self._param_payload.copy_(data)
|
||||
|
||||
def shard(self):
|
||||
r"""
|
||||
Distributed the payload of param to all processes.
|
||||
"""
|
||||
if self.is_sharded:
|
||||
return
|
||||
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
|
||||
self.is_sharded = True
|
||||
|
||||
def gather(self):
|
||||
r"""
|
||||
Collect the payload of param from different processes to process of local rank.
|
||||
The payload has to be moved to cuda memory before communication.
|
||||
"""
|
||||
if not self.is_sharded:
|
||||
return
|
||||
|
||||
buffer_list = []
|
||||
payload_numel = self._param_payload.numel()
|
||||
for i in range(self.world_size):
|
||||
if i == self.local_rank:
|
||||
buffer_list.append(self._param_payload.cuda())
|
||||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel).cuda())
|
||||
|
||||
torch.distributed.all_gather(buffer_list,
|
||||
buffer_list[self.local_rank],
|
||||
group=self.process_group,
|
||||
async_op=False)
|
||||
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
||||
self.is_sharded = False
|
||||
|
||||
@property
|
||||
def origin_dtype(self):
|
||||
return self._origin_dtype
|
||||
return self._sharded_data_tensor.is_sharded
|
||||
|
Reference in New Issue
Block a user