mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[fx] Fix wrong index in annotation and minimal flops in ckpt solver (#1521)
* [fx] fix wrong variable name in solver rotor * [fx] fix wrong variable name in solver rotor * [fx] fix the discretize bug * [fx] fix the first op in activation checkpoint codegen * [fx] fix some bugs of ckpt solver * [fx] modify test_ckpt_torchvision * [fx] set sequence to __sequence__ attr of GraphModule * [fx] docstring modification * [fx] remove performance test
This commit is contained in:
@@ -62,13 +62,13 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
|
||||
|
||||
def _run_ckpt_solver(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
|
||||
data = torch.rand(2, 3, 32, 32, device='meta')
|
||||
data = torch.rand(8, 3, 224, 224, device='meta')
|
||||
for solver in SOLVERS:
|
||||
for model_cls in MODEL_LIST:
|
||||
m = model_cls(num_classes=5)
|
||||
@@ -95,13 +95,13 @@ def test_ckpt_solver():
|
||||
|
||||
def _run_ckpt_solver_torch11(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
|
||||
data = torch.rand(2, 3, 32, 32, device='meta')
|
||||
data = torch.rand(8, 3, 32, 32, device='meta')
|
||||
for solver in SOLVERS:
|
||||
for model_cls in MODEL_LIST:
|
||||
m = model_cls(num_classes=5)
|
||||
|
Reference in New Issue
Block a user