basic memory

This commit is contained in:
oahzxl
2022-11-07 18:26:13 +08:00
parent c35718e8db
commit d95cfe2622
2 changed files with 90 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ import copy
import torch
import torch.nn.functional as F
import pytest
import torch.fx
import torch.multiprocessing as mp
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
@@ -56,18 +57,15 @@ def _run_offload_codegen(rank):
pair = torch.randn(1, 32, 32, 128).cuda()
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
gm_prop = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm_prop)
graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop)
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
# now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm)
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
# annotate the chunk part
# for node in graph.nodes:
# if node.name == "linear0":
# setattr(node, "activation_offload", [0, True, False])
# if node.name == "linear1":
# setattr(node, "activation_offload", [0, True, False])
codegen = ChunkCodeGen(gm_prop)
graph.set_codegen(codegen)