[pipeline] Llama pipeline (#4205)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt
This commit is contained in:
Jianghai
2023-07-11 11:37:26 +08:00
committed by Hongxin Liu
parent 1094e0f0d3
commit 1622031058
6 changed files with 516 additions and 4 deletions

View File

@@ -39,6 +39,7 @@ def build_pipeline_model(model_fn,
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()