mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,15 +1,11 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim)
|
||||
@@ -43,11 +39,11 @@ def test_graph_manipulation():
|
||||
assert leaf_nodes == set([l4, l5])
|
||||
assert top_nodes == set([l1, l2])
|
||||
for node in graph.nodes:
|
||||
if node.op in ('placeholder', 'output'):
|
||||
assert not hasattr(node, 'bfs_level')
|
||||
if node.op in ("placeholder", "output"):
|
||||
assert not hasattr(node, "bfs_level")
|
||||
else:
|
||||
assert node.bfs_level == compare_dict[node]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_graph_manipulation()
|
||||
|
Reference in New Issue
Block a user