mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,9 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from numpy import isin
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
# from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
@@ -20,7 +17,7 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non
|
||||
inputs = {k: v for k, v in inputs.items() if k not in ignore_data}
|
||||
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in inputs.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in inputs.items()}
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
|
||||
except Exception as e:
|
||||
@@ -35,4 +32,4 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non
|
||||
if torch.is_tensor(fx_out[k]):
|
||||
assert torch.equal(
|
||||
fx_out[k], non_fx_out[k]
|
||||
), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}'
|
||||
), f"{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}"
|
||||
|
@@ -10,15 +10,15 @@ BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@clear_cache_before_run()
|
||||
def test_albert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
||||
sub_registry = model_zoo.get_sub_registry("transformers_albert")
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_albert()
|
||||
|
@@ -7,17 +7,17 @@ from colossalai.testing import clear_cache_before_run
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@clear_cache_before_run()
|
||||
def test_bert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
||||
sub_registry = model_zoo.get_sub_registry("transformers_bert")
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
if model.__class__.__name__ == "BertForQuestionAnswering":
|
||||
continue
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label'])
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "next_sentence_label"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_bert()
|
||||
|
@@ -22,7 +22,7 @@ def trace_and_compare(model_cls, data, output_fn):
|
||||
model.eval()
|
||||
|
||||
concrete_args = {k: v for k, v in data.items() if not torch.is_tensor(v)}
|
||||
meta_args = {k: v.to('meta') for k, v in data.items() if torch.is_tensor(v)}
|
||||
meta_args = {k: v.to("meta") for k, v in data.items() if torch.is_tensor(v)}
|
||||
gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args)
|
||||
|
||||
# run forward
|
||||
@@ -40,12 +40,12 @@ def trace_and_compare(model_cls, data, output_fn):
|
||||
assert_dict(transformed_fx_out, transformed_non_fx_out, assert_fn)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='cannot pass this test yet')
|
||||
@pytest.mark.skip(reason="cannot pass this test yet")
|
||||
@clear_cache_before_run()
|
||||
def test_diffusers():
|
||||
seed_all(9091, cuda_deterministic=True)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
|
||||
sub_model_zoo = model_zoo.get_sub_registry("diffusers")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
@@ -58,12 +58,12 @@ def test_diffusers():
|
||||
def test_torch_diffusers():
|
||||
seed_all(65535, cuda_deterministic=True)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
|
||||
sub_model_zoo = model_zoo.get_sub_registry("diffusers")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
model = model_fn()
|
||||
output = model(**data)
|
||||
model(**data)
|
||||
torch.cuda.synchronize()
|
||||
print(f"{name:40s} √")
|
||||
|
||||
|
@@ -7,10 +7,10 @@ from colossalai.testing import clear_cache_before_run
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@clear_cache_before_run()
|
||||
def test_gpt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
||||
sub_registry = model_zoo.get_sub_registry("transformers_gpt")
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
@@ -18,11 +18,11 @@ def test_gpt():
|
||||
# TODO(ver217): support the following models
|
||||
# 1. GPT2DoubleHeadsModel
|
||||
# as they are not supported, let's skip them
|
||||
if model.__class__.__name__ in ['GPT2DoubleHeadsModel', 'GPT2ForQuestionAnswering']:
|
||||
if model.__class__.__name__ in ["GPT2DoubleHeadsModel", "GPT2ForQuestionAnswering"]:
|
||||
continue
|
||||
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_gpt()
|
||||
|
@@ -7,14 +7,14 @@ from colossalai.testing import clear_cache_before_run
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@clear_cache_before_run()
|
||||
def test_opt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
||||
sub_registry = model_zoo.get_sub_registry("transformers_opt")
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions'])
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "start_positions", "end_positions"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_opt()
|
||||
|
@@ -7,20 +7,20 @@ from colossalai.testing import clear_cache_before_run
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@clear_cache_before_run()
|
||||
def test_t5():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
||||
sub_registry = model_zoo.get_sub_registry("transformers_t5")
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
if name == "transformers_t5_for_conditional_generation":
|
||||
# cannot trace for loss function yet
|
||||
# so we use a data gen which does not produce labels
|
||||
data_gen_fn = sub_registry.get('transformers_t5')[1]
|
||||
data_gen_fn = sub_registry.get("transformers_t5")[1]
|
||||
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_t5()
|
||||
|
Reference in New Issue
Block a user