[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -7,7 +7,7 @@ BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0")
def test_opt():
MODEL_LIST = [
MLP,
@@ -15,10 +15,7 @@ def test_opt():
]
CONFIGS = [
{
'dim': 10,
'layers': 12
},
{"dim": 10, "layers": 12},
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
]
@@ -45,5 +42,5 @@ def test_opt():
check_topo(top_mod, topo)
if __name__ == '__main__':
if __name__ == "__main__":
test_opt()

View File

@@ -6,7 +6,7 @@ from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.legacy.pipeline.middleware import Partition, Topo
from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
MANUAL_SEED = 0
@@ -16,11 +16,10 @@ torch.manual_seed(MANUAL_SEED)
class MLP(torch.nn.Module):
def __init__(self, config={}):
super().__init__()
dim = config['dim']
layers = config['layers']
dim = config["dim"]
layers = config["layers"]
self.layers = torch.nn.ModuleList()
for _ in range(layers):
@@ -41,7 +40,7 @@ def split_model_and_get_DAG(model, data_gen):
# tracing model
tracer = ColoTracer()
try:
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
meta_args = {k: v.to("meta") for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
@@ -55,7 +54,7 @@ def split_model_and_get_DAG(model, data_gen):
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
setattr(submodule, "_topo", topo)
return top_module, split_submodules[0]._topo
@@ -64,7 +63,7 @@ def check_input(top_module, input_partition: Partition):
partition_output = input_partition.get_output_vals()
arg_pos = 0
for node in top_module.graph.nodes:
if node.op == 'placeholder':
if node.op == "placeholder":
cur_checkee = partition_output[arg_pos]
to_partition_and_offset = cur_checkee.get()
assert len(to_partition_and_offset) == len(node.users.keys())
@@ -80,7 +79,7 @@ def check_submod(top_module, part_id, mid_partition: Partition):
cnt = 1
cur_node = None
for node in top_module.graph.nodes:
if node.name.startswith('submod'):
if node.name.startswith("submod"):
cnt += 1
if cnt == part_id:
cur_node = node