mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[shardformer]update t5 tests for using all optimizations. (#4407)
* [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix * [shardformer]update t5 to use all optimizations
This commit is contained in:
committed by
Hongxin Liu
parent
1edc9b5fb3
commit
108e54a0b4
@@ -16,8 +16,8 @@ def data_gen_for_encoder_only():
|
||||
# config = T5Config(decoder_start_token_id=0)
|
||||
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12, 1627, 5, 1, 12]]).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def data_gen_for_conditional_generation():
|
||||
#
|
||||
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
data = data_gen_for_encoder_only()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long()
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
@@ -35,7 +35,7 @@ def data_gen_for_t5_model():
|
||||
# decoder_inputs_ids is obtained with the following code
|
||||
# decoder_input_ids = model._shift_right(input_ids)
|
||||
data = data_gen_for_encoder_only()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long()
|
||||
data['decoder_input_ids'] = decoder_input_ids
|
||||
return data
|
||||
|
||||
|
Reference in New Issue
Block a user