mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-19 09:04:50 +00:00
only process module's own parameters in Zero context add zero hooks for all modules that contrain parameters gather parameters only belonging to module itself
92 lines
3.9 KiB
Python
92 lines
3.9 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossalai.registry import OPHOOKS
|
|
from colossalai.utils import get_current_device
|
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
|
|
|
from ._base_ophook import BaseOpHook
|
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
|
|
|
|
|
@OPHOOKS.register_module
|
|
class ZeroHook(BaseOpHook):
|
|
"""
|
|
A hook to process sharded param for ZeRO method.
|
|
"""
|
|
|
|
def __init__(self,
|
|
shard_strategy: BaseShardStrategy,
|
|
memstarts_collector: Optional[MemStatsCollector],
|
|
process_group: Optional[dist.ProcessGroup] = None):
|
|
super().__init__()
|
|
self.shard_strategy = shard_strategy
|
|
self.process_group = process_group
|
|
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
|
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
|
|
|
self._memstarts_collector = memstarts_collector
|
|
|
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
|
tensor_list = []
|
|
for param in module.parameters(recurse=False):
|
|
assert hasattr(param, 'col_attr')
|
|
tensor_list.append(param.col_attr.sharded_data_tensor)
|
|
self.shard_strategy.gather(tensor_list, self.process_group)
|
|
for param in module.parameters(recurse=False):
|
|
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
|
param.data = param.col_attr.sharded_data_tensor.payload
|
|
|
|
if self._memstarts_collector:
|
|
self._memstarts_collector.sample_memstats()
|
|
|
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
|
tensor_list = []
|
|
for param in module.parameters(recurse=False):
|
|
assert hasattr(param, 'col_attr')
|
|
tensor_list.append(param.col_attr.sharded_data_tensor)
|
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
|
for param in module.parameters(recurse=False):
|
|
param.col_attr.remove_torch_payload()
|
|
|
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
|
tensor_list = []
|
|
for param in module.parameters(recurse=False):
|
|
assert hasattr(param, 'col_attr')
|
|
tensor_list.append(param.col_attr.sharded_data_tensor)
|
|
self.shard_strategy.gather(tensor_list, self.process_group)
|
|
for param in module.parameters(recurse=False):
|
|
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
|
param.data = param.col_attr.sharded_data_tensor.payload
|
|
# Store local accumulated grad shard
|
|
if param.grad is not None:
|
|
if param.col_attr.bwd_count == 0:
|
|
# We haven't stored local accumulated grad yet
|
|
assert param.col_attr.fp32_grad is None
|
|
param.col_attr.fp32_grad = param.grad.data
|
|
param.grad = None
|
|
else:
|
|
# We have stored local accumulated grad
|
|
# The grad here must be locally computed full grad in this backward pass
|
|
assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape
|
|
param.col_attr.bwd_count += 1
|
|
if self._memstarts_collector:
|
|
self._memstarts_collector.sample_memstats()
|
|
|
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
|
tensor_list = []
|
|
for param in module.parameters(recurse=False):
|
|
assert hasattr(param, 'col_attr')
|
|
tensor_list.append(param.col_attr.sharded_data_tensor)
|
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
|
for param in module.parameters(recurse=False):
|
|
param.col_attr.remove_torch_payload()
|
|
|
|
def pre_iter(self):
|
|
pass
|
|
|
|
def post_iter(self):
|
|
pass
|