[zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api
This commit is contained in:
Hongxin Liu
2024-07-11 18:59:59 +08:00
committed by GitHub
parent dd9e1cdafe
commit c068ef0fa0
7 changed files with 119 additions and 25 deletions

View File

@@ -2,6 +2,7 @@ import enum
import logging
import os
import warnings
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from types import MethodType
@@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
@@ -72,12 +76,25 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_communication = overlap_communication
if overlap_communication:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext()
with ctx:
return super().forward(*args, **kwargs)
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
@@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()
@@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
@@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
overlap_allgather: bool = False,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
@@ -316,6 +357,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.overlap_allgather = overlap_allgather
self.lora_enabled = False
self.verbose = verbose
@@ -431,11 +473,11 @@ class LowLevelZeroPlugin(DPPluginBase):
self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather)
# TODO: Support Galore + ZeRO
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather}
dp_size = dist.get_world_size()
# Replace with the distributed implementation if exists