[zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api
This commit is contained in:
Hongxin Liu 2024-07-11 18:59:59 +08:00 committed by GitHub
parent dd9e1cdafe
commit c068ef0fa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 119 additions and 25 deletions

View File

@ -677,6 +677,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
dp_process_group=dp_process_group, dp_process_group=dp_process_group,
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
overlap_allgather=False,
) )
def sync_dp_grads(self): def sync_dp_grads(self):

View File

@ -2,6 +2,7 @@ import enum
import logging import logging
import os import os
import warnings import warnings
from contextlib import nullcontext
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import MethodType from types import MethodType
@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO from .torch_ddp_plugin import TorchDDPCheckpointIO
@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None: def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None:
super().__init__(module) super().__init__(module)
self.dtype = None self.dtype = None
if precision == "fp16": if precision == "fp16":
@ -72,13 +76,26 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.convert_fn = None self.convert_fn = None
if self.dtype is not None: if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_communication = overlap_communication
if overlap_communication:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if self.convert_fn is not None: if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs) kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext()
with ctx:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict) super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params() model.update_master_params()
@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
load_sub_module: bool = True, load_sub_module: bool = True,
): ):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params() model.update_master_params()
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from peft import PeftModel from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap() peft_model = model.unwrap()
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
reduce_bucket_size_in_m: int = 12, reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True, overlap_communication: bool = True,
overlap_allgather: bool = False,
cpu_offload: bool = False, cpu_offload: bool = False,
master_weights: bool = True, master_weights: bool = True,
verbose: bool = False, verbose: bool = False,
@ -316,6 +357,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
master_weights=master_weights, master_weights=master_weights,
) )
self.overlap_allgather = overlap_allgather
self.lora_enabled = False self.lora_enabled = False
self.verbose = verbose self.verbose = verbose
@ -431,11 +473,11 @@ class LowLevelZeroPlugin(DPPluginBase):
self.add_lora_params_to_optimizer(model, optimizer) self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision) model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather)
# TODO: Support Galore + ZeRO # TODO: Support Galore + ZeRO
zero_stage = self.stage zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs} zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather}
dp_size = dist.get_world_size() dp_size = dist.get_world_size()
# Replace with the distributed implementation if exists # Replace with the distributed implementation if exists

View File

@ -23,6 +23,7 @@ from colossalai.logging import get_dist_logger
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights master_weights: bool = True, # master weights
overlap_allgather: bool = False,
): ):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -121,6 +123,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# communication params # communication params
self._overlap_communication = overlap_communication self._overlap_communication = overlap_communication
self._overlap_allgather = overlap_allgather
self._reduce_bucket_size = reduce_bucket_size self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype self._communication_dtype = communication_dtype
@ -145,6 +148,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# record the padding size of each param # record the padding size of each param
self._padding_map = dict() self._padding_map = dict()
# padded working param is all-gather buffer and it shares the same memory with working param
self._working_param_to_padded_working_param = dict()
# mapping working param and master param # mapping working param and master param
self.master_to_working_param = dict() self.master_to_working_param = dict()
@ -245,11 +250,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
with torch.no_grad(): with torch.no_grad():
if padding_size > 0: if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights # # reset working params' ptr when no master weights
if self._master_weights == False: # if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape) param.data = padding_param[: param.numel()].view(param.shape)
else: else:
padding_param = param.data.view(-1) padding_param = param.data.view(-1)
self._working_param_to_padded_working_param[param] = padding_param
splited_params = padding_param.split( splited_params = padding_param.split(
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
@ -258,7 +264,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# use fp32 when master_weights is True # use fp32 when master_weights is True
if self._master_weights is True: if self._master_weights is True:
splited_param_current_rank = splited_params.detach().float().to(device) splited_param_current_rank = splited_params.detach().clone().float().to(device)
else: else:
splited_param_current_rank = splited_params splited_param_current_rank = splited_params
@ -549,12 +555,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype) param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather:
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
set_all_gather_handle(working_param, handle)
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like( dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue continue
try: try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
@ -562,6 +569,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.pg_to_tensor_bucket[pg].all_gather(pg) self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
if not self._overlap_allgather:
for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty(): if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg) tensor_bucket.all_gather(pg)
@ -892,3 +900,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
grad_store = self.pid_to_grad_store[param_id] grad_store = self.pid_to_grad_store[param_id]
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id) return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)
def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)

View File

@ -0,0 +1,33 @@
from typing import List
from torch._tensor import Tensor
from colossalai.tensor.param_op_hook import ColoParamOpHook
_ALL_GATHER_HANDLE = "_all_gather_handle"
def wait_all_gather_handle(p):
if hasattr(p, _ALL_GATHER_HANDLE):
handle = getattr(p, _ALL_GATHER_HANDLE)
handle.wait()
delattr(p, _ALL_GATHER_HANDLE)
def set_all_gather_handle(p, handle):
setattr(p, _ALL_GATHER_HANDLE, handle)
class ZeroOpHook(ColoParamOpHook):
def pre_forward(self, params: List[Tensor]) -> None:
for p in params:
wait_all_gather_handle(p)
def post_forward(self, params: List[Tensor]) -> None:
pass
def pre_backward(self, params: List[Tensor]) -> None:
pass
def post_backward(self, params: List[Tensor]) -> None:
pass

View File

@ -113,13 +113,13 @@ class PerformanceEvaluator:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable: if self.disable:
return return
get_accelerator().synchronize() # get_accelerator().synchronize()
self.timer.start() self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None: def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable: if self.disable:
return return
get_accelerator().synchronize() # get_accelerator().synchronize()
self.timer.end() self.timer.end()
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape

View File

@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc():
zero1_optimizer.step() zero1_optimizer.step()
zero2_optimizer.step() zero2_optimizer.step()
zero1_optimizer._force_wait_all_gather()
zero2_optimizer._force_wait_all_gather()
# check updated param # check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert not hasattr(z1p, "_all_gather_handle")
assert torch.equal(z1p.data, z2p.data) assert torch.equal(z1p.data, z2p.data)

View File

@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
# torch ddp step # torch ddp step
torch_optimizer.step() torch_optimizer.step()
zero_optimizer._force_wait_all_gather()
# check updated param # check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p, z1p, dtype=dtype) loose_close(p, z1p, dtype=dtype)