mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
@@ -8,7 +7,6 @@ from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
@@ -22,7 +20,6 @@ class MLP(torch.nn.Module):
|
||||
|
||||
# Simple module for demonstration
|
||||
class MyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mlp_1 = MLP()
|
||||
@@ -46,20 +43,20 @@ def test_activation_checkpoint_annotation():
|
||||
gm = GraphModule(module, graph)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
|
||||
assert node.meta.get('activation_checkpoint', -1) == 0
|
||||
if node.name in ["mlp_1_linear1", "mlp_1_linear2"]:
|
||||
assert node.meta.get("activation_checkpoint", -1) == 0
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
|
||||
assert node.meta.get('activation_checkpoint', -1) == 1
|
||||
if node.name in ["mlp_2_linear1", "mlp_2_linear2"]:
|
||||
assert node.meta.get("activation_checkpoint", -1) == 1
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
graph = tracer.trace(module)
|
||||
gm = GraphModule(module, graph)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
assert not hasattr(node, 'activation_checkpoint')
|
||||
assert not hasattr(node, "activation_checkpoint")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_activation_checkpoint_annotation()
|
||||
|
Reference in New Issue
Block a user