mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 00:07:29 +00:00
[zero] add sharded grad and refactor grad hooks for ShardedModel (#287)
This commit is contained in:
parent
4fbb8db586
commit
9b07ac81d4
@ -1,10 +1,13 @@
|
||||
from ._base_ophook import BaseOpHook
|
||||
from ._memtracer_ophook import MemTracerOpHook
|
||||
from ._shard_param_ophook import ShardParamHook
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook"]
|
||||
import torch
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from ._memtracer_ophook import MemTracerOpHook
|
||||
from ._shard_grad_ophook import ShardGradHook
|
||||
from ._shard_param_ophook import ShardParamHook
|
||||
|
||||
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
|
||||
|
||||
|
||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||
|
31
colossalai/engine/ophooks/_shard_grad_ophook.py
Normal file
31
colossalai/engine/ophooks/_shard_grad_ophook.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from colossalai.registry import OPHOOKS
|
||||
|
||||
from . import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ShardGradHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param before and afther FWD and BWD operator executing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, '_sharded_grad')
|
||||
param._sharded_grad.setup()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
@ -1,9 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
import torch.distributed as dist
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.distributed as dist
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
|
||||
|
||||
class TensorType(Enum):
|
||||
@ -27,9 +28,11 @@ class ShardParam(object):
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
|
||||
self._payload_shape = None
|
||||
self._payload_numel = None
|
||||
self._origin_shape = param.shape
|
||||
self._origin_numel = param.numel()
|
||||
self._origin_dtype = param.dtype
|
||||
self.is_sharded = False
|
||||
|
||||
def payload(self, target_device: torch.device):
|
||||
@ -65,3 +68,7 @@ class ShardParam(object):
|
||||
async_op=False)
|
||||
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
||||
self.is_sharded = False
|
||||
|
||||
@property
|
||||
def origin_dtype(self):
|
||||
return self._origin_dtype
|
||||
|
@ -190,10 +190,6 @@ class ReduceScatterBucketer:
|
||||
return int(bucket_size // num_shards)
|
||||
|
||||
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||||
# TODO (Min): the `group` used here in the key is the object hash, not the content
|
||||
# hash. That means if FSDP instances are initialized with different process groups,
|
||||
# even when the group members are in fact the same, we end up creating different
|
||||
# buckets here.
|
||||
key = (tensor.dtype, tensor.device, group)
|
||||
if key not in self.buckets:
|
||||
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
|
||||
|
85
colossalai/zero/sharded_model/sharded_grad.py
Normal file
85
colossalai/zero/sharded_model/sharded_grad.py
Normal file
@ -0,0 +1,85 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class ShardedGradient:
|
||||
def __init__(self,
|
||||
param: Parameter,
|
||||
sharded_module: nn.Module,
|
||||
offload_config: Optional[dict] = None
|
||||
) -> None:
|
||||
assert hasattr(
|
||||
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
|
||||
|
||||
self.param = param
|
||||
self.sharded_module = sharded_module
|
||||
self.offload_config = offload_config
|
||||
|
||||
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
|
||||
# _gpu_grad is either sharded or not
|
||||
# all saved grads are fp32
|
||||
self._gpu_grad: Optional[torch.Tensor] = None
|
||||
self._cpu_grad: Optional[torch.Tensor] = None
|
||||
|
||||
if self._cpu_offload:
|
||||
# this buffer will be held and reused every iteration
|
||||
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
|
||||
|
||||
@torch.no_grad()
|
||||
def setup(self) -> None:
|
||||
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
|
||||
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
|
||||
|
||||
:raises AssertionError: Raise if grad shape is wrong
|
||||
"""
|
||||
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
|
||||
if self.param.grad.device != self.param.data.device:
|
||||
# TODO: offload?
|
||||
raise RuntimeError(
|
||||
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
|
||||
else:
|
||||
self._gpu_grad = self.param.grad.data
|
||||
self.param.grad = None
|
||||
|
||||
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
|
||||
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
|
||||
|
||||
:param reduced_grad: the reduced grad
|
||||
:type reduced_grad: torch.Tensor
|
||||
"""
|
||||
# Make sure we store fp32 grad
|
||||
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
|
||||
reduced_grad.data = reduced_grad.data.to(torch.float)
|
||||
|
||||
if self._gpu_grad is None:
|
||||
self._gpu_grad = reduced_grad.data
|
||||
else:
|
||||
self._gpu_grad += reduced_grad.data
|
||||
|
||||
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
|
||||
# backwards pass completes, we will set `.grad` to the CPU copy.
|
||||
if self._cpu_offload:
|
||||
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
|
||||
# Don't let this memory get reused until after the transfer.
|
||||
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
|
||||
|
||||
@torch.no_grad()
|
||||
def write_back(self) -> None:
|
||||
"""This function will be called in final backward hook
|
||||
"""
|
||||
if self._cpu_grad is not None:
|
||||
assert self.param.device == torch.device(
|
||||
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
|
||||
self.param.grad.data = self._cpu_grad
|
||||
elif self._gpu_grad is not None:
|
||||
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
|
||||
self.param.grad.data = self._gpu_grad
|
||||
else:
|
||||
raise RuntimeError('No grad to write back')
|
||||
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
|
||||
# They should be released here
|
||||
self._gpu_grad = None
|
@ -1,31 +1,37 @@
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from enum import Enum, auto
|
||||
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
|
||||
Set, Union)
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
|
||||
register_ophooks_recursively)
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook, ShardParamHook
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None
|
||||
):
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
reshard_after_forward: bool = True,
|
||||
mixed_precision: bool = False,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
):
|
||||
r"""
|
||||
A demo to reconfigure zero1 shared_model.
|
||||
Currently do not consider the Optimizer States.
|
||||
@ -45,19 +51,111 @@ class ShardedModelV2(nn.Module):
|
||||
for _, param in self.module.named_parameters():
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr.shard()
|
||||
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
||||
|
||||
# Register hooks
|
||||
register_ophooks_recursively(self.module, [ShardParamHook()])
|
||||
register_ophooks_recursively(self.module, [ShardParamHook(), ShardGradHook()])
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.reshard_after_forward = reshard_after_forward
|
||||
self.mixed_precision = mixed_precision
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
get_gradient_predivide_factor(self.world_size)
|
||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
||||
self._require_backward_grad_sync: bool = True
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
outputs = self.module(*args, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
def backward(self, loss):
|
||||
if self.loss_scaler:
|
||||
self.loss_scaler.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
loss.backward()
|
||||
self._final_backward_hook()
|
||||
|
||||
@torch.no_grad()
|
||||
def _final_backward_hook(self) -> None:
|
||||
if self._require_backward_grad_sync:
|
||||
# Flush any unreduced buckets in the post_backward stream.
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self.reducer.flush()
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
if self._cpu_offload:
|
||||
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.reducer.free()
|
||||
for p in self.module.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
|
||||
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
|
||||
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
|
||||
# sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
p._sharded_grad.write_back()
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
|
||||
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
||||
|
||||
before reduce_scatter:
|
||||
param.grad (GPU #0): [1, 2, 3, 4]
|
||||
param.grad (GPU #1): [5, 6, 7, 8]
|
||||
|
||||
after reduce_scatter:
|
||||
param.grad (GPU #0): [6, 8] # 1+5, 2+6
|
||||
param.grad (GPU #1): [10, 12] # 3+7, 4+8
|
||||
|
||||
The local GPU's ``optim.step`` is responsible for updating a single
|
||||
shard of params, also corresponding to the current GPU's rank. This
|
||||
alignment is created by `param._sharded_grad`, which ensures that
|
||||
the local optimizer only sees the relevant parameter shard.
|
||||
"""
|
||||
if grad is None:
|
||||
return
|
||||
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
||||
if not self._require_backward_grad_sync:
|
||||
return
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
new_grad = grad.clone()
|
||||
if self.mixed_precision and self.fp32_reduce_scatter:
|
||||
new_grad.data = new_grad.data.to(param.dtype)
|
||||
if self.gradient_predivide_factor > 1.0:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
new_grad.data.div_(self.gradient_predivide_factor)
|
||||
orig_grad_data = new_grad.data
|
||||
if self.world_size > 1:
|
||||
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||
self.reducer.reduce_scatter_async(
|
||||
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
else:
|
||||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
|
||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
# Cast grad to param's dtype (typically FP32). Note: we do this
|
||||
# before the cpu offload step so that this entire hook remains
|
||||
# non-blocking. The downside is a bit more D2H transfer in that case.
|
||||
if self.mixed_precision:
|
||||
orig_param_grad_data = reduced_grad.data
|
||||
reduced_grad.data = reduced_grad.data.to(dtype=param.ca_attr.origin_dtype)
|
||||
# Don't let this memory get reused until after the transfer.
|
||||
orig_param_grad_data.record_stream(torch.cuda.current_stream())
|
||||
|
||||
param._sharded_grad.reduce_scatter_callback(reduced_grad)
|
||||
|
@ -1,9 +1,10 @@
|
||||
from functools import partial
|
||||
from operator import imod
|
||||
from colossalai.utils import checkpoint
|
||||
import torch.nn as nn
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.utils import checkpoint
|
||||
|
||||
LOGGER = get_dist_logger()
|
||||
|
||||
@ -34,6 +35,7 @@ CONFIG = dict(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
@ -61,6 +63,7 @@ class Net(nn.Module):
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
@ -72,7 +75,8 @@ def check_grads(model, zero_model, loose=False):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
assert p.grad.dtype == zero_grad.dtype
|
||||
assert allclose(p.grad, zero_grad, loose=loose)
|
||||
LOGGER.info(torch.sum(p.grad-zero_grad))
|
||||
LOGGER.info(torch.sum(p.grad - zero_grad))
|
||||
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
@ -80,3 +84,30 @@ def check_params(model, zero_model, loose=False):
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
@ -3,19 +3,18 @@
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from operator import mod
|
||||
from pyexpat import model
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
|
||||
|
||||
from common import CONFIG, Net, check_grads, check_grads_padding
|
||||
|
||||
|
||||
def run_fwd_bwd(model, x, enable_autocast=False):
|
||||
@ -24,8 +23,11 @@ def run_fwd_bwd(model, x, enable_autocast=False):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
@ -34,7 +36,7 @@ def run_dist(rank, world_size, port):
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
|
||||
model = Net(checkpoint=True).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
@ -43,7 +45,10 @@ def run_dist(rank, world_size, port):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_fwd_bwd(zero_model, x, False)
|
||||
run_fwd_bwd(model, x, False)
|
||||
check_grads(model, zero_model)
|
||||
if dist.get_world_size() > 1:
|
||||
check_grads_padding(model, zero_model)
|
||||
else:
|
||||
check_grads(model, zero_model)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@ -14,7 +14,9 @@ from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from common import Net, allclose
|
||||
|
||||
from common import Net, check_grads_padding, check_params_padding
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False):
|
||||
model.train()
|
||||
@ -26,34 +28,6 @@ def run_step(model, optimizer, x, enable_autocast=False):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
if zero_p.zero_shard_padding > 0:
|
||||
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_shard_padding = zero_p.zero_shard_padding
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_shard_padding > 0:
|
||||
zero_p = zero_p[:-zero_shard_padding]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def decode_booleans(intval, bits):
|
||||
res = []
|
||||
|
Loading…
Reference in New Issue
Block a user