[pipeline] add bloom model pipeline (#4210)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* finish bloom model

* test shard gpt2

* clear cache
This commit is contained in:
Jianghai
2023-07-13 12:47:26 +08:00
committed by Hongxin Liu
parent 31bcf867ae
commit 37d22f6878
4 changed files with 322 additions and 10 deletions

View File

@@ -51,15 +51,17 @@ output_transform_fn = lambda x: x
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
loss_fn = lambda x: x.loss
config = transformers.GPT2Config(n_layer=2,
n_head=4,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification")
config = transformers.GPT2Config(
n_layer=2,
n_head=4,
#n_embd=128,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification")
# register the following models
model_zoo.register(name='transformers_gpt',