[fx] Fix wrong index in annotation and minimal flops in ckpt solver (#1521)

* [fx] fix wrong variable name in solver rotor

* [fx] fix wrong variable name in solver rotor

* [fx] fix the discretize bug

* [fx] fix the first op in activation checkpoint codegen

* [fx] fix some bugs of ckpt solver

* [fx] modify test_ckpt_torchvision

* [fx] set sequence to __sequence__ attr of GraphModule

* [fx] docstring modification

* [fx] remove performance test
This commit is contained in:
Boyuan Yao
2022-08-31 18:10:48 +08:00
committed by GitHub
parent 07f5a4e054
commit b231430bcb
2 changed files with 27 additions and 11 deletions

View File

@@ -62,13 +62,13 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121]
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32, device='meta')
data = torch.rand(8, 3, 224, 224, device='meta')
for solver in SOLVERS:
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
@@ -95,13 +95,13 @@ def test_ckpt_solver():
def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121]
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32, device='meta')
data = torch.rand(8, 3, 32, 32, device='meta')
for solver in SOLVERS:
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)