support fp8 communication in pipeline parallelism

This commit is contained in:
BurkeHulk
2024-07-12 15:25:25 +08:00
parent 1e1959467e
commit e88190184a
4 changed files with 126 additions and 1 deletions

View File

@@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert (
@@ -1082,6 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase):
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
fp8_communication=fp8_communication,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
@@ -1089,6 +1091,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication,
)
else:
raise NotImplementedError()