mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 14:49:24 +00:00
Polish sharded parameter (#297)
* init shard param from shape tuple * add more unitest for shard param * add more unittests to shareded param
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
import functools
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -7,11 +6,10 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
|
||||
register_ophooks_recursively)
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
from colossalai.zero.sharded_param import ShardedParam
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
reshard_after_forward: bool = True,
|
||||
mixed_precision: bool = False,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
reshard_after_forward: bool = True,
|
||||
mixed_precision: bool = False,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
):
|
||||
r"""
|
||||
A demo to reconfigure zero1 shared_model.
|
||||
Currently do not consider the Optimizer States.
|
||||
@@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
# Shard the parameters at first
|
||||
for _, param in self.module.named_parameters():
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr = ShardedParam(param)
|
||||
param.ca_attr.shard()
|
||||
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
||||
|
||||
@@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
|
||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
# However, if you set gradient_predivide_factor to None,
|
||||
# we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = \
|
||||
gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
get_gradient_predivide_factor(self.world_size)
|
||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||
|
||||
@@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
|
||||
full gradient for the local batch. The reduce-scatter op will save
|
||||
a single shard of the summed gradient across all
|
||||
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
||||
|
||||
before reduce_scatter:
|
||||
@@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
|
||||
orig_grad_data = new_grad.data
|
||||
if self.world_size > 1:
|
||||
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||
self.reducer.reduce_scatter_async(
|
||||
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
self.reducer.reduce_scatter_async(grad_chunks,
|
||||
group=self.reduce_scatter_process_group,
|
||||
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
else:
|
||||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
|
Reference in New Issue
Block a user