mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,7 +7,7 @@ BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0")
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
MLP,
|
||||
@@ -15,10 +15,7 @@ def test_opt():
|
||||
]
|
||||
|
||||
CONFIGS = [
|
||||
{
|
||||
'dim': 10,
|
||||
'layers': 12
|
||||
},
|
||||
{"dim": 10, "layers": 12},
|
||||
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
|
||||
]
|
||||
|
||||
@@ -45,5 +42,5 @@ def test_opt():
|
||||
check_topo(top_mod, topo)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_opt()
|
||||
|
@@ -6,7 +6,7 @@ from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.legacy.pipeline.middleware import Partition, Topo
|
||||
from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
|
||||
|
||||
MANUAL_SEED = 0
|
||||
@@ -16,11 +16,10 @@ torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, config={}):
|
||||
super().__init__()
|
||||
dim = config['dim']
|
||||
layers = config['layers']
|
||||
dim = config["dim"]
|
||||
layers = config["layers"]
|
||||
self.layers = torch.nn.ModuleList()
|
||||
|
||||
for _ in range(layers):
|
||||
@@ -41,7 +40,7 @@ def split_model_and_get_DAG(model, data_gen):
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
@@ -55,7 +54,7 @@ def split_model_and_get_DAG(model, data_gen):
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
setattr(submodule, "_topo", topo)
|
||||
|
||||
return top_module, split_submodules[0]._topo
|
||||
|
||||
@@ -64,7 +63,7 @@ def check_input(top_module, input_partition: Partition):
|
||||
partition_output = input_partition.get_output_vals()
|
||||
arg_pos = 0
|
||||
for node in top_module.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
cur_checkee = partition_output[arg_pos]
|
||||
to_partition_and_offset = cur_checkee.get()
|
||||
assert len(to_partition_and_offset) == len(node.users.keys())
|
||||
@@ -80,7 +79,7 @@ def check_submod(top_module, part_id, mid_partition: Partition):
|
||||
cnt = 1
|
||||
cur_node = None
|
||||
for node in top_module.graph.nodes:
|
||||
if node.name.startswith('submod'):
|
||||
if node.name.startswith("submod"):
|
||||
cnt += 1
|
||||
if cnt == part_id:
|
||||
cur_node = node
|
||||
|
Reference in New Issue
Block a user