mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
import inspect
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
@@ -26,7 +27,7 @@ def split_model_and_compare_output(model, data_gen):
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
@@ -49,16 +50,16 @@ def split_model_and_compare_output(model, data_gen):
|
||||
output_part1 = model_part1(output_part0)
|
||||
else:
|
||||
if len(output_part0) > len(sig.parameters):
|
||||
output_part0 = output_part0[:len(sig.parameters)]
|
||||
output_part0 = output_part0[: len(sig.parameters)]
|
||||
output_part1 = model_part1(*output_part0)
|
||||
|
||||
# get output tensor from HFOutput datastructure
|
||||
if 'logits' in output:
|
||||
output_to_compare = output['logits']
|
||||
elif 'prediction_logits' in output:
|
||||
output_to_compare = output['prediction_logits']
|
||||
if "logits" in output:
|
||||
output_to_compare = output["logits"]
|
||||
elif "prediction_logits" in output:
|
||||
output_to_compare = output["prediction_logits"]
|
||||
else:
|
||||
output_to_compare = output['last_hidden_state']
|
||||
output_to_compare = output["last_hidden_state"]
|
||||
|
||||
# compare output
|
||||
if isinstance(output_part1, torch.Tensor):
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 2
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_single_sentence_albert():
|
||||
MODEL_LIST = [
|
||||
transformers.AlbertModel,
|
||||
@@ -17,12 +17,14 @@ def test_single_sentence_albert():
|
||||
transformers.AlbertForTokenClassification,
|
||||
]
|
||||
|
||||
config = transformers.AlbertConfig(vocab_size=100,
|
||||
embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
config = transformers.AlbertConfig(
|
||||
vocab_size=100,
|
||||
embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256,
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
@@ -36,5 +38,5 @@ def test_single_sentence_albert():
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_single_sentence_albert()
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 2
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_single_sentence_bert():
|
||||
MODEL_LIST = [
|
||||
transformers.BertModel,
|
||||
@@ -18,11 +18,9 @@ def test_single_sentence_bert():
|
||||
transformers.BertForTokenClassification,
|
||||
]
|
||||
|
||||
config = transformers.BertConfig(vocab_size=100,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
config = transformers.BertConfig(
|
||||
vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4, intermediate_size=256
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
@@ -36,5 +34,5 @@ def test_single_sentence_bert():
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_single_sentence_bert()
|
||||
|
@@ -9,14 +9,14 @@ NUM_EPOCHS = 2
|
||||
NUM_CHUNKS = 1
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_gpt():
|
||||
MODEL_LIST = [
|
||||
transformers.GPT2Model,
|
||||
transformers.GPT2LMHeadModel,
|
||||
transformers.GPT2DoubleHeadsModel,
|
||||
transformers.GPT2ForTokenClassification,
|
||||
# transformers.GPT2ForSequenceClassification, # not supported yet
|
||||
# transformers.GPT2ForSequenceClassification, # not supported yet
|
||||
]
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8)
|
||||
|
||||
@@ -32,5 +32,5 @@ def test_gpt():
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_gpt()
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
transformers.OPTModel,
|
||||
@@ -27,5 +27,5 @@ def test_opt():
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_opt()
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_t5():
|
||||
MODEL_LIST = [
|
||||
transformers.T5Model,
|
||||
@@ -39,5 +39,5 @@ def test_t5():
|
||||
split_model_and_compare_output(model, data_gen_func)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_t5()
|
||||
|
@@ -4,9 +4,8 @@ import torch
|
||||
from timm_utils import split_model_and_compare_output
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_timm_models_without_control_flow():
|
||||
|
||||
MODEL_LIST = [
|
||||
tm.resnest.resnest50d,
|
||||
tm.beit.beit_base_patch16_224,
|
||||
@@ -25,24 +24,28 @@ def test_timm_models_without_control_flow():
|
||||
split_model_and_compare_output(model, data)
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_timm_models_with_control_flow():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
MODEL_LIST_WITH_CONTROL_FLOW = [
|
||||
tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
|
||||
tm.swin_transformer.swin_base_patch4_window7_224
|
||||
tm.convnext.convnext_base,
|
||||
tm.vgg.vgg11,
|
||||
tm.dpn.dpn68,
|
||||
tm.densenet.densenet121,
|
||||
tm.rexnet.rexnet_100,
|
||||
tm.swin_transformer.swin_base_patch4_window7_224,
|
||||
]
|
||||
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
meta_args = {'x': data.to('meta')}
|
||||
meta_args = {"x": data.to("meta")}
|
||||
|
||||
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
|
||||
model = model_cls()
|
||||
split_model_and_compare_output(model, data, meta_args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_timm_models_without_control_flow()
|
||||
test_timm_models_with_control_flow()
|
||||
|
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
import inspect
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
@@ -46,6 +47,6 @@ def split_model_and_compare_output(model, data, meta_args=None):
|
||||
output_part1 = model_part1(output_part0)
|
||||
else:
|
||||
if len(output_part0) > len(sig.parameters):
|
||||
output_part0 = output_part0[:len(sig.parameters)]
|
||||
output_part0 = output_part0[: len(sig.parameters)]
|
||||
output_part1 = model_part1(*output_part0)
|
||||
assert output.equal(output_part1)
|
||||
|
@@ -7,7 +7,7 @@ BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0")
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
MLP,
|
||||
@@ -15,10 +15,7 @@ def test_opt():
|
||||
]
|
||||
|
||||
CONFIGS = [
|
||||
{
|
||||
'dim': 10,
|
||||
'layers': 12
|
||||
},
|
||||
{"dim": 10, "layers": 12},
|
||||
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
|
||||
]
|
||||
|
||||
@@ -45,5 +42,5 @@ def test_opt():
|
||||
check_topo(top_mod, topo)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_opt()
|
||||
|
@@ -6,7 +6,7 @@ from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.legacy.pipeline.middleware import Partition, Topo
|
||||
from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
|
||||
|
||||
MANUAL_SEED = 0
|
||||
@@ -16,11 +16,10 @@ torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, config={}):
|
||||
super().__init__()
|
||||
dim = config['dim']
|
||||
layers = config['layers']
|
||||
dim = config["dim"]
|
||||
layers = config["layers"]
|
||||
self.layers = torch.nn.ModuleList()
|
||||
|
||||
for _ in range(layers):
|
||||
@@ -41,7 +40,7 @@ def split_model_and_get_DAG(model, data_gen):
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
@@ -55,7 +54,7 @@ def split_model_and_get_DAG(model, data_gen):
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
setattr(submodule, "_topo", topo)
|
||||
|
||||
return top_module, split_submodules[0]._topo
|
||||
|
||||
@@ -64,7 +63,7 @@ def check_input(top_module, input_partition: Partition):
|
||||
partition_output = input_partition.get_output_vals()
|
||||
arg_pos = 0
|
||||
for node in top_module.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
cur_checkee = partition_output[arg_pos]
|
||||
to_partition_and_offset = cur_checkee.get()
|
||||
assert len(to_partition_and_offset) == len(node.users.keys())
|
||||
@@ -80,7 +79,7 @@ def check_submod(top_module, part_id, mid_partition: Partition):
|
||||
cnt = 1
|
||||
cur_node = None
|
||||
for node in top_module.graph.nodes:
|
||||
if node.name.startswith('submod'):
|
||||
if node.name.startswith("submod"):
|
||||
cnt += 1
|
||||
if cnt == part_id:
|
||||
cur_node = node
|
||||
|
@@ -19,14 +19,21 @@ torch.manual_seed(MANUAL_SEED)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_torchvision_models():
|
||||
MODEL_LIST = [
|
||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||
tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5
|
||||
tm.vgg11,
|
||||
tm.resnet18,
|
||||
tm.densenet121,
|
||||
tm.mobilenet_v3_small,
|
||||
tm.resnext50_32x4d,
|
||||
tm.wide_resnet50_2,
|
||||
tm.regnet_x_16gf,
|
||||
tm.efficientnet_b0,
|
||||
tm.mnasnet0_5,
|
||||
]
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
|
||||
if version.parse(torchvision.__version__) >= version.parse("0.12.0"):
|
||||
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
|
||||
|
||||
tracer = ColoTracer()
|
||||
@@ -57,10 +64,10 @@ def test_torchvision_models():
|
||||
output_part1 = model_part1(output_part0)
|
||||
else:
|
||||
if len(output_part0) > len(sig.parameters):
|
||||
output_part0 = output_part0[:len(sig.parameters)]
|
||||
output_part0 = output_part0[: len(sig.parameters)]
|
||||
output_part1 = model_part1(*output_part0)
|
||||
assert output.equal(output_part1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_torchvision_models()
|
||||
|
Reference in New Issue
Block a user