mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 21:55:46 +00:00
[fx] added testing for all bert variants (#1207)
* [fx] added testing for all bert variants * polish code
This commit is contained in:
parent
b5f25eb32a
commit
426a279ce7
@ -1,6 +1,8 @@
|
|||||||
import operator
|
import operator
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.proxy import Proxy, Attribute
|
from torch.fx.proxy import Proxy, Attribute
|
||||||
|
from typing import List, Union
|
||||||
|
from torch.utils._pytree import PyTree
|
||||||
|
|
||||||
__all__ = ['ColoProxy']
|
__all__ = ['ColoProxy']
|
||||||
|
|
||||||
@ -26,8 +28,12 @@ class ColoProxy(Proxy):
|
|||||||
return self._meta_tensor
|
return self._meta_tensor
|
||||||
|
|
||||||
@meta_tensor.setter
|
@meta_tensor.setter
|
||||||
def meta_tensor(self, tensor: torch.Tensor):
|
def meta_tensor(self, tensor: Union[List[torch.Tensor], torch.Tensor]):
|
||||||
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
|
||||||
|
def _is_meta(item):
|
||||||
|
assert torch.is_tensor(item) and item.is_meta
|
||||||
|
|
||||||
|
torch.fx.node.map_aggregate(tensor, _is_meta)
|
||||||
self._meta_tensor = tensor
|
self._meta_tensor = tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -83,6 +89,14 @@ class ColoProxy(Proxy):
|
|||||||
def __setitem__(self, indices, values):
|
def __setitem__(self, indices, values):
|
||||||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
if self.node.op == "placeholder":
|
||||||
|
# this is used to handle like
|
||||||
|
# if x in kwargs
|
||||||
|
# we don't handle this case for now
|
||||||
|
return False
|
||||||
|
return super().__contains__(key)
|
||||||
|
|
||||||
|
|
||||||
class ColoAttribute(ColoProxy):
|
class ColoAttribute(ColoProxy):
|
||||||
|
|
||||||
|
@ -7,36 +7,90 @@ BATCH_SIZE = 2
|
|||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
def test_bert():
|
def trace_bert_and_compare_output(model, data_gen):
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
config = transformers.BertConfig()
|
|
||||||
model = transformers.BertModel(config=config)
|
|
||||||
|
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
|
||||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
|
||||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
|
||||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
# make sure that the model is traceable
|
# make sure that the model is traceable
|
||||||
|
try:
|
||||||
|
kwargs = data_gen()
|
||||||
|
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
# check output
|
# check output
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
inputs = data_gen()
|
||||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
|
||||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
|
||||||
|
|
||||||
# must turn on eval mode to ensure the output is consistent
|
# must turn on eval mode to ensure the output is consistent
|
||||||
gm.eval()
|
gm.eval()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# run forward
|
# run forward
|
||||||
fx_out = gm(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
non_fx_out = model(**inputs)
|
||||||
non_fx_out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
fx_out = gm(**inputs)
|
||||||
assert fx_out['last_hidden_state'].shape == non_fx_out['last_hidden_state'].shape
|
|
||||||
assert torch.equal(fx_out['last_hidden_state'], non_fx_out['last_hidden_state'])
|
for k in non_fx_out.keys():
|
||||||
|
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_sentence_bert():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.BertModel,
|
||||||
|
transformers.BertForPreTraining,
|
||||||
|
transformers.BertLMHeadModel,
|
||||||
|
transformers.BertForMaskedLM,
|
||||||
|
transformers.BertForSequenceClassification,
|
||||||
|
transformers.BertForTokenClassification,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
return meta_args
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
trace_bert_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_sentence_bert():
|
||||||
|
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
||||||
|
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
def data_gen_for_next_sentence():
|
||||||
|
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
|
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||||
|
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
model = transformers.BertForNextSentencePrediction(config)
|
||||||
|
trace_bert_and_compare_output(model, data_gen_for_next_sentence)
|
||||||
|
|
||||||
|
def data_gen_for_qa():
|
||||||
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
|
inputs = tokenizer(question, text, return_tensors="pt")
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
model = transformers.BertForQuestionAnswering(config)
|
||||||
|
trace_bert_and_compare_output(model, data_gen_for_qa)
|
||||||
|
|
||||||
|
def data_gen_for_mcq():
|
||||||
|
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
|
choice0 = "It is eaten with a fork and a knife."
|
||||||
|
choice1 = "It is eaten while held in the hand."
|
||||||
|
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
||||||
|
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
model = transformers.BertForMultipleChoice(config)
|
||||||
|
trace_bert_and_compare_output(model, data_gen_for_mcq)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_bert()
|
test_single_sentence_bert()
|
||||||
|
test_multi_sentence_bert()
|
||||||
|
Loading…
Reference in New Issue
Block a user