mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix
This commit is contained in:
@@ -572,6 +572,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
fp8_communication: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# create weight and bias
|
||||
@@ -602,6 +603,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
||||
**kwargs,
|
||||
new_num_embeddings=new_out_features,
|
||||
old_num_embeddings=out_features,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
# get the length of valid embeddings
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
Reference in New Issue
Block a user