mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api
This commit is contained in:
@@ -2,6 +2,7 @@ import enum
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
@@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
@@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||
def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == "fp16":
|
||||
@@ -72,12 +76,25 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_communication = overlap_communication
|
||||
if overlap_communication:
|
||||
self.op_hook = ZeroOpHook()
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=True)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext()
|
||||
with ctx:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
@@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_unsharded_model(model, checkpoint, strict)
|
||||
model.update_master_params()
|
||||
|
||||
@@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
return super().save_sharded_model(
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
@@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
@@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
reduce_bucket_size_in_m: int = 12,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
@@ -316,6 +357,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
|
||||
@@ -431,11 +473,11 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
self.add_lora_params_to_optimizer(model, optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
zero_stage = self.stage
|
||||
zero_optim_kwargs = {**self.zero_optim_kwargs}
|
||||
zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather}
|
||||
dp_size = dist.get_world_size()
|
||||
|
||||
# Replace with the distributed implementation if exists
|
||||
|
Reference in New Issue
Block a user