[plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel
This commit is contained in:
Hongxin Liu
2024-07-18 15:33:03 +08:00
committed by GitHub
parent 73494de577
commit e86127925a
4 changed files with 42 additions and 12 deletions

View File

@@ -62,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
@@ -76,8 +76,8 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_communication = overlap_communication
if overlap_communication:
self.overlap_allgather = overlap_allgather
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
@@ -88,7 +88,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext()
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
with ctx:
return super().forward(*args, **kwargs)
@@ -356,8 +356,8 @@ class LowLevelZeroPlugin(DPPluginBase):
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
overlap_allgather=overlap_allgather,
)
self.overlap_allgather = overlap_allgather
self.lora_enabled = False
self.verbose = verbose
@@ -473,11 +473,13 @@ class LowLevelZeroPlugin(DPPluginBase):
self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather)
model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
)
# TODO: Support Galore + ZeRO
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather}
zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size()
# Replace with the distributed implementation if exists