mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 03:26:48 +00:00
[fx] added testing for all albert variants (#1211)
This commit is contained in:
parent
2d13a45a3b
commit
5da87ce35d
@ -39,7 +39,7 @@ class ColoProxy(Proxy):
|
|||||||
self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
|
self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
|
||||||
|
|
||||||
def _assert_has_meta_data(self):
|
def _assert_has_meta_data(self):
|
||||||
assert self._meta_data, f'Meta data is not set for {self.node.name}'
|
assert self._meta_data is not None, f'Meta data is not set for {self.node.name}'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@ -63,7 +63,7 @@ class ColoProxy(Proxy):
|
|||||||
|
|
||||||
def size(self, dim: int = None):
|
def size(self, dim: int = None):
|
||||||
self._assert_meta_data_is_tensor()
|
self._assert_meta_data_is_tensor()
|
||||||
if dim:
|
if dim is not None:
|
||||||
return self.meta_data.size(dim=dim)
|
return self.meta_data.size(dim=dim)
|
||||||
else:
|
else:
|
||||||
# size(dim=None) will trigger runtime error for meta tensor
|
# size(dim=None) will trigger runtime error for meta tensor
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from curses import meta
|
||||||
import operator
|
import operator
|
||||||
import torch
|
import torch
|
||||||
from .registry import meta_patched_function
|
from .registry import meta_patched_function
|
||||||
@ -89,3 +90,55 @@ def torch_where(condition, x, y):
|
|||||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||||
# so hack it by using addition
|
# so hack it by using addition
|
||||||
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.abs)
|
||||||
|
def torch_abs(input, *, out=None):
|
||||||
|
assert out is None, 'out is not supported yet'
|
||||||
|
return torch.empty(input.shape, device='meta')
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.nn.functional.relu)
|
||||||
|
def torch_nn_func_relu(input, inplace=False):
|
||||||
|
assert not inplace, 'inplace is not supported yet'
|
||||||
|
return torch.empty(input.shape, device='meta')
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.Tensor.repeat)
|
||||||
|
def torch_tensor_repeat(self, *sizes):
|
||||||
|
shape = list(self.shape)
|
||||||
|
for i, x in enumerate(sizes):
|
||||||
|
shape[i] *= x
|
||||||
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.index_select)
|
||||||
|
def torch_index_select(input, dim, index, *, out=None):
|
||||||
|
shape = list(input.shape)
|
||||||
|
shape[dim] = len(index)
|
||||||
|
return torch.empty(*shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.Tensor.index_select)
|
||||||
|
def torch_tensor_index_select(self, dim, index):
|
||||||
|
return torch_index_select(self, dim, index)
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.nn.functional.embedding)
|
||||||
|
def torch_nn_functional_embedding(input,
|
||||||
|
weight,
|
||||||
|
padding_idx=None,
|
||||||
|
max_norm=None,
|
||||||
|
norm_type=2.0,
|
||||||
|
scale_grad_by_freq=False,
|
||||||
|
sparse=False):
|
||||||
|
return torch.empty(*input.shape, weight.shape[-1], device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.bmm)
|
||||||
|
def torch_bmm(input, mat2, *, out=None):
|
||||||
|
if out is not None:
|
||||||
|
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||||
|
batch_size, n, m = input.shape
|
||||||
|
_, _, p = mat2.shape
|
||||||
|
return torch.empty(batch_size, n, p, device="meta")
|
||||||
|
@ -116,3 +116,9 @@ def torch_nn_maxpool3d(self, input):
|
|||||||
w_out,
|
w_out,
|
||||||
)
|
)
|
||||||
return torch.empty(result_shape, device='meta')
|
return torch.empty(result_shape, device='meta')
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_module.register(torch.nn.ReLU)
|
||||||
|
def torch_nn_func_relu(self, input):
|
||||||
|
assert not self.inplace, 'inplace is not supported yet'
|
||||||
|
return input.clone()
|
||||||
|
65
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
Normal file
65
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_sentence_albert():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.AlbertModel,
|
||||||
|
transformers.AlbertForPreTraining,
|
||||||
|
transformers.AlbertForMaskedLM,
|
||||||
|
transformers.AlbertForSequenceClassification,
|
||||||
|
transformers.AlbertForTokenClassification,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.AlbertConfig(embedding_size=128,
|
||||||
|
hidden_size=128,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=256)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
return meta_args
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
trace_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_sentence_albert():
|
||||||
|
config = transformers.AlbertConfig(hidden_size=128,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=256)
|
||||||
|
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
def data_gen_for_qa():
|
||||||
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
|
inputs = tokenizer(question, text, return_tensors="pt")
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
model = transformers.AlbertForQuestionAnswering(config)
|
||||||
|
trace_model_and_compare_output(model, data_gen_for_qa)
|
||||||
|
|
||||||
|
def data_gen_for_mcq():
|
||||||
|
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
|
choice0 = "It is eaten with a fork and a knife."
|
||||||
|
choice1 = "It is eaten while held in the hand."
|
||||||
|
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
||||||
|
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
model = transformers.AlbertForMultipleChoice(config)
|
||||||
|
trace_model_and_compare_output(model, data_gen_for_mcq)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_single_sentence_albert()
|
||||||
|
test_multi_sentence_albert()
|
31
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
Normal file
31
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
|
BATCH_SIZE = 1
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('value is not aligned yet')
|
||||||
|
def test_opt():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.OPTModel,
|
||||||
|
transformers.OPTForCausalLM,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
trace_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_opt()
|
32
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
Normal file
32
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
|
BATCH_SIZE = 1
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('value is not aligned yet')
|
||||||
|
def test_t5():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.T5Model,
|
||||||
|
transformers.T5ForConditionalGeneration,
|
||||||
|
transformers.T5EncoderModel,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.T5Config(d_model=128, num_layers=2)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
trace_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_t5()
|
@ -30,4 +30,6 @@ def trace_model_and_compare_output(model, data_gen):
|
|||||||
|
|
||||||
for k in non_fx_out.keys():
|
for k in non_fx_out.keys():
|
||||||
if torch.is_tensor(fx_out[k]):
|
if torch.is_tensor(fx_out[k]):
|
||||||
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'
|
assert torch.equal(
|
||||||
|
fx_out[k], non_fx_out[k]
|
||||||
|
), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}'
|
||||||
|
Loading…
Reference in New Issue
Block a user