[fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-08-12 18:17:05 +08:00
committed by GitHub
parent f1a3a326c4
commit b2483c8e31
27 changed files with 633 additions and 83 deletions

View File

@@ -68,6 +68,7 @@ class Embedding1D(ParallelModule):
gather_output: bool = True,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
fp8_communication: bool = False,
*args,
**kwargs,
):
@@ -81,6 +82,7 @@ class Embedding1D(ParallelModule):
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output
self.fp8_communication = fp8_communication
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
@@ -155,7 +157,9 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
return output
else:
return output_parallel