mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix
This commit is contained in:
@@ -68,6 +68,7 @@ class Embedding1D(ParallelModule):
|
||||
gather_output: bool = True,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
fp8_communication: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -81,6 +82,7 @@ class Embedding1D(ParallelModule):
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
@@ -155,7 +157,9 @@ class Embedding1D(ParallelModule):
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
||||
|
Reference in New Issue
Block a user