[pipeline] add bloom model pipeline (#4210)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* finish bloom model

* test shard gpt2

* clear cache
This commit is contained in:
Jianghai
2023-07-13 12:47:26 +08:00
committed by Hongxin Liu
parent 31bcf867ae
commit 37d22f6878
4 changed files with 322 additions and 10 deletions

View File

@@ -70,6 +70,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('use_lazy_init', [False, True])
@clear_cache_before_run()
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():