[fp8] add fp8 comm for low level zero

This commit is contained in:
ver217
2024-08-02 11:12:12 +08:00
parent 5fd0592767
commit ae486ce005
3 changed files with 32 additions and 9 deletions

View File

@@ -293,6 +293,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@@ -315,6 +316,7 @@ class LowLevelZeroPlugin(DPPluginBase):
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
fp8_communication=fp8_communication,
)
self.lora_enabled = False
self.verbose = verbose