mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[Shardformer] Merge flash attention branch to pipeline branch (#4362)
* [shardformer] supported flash attention test dependency (#4158) * [shardformer] fix flash attention utils test (#4180) * [shardformer] opt support flash attention (#4163) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] add performance benchmark of shardformer (#4175) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] benchmark fix * [shardformer] benchmark fix * [shardformer] llama support flash attention (#4185) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] llama support flash attention * [shardformer] llama support flash attention * [shardformer] Move the import statement for xformer outside the forward function. * [shardformer] gpt2 support flash attention. (#4191) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] gpt2 support flash attention * [shardformer] gpt2 support flash attention * [shardformer] gpt2 support flash attention * [shardformer] bloom support flash attention (#4188) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bloom suport flash attention * [shardformer] add assert to sequence length * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] bert support flash attention. (#4206) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bert support flash attention * [shardformer] t5 support flash attention. (#4216) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] t5 support flash attention * [shardformer] t5 support flash attention * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo * [shardformer] support 'paddedcausal' type of attention mask in Coloattention. (#4215) * added padded causal attn mask type for ColoAttention * [shardformer]t5 flash attention fix (#4239) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] t5 flash attention fix * [shardformer] update gpt2 to use coloattention. (#4234) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 * [shardformer] update opt and llama to use coloattention. (#4226) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt * [shardformer] shardformer support jit fused operator. (#4236) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bloom support jit fused operator * [shardformer] bloom support jit fused operator * [shardformer] bloom support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] add roadmap of flash attention * [shardformer] add roadmap of flash attention * [shardformer] add roadmap of flash attention * [shardformer] add type hint to 'self' param of forward * [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] whisper support flash attention (#4301) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] whisper support flash attention * [shardformer] whisper support flash attention * [shardformer]whisper support jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] sam support flash attention (#4316) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] sam support flash attention --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] merge blip2/chatglm (#4321) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] blip2 support flash attention and jit operator (#4325) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] blip2 support flash attention and jit operator * [shardformer] blip2 support flash attention and jit operator * [shardformer] blip2 support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] chatglm support flash attention and jit operator (#4330) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] vit support flash attention and jit operator (#4334) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] vit support flash attention and jit operator * [shardformer] vit support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [pipeline] merge flash attention branch * [pipeline] merge flash attention branch * [pipeline] merge flash attention branch * [pipeline] fix conflict * [pipeline] fix conflict * Merge branch 'feature/pipeline' into feature/pipeline * Merge branch 'feature/pipeline' into feature/pipeline * Merge branch 'feature/pipeline' into feature/pipeline * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * fix flash attention tests * gemini ignore whisper * fix vit * fix xformers import handle --------- Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com>
This commit is contained in:
@@ -20,7 +20,7 @@ def data_gen():
|
||||
# token_type_ids = tokenized_input['token_type_ids']
|
||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
|
||||
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
@@ -69,19 +69,21 @@ def data_gen_for_mcq():
|
||||
# data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
|
||||
],
|
||||
[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
|
||||
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0, 0
|
||||
]]])
|
||||
token_type_ids = torch.tensor(
|
||||
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
|
||||
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0]]])
|
||||
attention_mask = torch.tensor(
|
||||
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
|
||||
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0]]])
|
||||
labels = torch.tensor([0], dtype=torch.int64)
|
||||
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
@@ -38,6 +38,7 @@ output_transform_fn = lambda x: x
|
||||
loss_fn_blip2_model = lambda x: x.loss
|
||||
|
||||
config = transformers.Blip2Config()
|
||||
config.vision_config.patch_size = 14
|
||||
config.text_config.num_hidden_layers = 1
|
||||
config.qformer_config.num_hidden_layers = 1
|
||||
config.vision_config.num_hidden_layers = 1
|
||||
|
@@ -16,8 +16,8 @@ def data_gen():
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ def data_gen_for_question_answering():
|
||||
# inputs = tokenizer(question, text, return_tensors="pt")
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
start_positions = torch.tensor([1], dtype=torch.int64)
|
||||
end_positions = torch.tensor([10], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids,
|
||||
|
@@ -6,7 +6,6 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
|
||||
# ================================
|
||||
# Register single-sentence ChatGLM
|
||||
# ================================
|
||||
|
@@ -1,58 +0,0 @@
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
model_type = "chatglm"
|
||||
|
||||
def __init__(self,
|
||||
num_layers=28,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=4096,
|
||||
ffn_hidden_size=13696,
|
||||
kv_channels=128,
|
||||
num_attention_heads=32,
|
||||
seq_length=2048,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layernorm_epsilon=1e-5,
|
||||
rmsnorm=True,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
post_layer_norm=True,
|
||||
add_bias_linear=False,
|
||||
add_qkv_bias=False,
|
||||
bias_dropout_fusion=True,
|
||||
multi_query_attention=False,
|
||||
multi_query_group_num=1,
|
||||
apply_query_key_layer_scaling=True,
|
||||
attention_softmax_in_fp32=True,
|
||||
fp32_residual_connection=False,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = padded_vocab_size
|
||||
self.padded_vocab_size = padded_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.kv_channels = kv_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.seq_length = seq_length
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.rmsnorm = rmsnorm
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.post_layer_norm = post_layer_norm
|
||||
self.add_bias_linear = add_bias_linear
|
||||
self.add_qkv_bias = add_qkv_bias
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
self.multi_query_attention = multi_query_attention
|
||||
self.multi_query_group_num = multi_query_group_num
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
super().__init__(**kwargs)
|
File diff suppressed because it is too large
Load Diff
@@ -18,8 +18,8 @@ def data_gen():
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
|
@@ -16,8 +16,9 @@ def data_gen_for_encoder_only():
|
||||
# config = T5Config(decoder_start_token_id=0)
|
||||
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
|
||||
return dict(input_ids=input_ids)
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_conditional_generation():
|
||||
@@ -25,17 +26,16 @@ def data_gen_for_conditional_generation():
|
||||
#
|
||||
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
data = data_gen_for_encoder_only()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_t5_model():
|
||||
# decoder_inputs_ids is obtained with the following code
|
||||
#
|
||||
# decoder_input_ids = model._shift_right(input_ids)
|
||||
data = data_gen_for_encoder_only()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
|
||||
data['decoder_input_ids'] = decoder_input_ids
|
||||
return data
|
||||
|
||||
|
@@ -76,14 +76,14 @@ model_zoo.register(name='transformers_whisper',
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_whisperForConditionalGeneration',
|
||||
model_zoo.register(name='transformers_whisper_for_conditional_generation',
|
||||
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_whisperWhisperForAudioClassification',
|
||||
model_zoo.register(name='transformers_whisper_for_audio_classification',
|
||||
model_fn=lambda: transformers.WhisperForAudioClassification(config),
|
||||
data_gen_fn=data_gen_for_audio_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
|
@@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
'transformers_vit_for_image_classification', 'transformers_chatglm',
|
||||
'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
|
||||
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
|
||||
'transformers_whisperForConditionalGeneration', 'transformers_whisperWhisperForAudioClassification'
|
||||
'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
|
||||
]:
|
||||
continue
|
||||
|
||||
|
@@ -21,7 +21,13 @@ from colossalai.shardformer._utils import getattr_
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
|
||||
|
||||
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
|
||||
def build_model(model_fn,
|
||||
enable_fused_normalization=True,
|
||||
enable_tensor_parallelism=True,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
use_lazy_init: bool = False):
|
||||
# create new model
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
# create new model
|
||||
@@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
|
||||
ctx.materialize(org_model)
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
@@ -46,14 +46,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False, True])
|
||||
@parameterize('enable_tensor_parallelism', [False, True])
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
|
||||
use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
|
@@ -47,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -44,13 +44,15 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
|
||||
use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -72,7 +72,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create new model
|
||||
@@ -80,7 +82,9 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
if name == "transformers_chatglm":
|
||||
|
@@ -68,7 +68,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
|
@@ -49,12 +49,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
enable_flash_attention, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -42,18 +42,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check grad
|
||||
col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
|
||||
row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
|
||||
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
|
||||
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
|
||||
|
||||
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
|
||||
enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
@@ -62,7 +65,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
||||
def check_OPTModel(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_t5_test()
|
||||
run_opt_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -41,10 +41,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_sam')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -33,8 +33,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check grad
|
||||
col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
|
||||
row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
|
||||
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-7, rtol=1e-5, dim=0, verbose=False)
|
||||
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-7, rtol=1e-5, dim=1, verbose=False)
|
||||
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
|
||||
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
|
||||
|
||||
# check weights are tied
|
||||
if hasattr(org_model, 'lm_head'):
|
||||
@@ -45,11 +45,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
|
||||
enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -20,7 +20,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
|
||||
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
@@ -45,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -48,12 +48,16 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn,
|
||||
enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -167,4 +167,4 @@ def test_cross_attention(proj_shape, dtype, dropout):
|
||||
torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
|
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
||||
|
Reference in New Issue
Block a user