[zero] polish sharded param name (#484)

* [zero] polish sharded param name

* polish code

* polish

* polish code

* polish

* polsih

* polish
This commit is contained in:
Jiarui Fang
2022-03-22 14:36:16 +08:00
committed by GitHub
parent 9caa8b6481
commit b334822163
12 changed files with 55 additions and 222 deletions

View File

@@ -160,8 +160,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.initialized_param_list.append(param)
if self.shard_param:
self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)

View File

@@ -165,7 +165,7 @@ class ShardedModelV2(nn.Module):
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data], self.process_group)
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group)
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
@@ -249,13 +249,15 @@ class ShardedModelV2(nn.Module):
param.col_attr.fp16_grad = reduced_grad.data
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()], self.process_group)
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
self.process_group)
prev_params = {}
for p in self.module.parameters():
prev_params[p] = p.data
p.data = p.col_attr.data.payload
p.data = p.col_attr.sharded_data_tensor.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()], self.process_group)
self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
self.process_group)
for p in self.module.parameters():
p.data = prev_params[p]
return gathered_state_dict

View File

@@ -11,9 +11,9 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
"""
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, 'col_attr')
shard_flag = zero_param.col_attr.data.is_sharded
shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.col_attr.data])
param.data = copy.deepcopy(zero_param.col_attr.data.payload)
sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.col_attr.data])
sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor])

View File

@@ -109,17 +109,17 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.data.is_sharded
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device)
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device)
if not is_param_sharded:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
def step(self, *args, **kwargs):
# unscale grads if scaled
@@ -149,24 +149,24 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.col_attr.data.is_sharded
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# We use ZeRO-2 here
# The `p.col_attr.data` saves full fp16 param
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.col_attr.data.copy_payload(p.data)
p.col_attr.sharded_data_tensor.copy_payload(p.data)
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
p.data = p.col_attr.data.payload
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.col_attr.sharded_data_tensor.payload
return ret
def backward(self, loss: Tensor) -> None:

View File

@@ -1,4 +1,4 @@
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParam, ShardedParamV2
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
__all__ = ['ShardedParam', 'ShardedTensor', 'ShardedParamV2']
__all__ = ['ShardedTensor', 'ShardedParamV2']

View File

@@ -1,12 +1,7 @@
from typing import Optional, Tuple, Union
import numpy
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional
class ShardedParamV2(object):
@@ -15,7 +10,7 @@ class ShardedParamV2(object):
param: torch.nn.Parameter,
process_group: Optional[dist.ProcessGroup] = None,
rm_torch_payload=False) -> None:
self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
self.fp16_grad: Optional[torch.Tensor] = None
self.fp32_grad: Optional[torch.Tensor] = None
@@ -37,105 +32,9 @@ class ShardedParamV2(object):
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
@property
def data(self):
return self._data_sharded_tensor
def sharded_data_tensor(self):
return self._sharded_data_tensor
@property
def param_is_sharded(self):
return self._data_sharded_tensor.is_sharded
class ShardedParam(object):
r"""
A wrapper to torch.nn.Parameter. Shard a param
on memory space of different processes.
"""
def __init__(self,
other: Union[torch.nn.Parameter, Tuple[int, ...]],
process_group: Optional[dist.ProcessGroup] = None,
is_sharded: bool = False,
device: Optional[torch.device] = None) -> None:
r"""
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
process_group: the process group storing the shared data.
is_sharded: is shared the param during __init__.
device: the device to place param data payload on
"""
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
self.is_sharded = False
self.device = device
# Hijack the data payload of param
if isinstance(other, torch.nn.Parameter):
self._param_payload = other.data.to(device)
self._origin_shape = other.shape
self._origin_numel = other.numel()
if is_sharded:
self.shard()
elif isinstance(other, tuple):
self._origin_shape = other
self._origin_numel = numpy.prod(other)
# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
self._param_payload = torch.empty(self._origin_shape, device=device)
if is_sharded:
self.shard()
else:
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
self._payload_numel = None
def payload(self, target_device: Optional[torch.device] = None):
r"""
get the payload and move it to target device
"""
if target_device is not None:
return self._param_payload.to(target_device)
return self._param_payload
def set_payload(self, data: torch.Tensor):
r"""
set payload as data
"""
assert self._param_payload.shape == data.shape
self._param_payload.copy_(data)
def shard(self):
r"""
Distributed the payload of param to all processes.
"""
if self.is_sharded:
return
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
self.is_sharded = True
def gather(self):
r"""
Collect the payload of param from different processes to process of local rank.
The payload has to be moved to cuda memory before communication.
"""
if not self.is_sharded:
return
buffer_list = []
payload_numel = self._param_payload.numel()
for i in range(self.world_size):
if i == self.local_rank:
buffer_list.append(self._param_payload.cuda())
else:
buffer_list.append(torch.zeros(payload_numel).cuda())
torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank],
group=self.process_group,
async_op=False)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_sharded = False
@property
def origin_dtype(self):
return self._origin_dtype
return self._sharded_data_tensor.is_sharded