Polish sharded parameter (#297)

* init shard param from shape tuple

* add more unitest for shard param

* add more unittests to shareded param
This commit is contained in:
Jiarui Fang
2022-03-03 12:42:57 +08:00
committed by Frank Lee
parent 7aef75ca42
commit e17e92c54d
5 changed files with 106 additions and 61 deletions

View File

@@ -1,4 +1,3 @@
import functools
from typing import Any, Optional
@@ -7,11 +6,10 @@ import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
register_ophooks_recursively)
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.zero.shard_param import ShardParam
from colossalai.zero.sharded_param import ShardedParam
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
from torch.distributed import ProcessGroup
@@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
class ShardedModelV2(nn.Module):
def __init__(self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
):
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
):
r"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
@@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
# Shard the parameters at first
for _, param in self.module.named_parameters():
param.ca_attr = ShardParam(param)
param.ca_attr = ShardedParam(param)
param.ca_attr.shard()
param._sharded_grad = ShardedGradient(param, self, offload_config)
@@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
# However, if you set gradient_predivide_factor to None,
# we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = \
gradient_predivide_factor if gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
@@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
"""
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
@@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
orig_grad_data = new_grad.data
if self.world_size > 1:
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
self.reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)