mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
import inspect
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
@@ -26,7 +27,7 @@ def split_model_and_compare_output(model, data_gen):
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
@@ -49,16 +50,16 @@ def split_model_and_compare_output(model, data_gen):
|
||||
output_part1 = model_part1(output_part0)
|
||||
else:
|
||||
if len(output_part0) > len(sig.parameters):
|
||||
output_part0 = output_part0[:len(sig.parameters)]
|
||||
output_part0 = output_part0[: len(sig.parameters)]
|
||||
output_part1 = model_part1(*output_part0)
|
||||
|
||||
# get output tensor from HFOutput datastructure
|
||||
if 'logits' in output:
|
||||
output_to_compare = output['logits']
|
||||
elif 'prediction_logits' in output:
|
||||
output_to_compare = output['prediction_logits']
|
||||
if "logits" in output:
|
||||
output_to_compare = output["logits"]
|
||||
elif "prediction_logits" in output:
|
||||
output_to_compare = output["prediction_logits"]
|
||||
else:
|
||||
output_to_compare = output['last_hidden_state']
|
||||
output_to_compare = output["last_hidden_state"]
|
||||
|
||||
# compare output
|
||||
if isinstance(output_part1, torch.Tensor):
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 2
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_single_sentence_albert():
|
||||
MODEL_LIST = [
|
||||
transformers.AlbertModel,
|
||||
@@ -17,12 +17,14 @@ def test_single_sentence_albert():
|
||||
transformers.AlbertForTokenClassification,
|
||||
]
|
||||
|
||||
config = transformers.AlbertConfig(vocab_size=100,
|
||||
embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
config = transformers.AlbertConfig(
|
||||
vocab_size=100,
|
||||
embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256,
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
@@ -36,5 +38,5 @@ def test_single_sentence_albert():
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_single_sentence_albert()
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 2
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_single_sentence_bert():
|
||||
MODEL_LIST = [
|
||||
transformers.BertModel,
|
||||
@@ -18,11 +18,9 @@ def test_single_sentence_bert():
|
||||
transformers.BertForTokenClassification,
|
||||
]
|
||||
|
||||
config = transformers.BertConfig(vocab_size=100,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
config = transformers.BertConfig(
|
||||
vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4, intermediate_size=256
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
@@ -36,5 +34,5 @@ def test_single_sentence_bert():
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_single_sentence_bert()
|
||||
|
@@ -9,14 +9,14 @@ NUM_EPOCHS = 2
|
||||
NUM_CHUNKS = 1
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_gpt():
|
||||
MODEL_LIST = [
|
||||
transformers.GPT2Model,
|
||||
transformers.GPT2LMHeadModel,
|
||||
transformers.GPT2DoubleHeadsModel,
|
||||
transformers.GPT2ForTokenClassification,
|
||||
# transformers.GPT2ForSequenceClassification, # not supported yet
|
||||
# transformers.GPT2ForSequenceClassification, # not supported yet
|
||||
]
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8)
|
||||
|
||||
@@ -32,5 +32,5 @@ def test_gpt():
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_gpt()
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
transformers.OPTModel,
|
||||
@@ -27,5 +27,5 @@ def test_opt():
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_opt()
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_t5():
|
||||
MODEL_LIST = [
|
||||
transformers.T5Model,
|
||||
@@ -39,5 +39,5 @@ def test_t5():
|
||||
split_model_and_compare_output(model, data_gen_func)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_t5()
|
||||
|
Reference in New Issue
Block a user