mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[auto-parallel] add auto-offload feature (#3154)
* add auto-offload feature * polish code * fix syn offload runtime pass bug * add offload example * fix offload testing bug * fix example testing bug
This commit is contained in:
109
colossalai/auto_parallel/offload/base_offload_module.py
Normal file
109
colossalai/auto_parallel/offload/base_offload_module.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Optional, Set
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.gemini.tensor_utils import free_storage
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .util import GlobalRuntimeInfo
|
||||
|
||||
|
||||
class BaseOffloadModule:
|
||||
"""
|
||||
BaseOffloadModule: A model wrapper for parameter offloading.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to apply offloading.
|
||||
region_manager (RegionManager): a ``RegionManager`` instance.
|
||||
is_sync (bool): synchronous mode or not.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
region_manager: RegionManager,
|
||||
is_sync=True):
|
||||
|
||||
self.model = model
|
||||
self.region_manager = region_manager
|
||||
self.grad_hook_list = []
|
||||
self.overflow_counter = torch.cuda.IntTensor([0])
|
||||
|
||||
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream
|
||||
|
||||
self._cast_buffers()
|
||||
|
||||
def register_grad_hook(self):
|
||||
for p in self.model.parameters():
|
||||
if p.requires_grad:
|
||||
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
|
||||
|
||||
def remove_grad_hook(self):
|
||||
for hook in self.grad_hook_list:
|
||||
hook.remove()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _pre_forward(self):
|
||||
self.register_grad_hook()
|
||||
for region in self.region_manager.region_list:
|
||||
region.cpu_grad = None
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
self._pre_forward()
|
||||
outputs = self.model(*args, **kwargs)
|
||||
return outputs
|
||||
|
||||
def backward(self, loss):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
def _post_backward(self):
|
||||
torch.cuda.synchronize()
|
||||
self.remove_grad_hook()
|
||||
|
||||
for p in self.model.parameters():
|
||||
p.grad = None
|
||||
|
||||
GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
|
||||
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
region = self.region_manager.get_region(p)
|
||||
region.copy_grad_to_region_slice(p, grad)
|
||||
if region.can_release:
|
||||
self.overflow_counter += region.has_inf_or_nan
|
||||
master_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.grad_offload_stream):
|
||||
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
|
||||
region.move_grad_to_cpu()
|
||||
return empty_grad
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.model.buffers():
|
||||
buffer.data = buffer.cuda()
|
||||
|
||||
def parameters(self, recurse: bool = True):
|
||||
return self.model.parameters(recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||
return self.model.named_parameters(prefix, recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||
return self.model.named_buffers(prefix, recurse)
|
||||
|
||||
def named_children(self):
|
||||
return self.model.named_children()
|
||||
|
||||
def named_modules(self,
|
||||
memo: Optional[Set[torch.nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
return self.model.named_modules(memo, prefix, remove_duplicate)
|
Reference in New Issue
Block a user