mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 16:28:10 +00:00
[zero] add sharded grad and refactor grad hooks for ShardedModel (#287)
This commit is contained in:
parent
4fbb8db586
commit
9b07ac81d4
@ -1,10 +1,13 @@
|
|||||||
from ._base_ophook import BaseOpHook
|
|
||||||
from ._memtracer_ophook import MemTracerOpHook
|
|
||||||
from ._shard_param_ophook import ShardParamHook
|
|
||||||
import torch
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook"]
|
import torch
|
||||||
|
|
||||||
|
from ._base_ophook import BaseOpHook
|
||||||
|
from ._memtracer_ophook import MemTracerOpHook
|
||||||
|
from ._shard_grad_ophook import ShardGradHook
|
||||||
|
from ._shard_param_ophook import ShardParamHook
|
||||||
|
|
||||||
|
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
|
||||||
|
|
||||||
|
|
||||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||||
|
31
colossalai/engine/ophooks/_shard_grad_ophook.py
Normal file
31
colossalai/engine/ophooks/_shard_grad_ophook.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
from colossalai.registry import OPHOOKS
|
||||||
|
|
||||||
|
from . import BaseOpHook
|
||||||
|
|
||||||
|
|
||||||
|
@OPHOOKS.register_module
|
||||||
|
class ShardGradHook(BaseOpHook):
|
||||||
|
"""
|
||||||
|
A hook to process sharded param before and afther FWD and BWD operator executing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||||
|
for param in module.parameters():
|
||||||
|
assert hasattr(param, '_sharded_grad')
|
||||||
|
param._sharded_grad.setup()
|
||||||
|
|
||||||
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_iter(self):
|
||||||
|
pass
|
@ -1,9 +1,10 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
import torch.distributed as dist
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
import torch.distributed as dist
|
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||||
|
|
||||||
|
|
||||||
class TensorType(Enum):
|
class TensorType(Enum):
|
||||||
@ -27,9 +28,11 @@ class ShardParam(object):
|
|||||||
self.world_size = dist.get_world_size(self.process_group)
|
self.world_size = dist.get_world_size(self.process_group)
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
|
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
|
||||||
|
self._payload_shape = None
|
||||||
self._payload_numel = None
|
self._payload_numel = None
|
||||||
self._origin_shape = param.shape
|
self._origin_shape = param.shape
|
||||||
self._origin_numel = param.numel()
|
self._origin_numel = param.numel()
|
||||||
|
self._origin_dtype = param.dtype
|
||||||
self.is_sharded = False
|
self.is_sharded = False
|
||||||
|
|
||||||
def payload(self, target_device: torch.device):
|
def payload(self, target_device: torch.device):
|
||||||
@ -65,3 +68,7 @@ class ShardParam(object):
|
|||||||
async_op=False)
|
async_op=False)
|
||||||
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
||||||
self.is_sharded = False
|
self.is_sharded = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def origin_dtype(self):
|
||||||
|
return self._origin_dtype
|
||||||
|
@ -190,10 +190,6 @@ class ReduceScatterBucketer:
|
|||||||
return int(bucket_size // num_shards)
|
return int(bucket_size // num_shards)
|
||||||
|
|
||||||
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||||||
# TODO (Min): the `group` used here in the key is the object hash, not the content
|
|
||||||
# hash. That means if FSDP instances are initialized with different process groups,
|
|
||||||
# even when the group members are in fact the same, we end up creating different
|
|
||||||
# buckets here.
|
|
||||||
key = (tensor.dtype, tensor.device, group)
|
key = (tensor.dtype, tensor.device, group)
|
||||||
if key not in self.buckets:
|
if key not in self.buckets:
|
||||||
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
|
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
|
||||||
|
85
colossalai/zero/sharded_model/sharded_grad.py
Normal file
85
colossalai/zero/sharded_model/sharded_grad.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedGradient:
|
||||||
|
def __init__(self,
|
||||||
|
param: Parameter,
|
||||||
|
sharded_module: nn.Module,
|
||||||
|
offload_config: Optional[dict] = None
|
||||||
|
) -> None:
|
||||||
|
assert hasattr(
|
||||||
|
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
|
||||||
|
|
||||||
|
self.param = param
|
||||||
|
self.sharded_module = sharded_module
|
||||||
|
self.offload_config = offload_config
|
||||||
|
|
||||||
|
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||||
|
|
||||||
|
# _gpu_grad is either sharded or not
|
||||||
|
# all saved grads are fp32
|
||||||
|
self._gpu_grad: Optional[torch.Tensor] = None
|
||||||
|
self._cpu_grad: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
if self._cpu_offload:
|
||||||
|
# this buffer will be held and reused every iteration
|
||||||
|
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def setup(self) -> None:
|
||||||
|
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
|
||||||
|
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
|
||||||
|
|
||||||
|
:raises AssertionError: Raise if grad shape is wrong
|
||||||
|
"""
|
||||||
|
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
|
||||||
|
if self.param.grad.device != self.param.data.device:
|
||||||
|
# TODO: offload?
|
||||||
|
raise RuntimeError(
|
||||||
|
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
|
||||||
|
else:
|
||||||
|
self._gpu_grad = self.param.grad.data
|
||||||
|
self.param.grad = None
|
||||||
|
|
||||||
|
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
|
||||||
|
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
|
||||||
|
|
||||||
|
:param reduced_grad: the reduced grad
|
||||||
|
:type reduced_grad: torch.Tensor
|
||||||
|
"""
|
||||||
|
# Make sure we store fp32 grad
|
||||||
|
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
|
||||||
|
reduced_grad.data = reduced_grad.data.to(torch.float)
|
||||||
|
|
||||||
|
if self._gpu_grad is None:
|
||||||
|
self._gpu_grad = reduced_grad.data
|
||||||
|
else:
|
||||||
|
self._gpu_grad += reduced_grad.data
|
||||||
|
|
||||||
|
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
|
||||||
|
# backwards pass completes, we will set `.grad` to the CPU copy.
|
||||||
|
if self._cpu_offload:
|
||||||
|
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
|
||||||
|
# Don't let this memory get reused until after the transfer.
|
||||||
|
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def write_back(self) -> None:
|
||||||
|
"""This function will be called in final backward hook
|
||||||
|
"""
|
||||||
|
if self._cpu_grad is not None:
|
||||||
|
assert self.param.device == torch.device(
|
||||||
|
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
|
||||||
|
self.param.grad.data = self._cpu_grad
|
||||||
|
elif self._gpu_grad is not None:
|
||||||
|
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
|
||||||
|
self.param.grad.data = self._gpu_grad
|
||||||
|
else:
|
||||||
|
raise RuntimeError('No grad to write back')
|
||||||
|
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
|
||||||
|
# They should be released here
|
||||||
|
self._gpu_grad = None
|
@ -1,31 +1,37 @@
|
|||||||
|
|
||||||
import contextlib
|
|
||||||
import copy
|
|
||||||
import functools
|
import functools
|
||||||
import os
|
from typing import Any, Optional
|
||||||
import traceback
|
|
||||||
from collections import OrderedDict
|
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
|
|
||||||
Set, Union)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
|
||||||
|
register_ophooks_recursively)
|
||||||
|
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook, ShardParamHook
|
|
||||||
from colossalai.zero.shard_param import ShardParam
|
from colossalai.zero.shard_param import ShardParam
|
||||||
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
|
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
process_group: Optional[ProcessGroup] = None,
|
process_group: Optional[ProcessGroup] = None,
|
||||||
reduce_scatter_process_group: Optional[ProcessGroup] = None
|
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||||
):
|
reduce_scatter_bucket_size_mb: int = 25,
|
||||||
|
reshard_after_forward: bool = True,
|
||||||
|
mixed_precision: bool = False,
|
||||||
|
fp32_reduce_scatter: bool = False,
|
||||||
|
offload_config: Optional[dict] = None,
|
||||||
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
A demo to reconfigure zero1 shared_model.
|
A demo to reconfigure zero1 shared_model.
|
||||||
Currently do not consider the Optimizer States.
|
Currently do not consider the Optimizer States.
|
||||||
@ -45,19 +51,111 @@ class ShardedModelV2(nn.Module):
|
|||||||
for _, param in self.module.named_parameters():
|
for _, param in self.module.named_parameters():
|
||||||
param.ca_attr = ShardParam(param)
|
param.ca_attr = ShardParam(param)
|
||||||
param.ca_attr.shard()
|
param.ca_attr.shard()
|
||||||
|
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
||||||
|
|
||||||
# Register hooks
|
# Register hooks
|
||||||
register_ophooks_recursively(self.module, [ShardParamHook()])
|
register_ophooks_recursively(self.module, [ShardParamHook(), ShardGradHook()])
|
||||||
|
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||||
|
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||||
|
|
||||||
|
self.reshard_after_forward = reshard_after_forward
|
||||||
|
self.mixed_precision = mixed_precision
|
||||||
|
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||||
|
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||||
|
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||||
|
# So we use 1.0 as the default gradient_predivide_factor
|
||||||
|
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||||
|
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||||
|
get_gradient_predivide_factor(self.world_size)
|
||||||
|
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||||
|
|
||||||
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||||
|
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
||||||
|
self._require_backward_grad_sync: bool = True
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
if self.loss_scaler:
|
loss.backward()
|
||||||
self.loss_scaler.backward(loss)
|
self._final_backward_hook()
|
||||||
else:
|
|
||||||
loss.backward()
|
@torch.no_grad()
|
||||||
|
def _final_backward_hook(self) -> None:
|
||||||
|
if self._require_backward_grad_sync:
|
||||||
|
# Flush any unreduced buckets in the post_backward stream.
|
||||||
|
with torch.cuda.stream(self.comm_stream):
|
||||||
|
self.reducer.flush()
|
||||||
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||||
|
if self._cpu_offload:
|
||||||
|
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.reducer.free()
|
||||||
|
for p in self.module.parameters():
|
||||||
|
if not p.requires_grad:
|
||||||
|
continue
|
||||||
|
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
|
||||||
|
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
|
||||||
|
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
|
||||||
|
# sync passes, if desired.
|
||||||
|
if not self._require_backward_grad_sync:
|
||||||
|
continue
|
||||||
|
p._sharded_grad.write_back()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||||
|
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
|
||||||
|
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
||||||
|
|
||||||
|
before reduce_scatter:
|
||||||
|
param.grad (GPU #0): [1, 2, 3, 4]
|
||||||
|
param.grad (GPU #1): [5, 6, 7, 8]
|
||||||
|
|
||||||
|
after reduce_scatter:
|
||||||
|
param.grad (GPU #0): [6, 8] # 1+5, 2+6
|
||||||
|
param.grad (GPU #1): [10, 12] # 3+7, 4+8
|
||||||
|
|
||||||
|
The local GPU's ``optim.step`` is responsible for updating a single
|
||||||
|
shard of params, also corresponding to the current GPU's rank. This
|
||||||
|
alignment is created by `param._sharded_grad`, which ensures that
|
||||||
|
the local optimizer only sees the relevant parameter shard.
|
||||||
|
"""
|
||||||
|
if grad is None:
|
||||||
|
return
|
||||||
|
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
||||||
|
if not self._require_backward_grad_sync:
|
||||||
|
return
|
||||||
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(self.comm_stream):
|
||||||
|
new_grad = grad.clone()
|
||||||
|
if self.mixed_precision and self.fp32_reduce_scatter:
|
||||||
|
new_grad.data = new_grad.data.to(param.dtype)
|
||||||
|
if self.gradient_predivide_factor > 1.0:
|
||||||
|
# Average grad by world_size for consistency with PyTorch DDP.
|
||||||
|
new_grad.data.div_(self.gradient_predivide_factor)
|
||||||
|
orig_grad_data = new_grad.data
|
||||||
|
if self.world_size > 1:
|
||||||
|
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||||
|
self.reducer.reduce_scatter_async(
|
||||||
|
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||||
|
else:
|
||||||
|
self._reduce_scatter_callback(param, new_grad)
|
||||||
|
orig_grad_data.record_stream(self.comm_stream)
|
||||||
|
|
||||||
|
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||||
|
if self.gradient_postdivide_factor > 1:
|
||||||
|
# Average grad by world_size for consistency with PyTorch DDP.
|
||||||
|
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||||
|
# Cast grad to param's dtype (typically FP32). Note: we do this
|
||||||
|
# before the cpu offload step so that this entire hook remains
|
||||||
|
# non-blocking. The downside is a bit more D2H transfer in that case.
|
||||||
|
if self.mixed_precision:
|
||||||
|
orig_param_grad_data = reduced_grad.data
|
||||||
|
reduced_grad.data = reduced_grad.data.to(dtype=param.ca_attr.origin_dtype)
|
||||||
|
# Don't let this memory get reused until after the transfer.
|
||||||
|
orig_param_grad_data.record_stream(torch.cuda.current_stream())
|
||||||
|
|
||||||
|
param._sharded_grad.reduce_scatter_callback(reduced_grad)
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from operator import imod
|
|
||||||
from colossalai.utils import checkpoint
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.utils import checkpoint
|
||||||
|
|
||||||
LOGGER = get_dist_logger()
|
LOGGER = get_dist_logger()
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ CONFIG = dict(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_wrapper(module, enable=True):
|
def checkpoint_wrapper(module, enable=True):
|
||||||
if enable:
|
if enable:
|
||||||
module.forward = partial(checkpoint, module.forward)
|
module.forward = partial(checkpoint, module.forward)
|
||||||
@ -61,6 +63,7 @@ class Net(nn.Module):
|
|||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||||
if loose:
|
if loose:
|
||||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||||
@ -72,7 +75,8 @@ def check_grads(model, zero_model, loose=False):
|
|||||||
zero_grad = zero_p.grad.clone().to(p.device)
|
zero_grad = zero_p.grad.clone().to(p.device)
|
||||||
assert p.grad.dtype == zero_grad.dtype
|
assert p.grad.dtype == zero_grad.dtype
|
||||||
assert allclose(p.grad, zero_grad, loose=loose)
|
assert allclose(p.grad, zero_grad, loose=loose)
|
||||||
LOGGER.info(torch.sum(p.grad-zero_grad))
|
LOGGER.info(torch.sum(p.grad - zero_grad))
|
||||||
|
|
||||||
|
|
||||||
def check_params(model, zero_model, loose=False):
|
def check_params(model, zero_model, loose=False):
|
||||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||||
@ -80,3 +84,30 @@ def check_params(model, zero_model, loose=False):
|
|||||||
assert p.dtype == zero_p.dtype
|
assert p.dtype == zero_p.dtype
|
||||||
assert allclose(p, zero_p, loose=loose)
|
assert allclose(p, zero_p, loose=loose)
|
||||||
|
|
||||||
|
|
||||||
|
def check_grads_padding(model, zero_model, loose=False):
|
||||||
|
rank = dist.get_rank()
|
||||||
|
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||||
|
zero_grad = zero_p.grad.clone().to(p.device)
|
||||||
|
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||||
|
if rank >= len(chunks):
|
||||||
|
continue
|
||||||
|
grad = chunks[rank]
|
||||||
|
if zero_grad.size(0) > grad.size(0):
|
||||||
|
zero_grad = zero_grad[:grad.size(0)]
|
||||||
|
assert grad.dtype == zero_grad.dtype
|
||||||
|
assert allclose(grad, zero_grad, loose=loose)
|
||||||
|
|
||||||
|
|
||||||
|
def check_params_padding(model, zero_model, loose=False):
|
||||||
|
rank = dist.get_rank()
|
||||||
|
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||||
|
zero_p = zero_p.clone().to(p.device)
|
||||||
|
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||||
|
if rank >= len(chunks):
|
||||||
|
continue
|
||||||
|
p = chunks[rank]
|
||||||
|
if zero_p.size(0) > p.size(0):
|
||||||
|
zero_p = zero_p[:p.size(0)]
|
||||||
|
assert p.dtype == zero_p.dtype
|
||||||
|
assert allclose(p, zero_p, loose=loose)
|
||||||
|
@ -3,19 +3,18 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from operator import mod
|
|
||||||
from pyexpat import model
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from common import CONFIG, Net, check_grads, check_grads_padding
|
||||||
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
|
|
||||||
|
|
||||||
|
|
||||||
def run_fwd_bwd(model, x, enable_autocast=False):
|
def run_fwd_bwd(model, x, enable_autocast=False):
|
||||||
@ -24,8 +23,11 @@ def run_fwd_bwd(model, x, enable_autocast=False):
|
|||||||
y = model(x)
|
y = model(x)
|
||||||
loss = y.sum()
|
loss = y.sum()
|
||||||
loss = loss.float()
|
loss = loss.float()
|
||||||
loss.backward()
|
if isinstance(model, ShardedModelV2):
|
||||||
|
model.backward(loss)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG,
|
colossalai.launch(config=CONFIG,
|
||||||
@ -34,7 +36,7 @@ def run_dist(rank, world_size, port):
|
|||||||
host='localhost',
|
host='localhost',
|
||||||
port=port,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
model = Net(checkpoint=True).cuda()
|
model = Net(checkpoint=True).cuda()
|
||||||
zero_model = copy.deepcopy(model)
|
zero_model = copy.deepcopy(model)
|
||||||
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
|
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||||
@ -43,7 +45,10 @@ def run_dist(rank, world_size, port):
|
|||||||
x = torch.rand(2, 5).cuda()
|
x = torch.rand(2, 5).cuda()
|
||||||
run_fwd_bwd(zero_model, x, False)
|
run_fwd_bwd(zero_model, x, False)
|
||||||
run_fwd_bwd(model, x, False)
|
run_fwd_bwd(model, x, False)
|
||||||
check_grads(model, zero_model)
|
if dist.get_world_size() > 1:
|
||||||
|
check_grads_padding(model, zero_model)
|
||||||
|
else:
|
||||||
|
check_grads(model, zero_model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -14,7 +14,9 @@ from colossalai.logging import disable_existing_loggers
|
|||||||
from colossalai.utils import checkpoint, free_port
|
from colossalai.utils import checkpoint, free_port
|
||||||
from colossalai.zero.sharded_model import ShardedModel
|
from colossalai.zero.sharded_model import ShardedModel
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from common import Net, allclose
|
|
||||||
|
from common import Net, check_grads_padding, check_params_padding
|
||||||
|
|
||||||
|
|
||||||
def run_step(model, optimizer, x, enable_autocast=False):
|
def run_step(model, optimizer, x, enable_autocast=False):
|
||||||
model.train()
|
model.train()
|
||||||
@ -26,34 +28,6 @@ def run_step(model, optimizer, x, enable_autocast=False):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
def check_grads_padding(model, zero_model, loose=False):
|
|
||||||
rank = dist.get_rank()
|
|
||||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
|
||||||
zero_grad = zero_p.grad.clone().to(p.device)
|
|
||||||
chunks = torch.flatten(p.grad).chunk(4)
|
|
||||||
if rank >= len(chunks):
|
|
||||||
continue
|
|
||||||
grad = chunks[rank]
|
|
||||||
if zero_p.zero_shard_padding > 0:
|
|
||||||
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
|
|
||||||
assert grad.dtype == zero_grad.dtype
|
|
||||||
assert allclose(grad, zero_grad, loose=loose)
|
|
||||||
|
|
||||||
|
|
||||||
def check_params_padding(model, zero_model, loose=False):
|
|
||||||
rank = dist.get_rank()
|
|
||||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
|
||||||
zero_shard_padding = zero_p.zero_shard_padding
|
|
||||||
zero_p = zero_p.clone().to(p.device)
|
|
||||||
chunks = torch.flatten(p).chunk(4)
|
|
||||||
if rank >= len(chunks):
|
|
||||||
continue
|
|
||||||
p = chunks[rank]
|
|
||||||
if zero_shard_padding > 0:
|
|
||||||
zero_p = zero_p[:-zero_shard_padding]
|
|
||||||
assert p.dtype == zero_p.dtype
|
|
||||||
assert allclose(p, zero_p, loose=loose)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_booleans(intval, bits):
|
def decode_booleans(intval, bits):
|
||||||
res = []
|
res = []
|
||||||
|
Loading…
Reference in New Issue
Block a user