mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
This commit is contained in:
@@ -62,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None:
|
||||
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == "fp16":
|
||||
@@ -76,8 +76,8 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_communication = overlap_communication
|
||||
if overlap_communication:
|
||||
self.overlap_allgather = overlap_allgather
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
@@ -88,7 +88,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext()
|
||||
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
with ctx:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@@ -356,8 +356,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
partition_grad=(stage == 2),
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
|
||||
@@ -473,11 +473,13 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
self.add_lora_params_to_optimizer(model, optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather)
|
||||
model = LowLevelZeroModel(
|
||||
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
zero_stage = self.stage
|
||||
zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather}
|
||||
zero_optim_kwargs = {**self.zero_optim_kwargs}
|
||||
dp_size = dist.get_world_size()
|
||||
|
||||
# Replace with the distributed implementation if exists
|
||||
|
Reference in New Issue
Block a user