mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -5,6 +5,7 @@ import torchvision.models as tm
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
|
||||
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
|
||||
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
@@ -15,14 +16,16 @@ if is_compatible_with_meta():
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
||||
with_codegen = True
|
||||
except:
|
||||
# fall back to older pytorch version
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
|
||||
with_codegen = False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='TODO: modify the logger')
|
||||
@pytest.mark.skip(reason="TODO: modify the logger")
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
@@ -35,12 +38,12 @@ def test_linearize():
|
||||
graph = tracer.trace(model)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu'))
|
||||
MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device="cpu"))
|
||||
node_list = linearize(gm)
|
||||
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
||||
op_list = gm.__sequence__.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
op_list = op_list[: op_list.index(loss_op)]
|
||||
in_ckpt = False
|
||||
ckpt_idx = 0
|
||||
for idx, op in enumerate(op_list):
|
||||
@@ -48,8 +51,9 @@ def test_linearize():
|
||||
if isinstance(op, ForwardNograd):
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint[
|
||||
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
assert (
|
||||
n.activation_checkpoint[0] == ckpt_idx
|
||||
), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
@@ -65,8 +69,9 @@ def test_linearize():
|
||||
ckpt_idx += 1
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint[
|
||||
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
assert (
|
||||
n.activation_checkpoint[0] == ckpt_idx
|
||||
), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
@@ -75,8 +80,9 @@ def test_linearize():
|
||||
in_ckpt = True
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint[
|
||||
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
assert (
|
||||
n.activation_checkpoint[0] == ckpt_idx
|
||||
), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
|
||||
del model
|
||||
del gm
|
||||
@@ -100,7 +106,7 @@ def test_linearize_torch11():
|
||||
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
||||
op_list = gm.__sequence__.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
op_list = op_list[: op_list.index(loss_op)]
|
||||
in_ckpt = False
|
||||
ckpt_idx = 0
|
||||
for idx, op in enumerate(op_list):
|
||||
|
Reference in New Issue
Block a user