mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 23:02:07 +00:00
[zero] add sharded grad and refactor grad hooks for ShardedModel (#287)
This commit is contained in:
85
colossalai/zero/sharded_model/sharded_grad.py
Normal file
85
colossalai/zero/sharded_model/sharded_grad.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class ShardedGradient:
|
||||
def __init__(self,
|
||||
param: Parameter,
|
||||
sharded_module: nn.Module,
|
||||
offload_config: Optional[dict] = None
|
||||
) -> None:
|
||||
assert hasattr(
|
||||
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
|
||||
|
||||
self.param = param
|
||||
self.sharded_module = sharded_module
|
||||
self.offload_config = offload_config
|
||||
|
||||
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
|
||||
# _gpu_grad is either sharded or not
|
||||
# all saved grads are fp32
|
||||
self._gpu_grad: Optional[torch.Tensor] = None
|
||||
self._cpu_grad: Optional[torch.Tensor] = None
|
||||
|
||||
if self._cpu_offload:
|
||||
# this buffer will be held and reused every iteration
|
||||
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
|
||||
|
||||
@torch.no_grad()
|
||||
def setup(self) -> None:
|
||||
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
|
||||
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
|
||||
|
||||
:raises AssertionError: Raise if grad shape is wrong
|
||||
"""
|
||||
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
|
||||
if self.param.grad.device != self.param.data.device:
|
||||
# TODO: offload?
|
||||
raise RuntimeError(
|
||||
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
|
||||
else:
|
||||
self._gpu_grad = self.param.grad.data
|
||||
self.param.grad = None
|
||||
|
||||
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
|
||||
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
|
||||
|
||||
:param reduced_grad: the reduced grad
|
||||
:type reduced_grad: torch.Tensor
|
||||
"""
|
||||
# Make sure we store fp32 grad
|
||||
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
|
||||
reduced_grad.data = reduced_grad.data.to(torch.float)
|
||||
|
||||
if self._gpu_grad is None:
|
||||
self._gpu_grad = reduced_grad.data
|
||||
else:
|
||||
self._gpu_grad += reduced_grad.data
|
||||
|
||||
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
|
||||
# backwards pass completes, we will set `.grad` to the CPU copy.
|
||||
if self._cpu_offload:
|
||||
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
|
||||
# Don't let this memory get reused until after the transfer.
|
||||
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
|
||||
|
||||
@torch.no_grad()
|
||||
def write_back(self) -> None:
|
||||
"""This function will be called in final backward hook
|
||||
"""
|
||||
if self._cpu_grad is not None:
|
||||
assert self.param.device == torch.device(
|
||||
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
|
||||
self.param.grad.data = self._cpu_grad
|
||||
elif self._gpu_grad is not None:
|
||||
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
|
||||
self.param.grad.data = self._gpu_grad
|
||||
else:
|
||||
raise RuntimeError('No grad to write back')
|
||||
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
|
||||
# They should be released here
|
||||
self._gpu_grad = None
|
Reference in New Issue
Block a user