mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[zero] Update initialize for ZeRO (#458)
* polish code * shard strategy receive pg in shard() / gather() * update zero engine * polish code
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
@@ -17,9 +18,13 @@ class ZeroHook(BaseOpHook):
|
||||
A hook to process sharded param for ZeRO method.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]):
|
||||
def __init__(self,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
memstarts_collector: Optional[MemStatsCollector],
|
||||
process_group: Optional[dist.ProcessGroup] = None):
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
self.process_group = process_group
|
||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
||||
|
||||
@@ -30,7 +35,7 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
self.shard_strategy.gather(tensor_list)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
@@ -45,7 +50,7 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
self.shard_strategy.shard(tensor_list)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
||||
@@ -54,7 +59,7 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
self.shard_strategy.gather(tensor_list)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
@@ -80,7 +85,7 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
self.shard_strategy.shard(tensor_list)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
||||
|
@@ -278,7 +278,10 @@ def initialize(model: nn.Module,
|
||||
cfg_ = {}
|
||||
optimizer_config = zero_cfg.get('optimizer_config', None)
|
||||
model_config = zero_cfg.get('model_config', None)
|
||||
model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config)
|
||||
model, optimizer = convert_to_zero_v2(model,
|
||||
optimizer,
|
||||
model_config=model_config,
|
||||
optimizer_config=optimizer_config)
|
||||
|
||||
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
|
||||
# FIXME() throw a warning if using zero with MP
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -11,7 +12,8 @@ from .sharded_model import ShardedModel
|
||||
from .sharded_optim import ShardedOptimizer
|
||||
|
||||
|
||||
def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
|
||||
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
||||
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
|
||||
"""
|
||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||
|
||||
@@ -34,7 +36,7 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl
|
||||
model_config = dict()
|
||||
|
||||
zero_model = ShardedModelV2(model, **model_config)
|
||||
zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config)
|
||||
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
|
||||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
|
@@ -1,26 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
import torch.distributed as dist
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class BaseShardStrategy(ABC):
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
def __init__(self) -> None:
|
||||
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
|
||||
|
||||
Args:
|
||||
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
|
||||
"""
|
||||
self.process_group = process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def shard(self, tensor_list: List[ShardedTensor]):
|
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gather(self, tensor_list: List[ShardedTensor]):
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
pass
|
||||
|
@@ -1,18 +1,17 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
|
||||
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor]):
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||
if len(tensor_list) == 0:
|
||||
return
|
||||
@@ -21,15 +20,17 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
buffer_list: List[torch.Tensor] = []
|
||||
tensor_numels = [t.payload.numel() for t in tensor_list]
|
||||
buffer_size = sum(tensor_numels)
|
||||
for i in range(self.world_size):
|
||||
if i == self.local_rank:
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||
# Release payload here, to decrease peak memory usage
|
||||
for t in tensor_list:
|
||||
t.reset_payload(None)
|
||||
else:
|
||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
|
||||
# Move to target device before splitting buffer
|
||||
# Ensure we utilize maximum PCIE bandwidth
|
||||
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
|
||||
|
@@ -2,49 +2,44 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
super().__init__(process_group)
|
||||
|
||||
def shard(self, tensor_list: List[ShardedTensor]):
|
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
for t in tensor_list:
|
||||
self._shard_tensor(t)
|
||||
self._shard_tensor(t, process_group)
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor]):
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
for t in tensor_list:
|
||||
self._gather_tensor(t)
|
||||
self._gather_tensor(t, process_group)
|
||||
|
||||
def _shard_tensor(self, t: ShardedTensor):
|
||||
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
if t.is_sharded:
|
||||
return
|
||||
sharded_payload, _ = get_shard(t.payload, self.local_rank, self.world_size)
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.reset_payload(sharded_payload)
|
||||
t.is_sharded = True
|
||||
|
||||
def _gather_tensor(self, t: ShardedTensor):
|
||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
if not t.is_sharded:
|
||||
return
|
||||
target_device = t.device
|
||||
buffer_list = []
|
||||
payload_numel = t.payload.numel()
|
||||
for i in range(self.world_size):
|
||||
if i == self.local_rank:
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
buffer_list.append(t.payload.cuda(get_current_device()))
|
||||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
|
||||
|
||||
torch.distributed.all_gather(buffer_list,
|
||||
buffer_list[self.local_rank],
|
||||
group=self.process_group,
|
||||
async_op=False)
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
t.to(target_device)
|
||||
|
@@ -70,7 +70,8 @@ class ShardedModelV2(nn.Module):
|
||||
self._iter_cnter = 0
|
||||
|
||||
# Register hooks
|
||||
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)])
|
||||
register_ophooks_recursively(self.module,
|
||||
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)])
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
@@ -145,7 +146,7 @@ class ShardedModelV2(nn.Module):
|
||||
if self.shard_param:
|
||||
for p in self.module.parameters():
|
||||
if not p.col_attr.param_is_sharded:
|
||||
self.shard_strategy.shard([p.col_attr.data])
|
||||
self.shard_strategy.shard([p.col_attr.data], self.process_group)
|
||||
for p in self.module.parameters():
|
||||
p.col_attr.bwd_count = 0
|
||||
if not p.requires_grad:
|
||||
@@ -229,13 +230,13 @@ class ShardedModelV2(nn.Module):
|
||||
param.col_attr.fp16_grad = reduced_grad.data
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])
|
||||
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()], self.process_group)
|
||||
prev_params = {}
|
||||
for p in self.module.parameters():
|
||||
prev_params[p] = p.data
|
||||
p.data = p.col_attr.data.payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()])
|
||||
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()], self.process_group)
|
||||
for p in self.module.parameters():
|
||||
p.data = prev_params[p]
|
||||
return gathered_state_dict
|
||||
|
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
|
||||
@@ -101,6 +102,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
# Store fp32 param shards
|
||||
self.master_params: Dict[Parameter, Tensor] = {}
|
||||
@@ -113,12 +115,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
# TODO (ver217): we may not use shard / gather here
|
||||
# Param is no sharded, which means we use ZeRO-2 here
|
||||
# As we only store param shard, we shard it here
|
||||
self.shard_strategy.shard([p.col_attr.data])
|
||||
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
|
||||
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device)
|
||||
if not is_param_sharded:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
self.shard_strategy.gather([p.col_attr.data])
|
||||
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
# unscale grads if scaled
|
||||
@@ -155,7 +157,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
# But we only have updated fp32 param shard here
|
||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||
# Then we will gather them
|
||||
self.shard_strategy.shard([p.col_attr.data])
|
||||
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
|
||||
# We have to use `copy_payload` instead of `reset_payload`
|
||||
# Since p.data is fp32 and p.col_attr.data is fp16
|
||||
|
||||
@@ -164,7 +166,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
|
||||
if not is_param_sharded:
|
||||
# We gather full fp16 param here
|
||||
self.shard_strategy.gather([p.col_attr.data])
|
||||
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
|
||||
p.data = p.col_attr.data.payload
|
||||
return ret
|
||||
|
||||
|
Reference in New Issue
Block a user