mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,7 +7,6 @@ from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class ControlFlowModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(10, 10)
|
||||
@@ -27,16 +26,12 @@ class ControlFlowModel(nn.Module):
|
||||
def test_control_flow():
|
||||
model = ControlFlowModel()
|
||||
tracer = Tracer()
|
||||
graph_branch_true = tracer.trace(model,
|
||||
meta_args={
|
||||
'x': torch.rand(4, 10, device='meta'),
|
||||
'y': torch.rand(4, 10, device='meta')
|
||||
})
|
||||
graph_branch_false = tracer.trace(model,
|
||||
meta_args={
|
||||
'x': torch.rand(10, device='meta'),
|
||||
'y': torch.rand(4, 10, device='meta')
|
||||
})
|
||||
graph_branch_true = tracer.trace(
|
||||
model, meta_args={"x": torch.rand(4, 10, device="meta"), "y": torch.rand(4, 10, device="meta")}
|
||||
)
|
||||
graph_branch_false = tracer.trace(
|
||||
model, meta_args={"x": torch.rand(10, device="meta"), "y": torch.rand(4, 10, device="meta")}
|
||||
)
|
||||
|
||||
gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__)
|
||||
gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__)
|
||||
@@ -56,5 +51,5 @@ def test_control_flow():
|
||||
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_control_flow()
|
||||
|
Reference in New Issue
Block a user