mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-26 09:42:27 +00:00
basic memory
This commit is contained in:
@@ -2,6 +2,7 @@ import copy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
@@ -56,18 +57,15 @@ def _run_offload_codegen(rank):
|
||||
pair = torch.randn(1, 32, 32, 128).cuda()
|
||||
|
||||
# trace the module and replace codegen
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
graph = tracer.trace(model)
|
||||
gm_prop = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
|
||||
|
||||
# now run it twice to get meta info in graph module, not necessary
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
|
||||
|
||||
# annotate the chunk part
|
||||
# for node in graph.nodes:
|
||||
# if node.name == "linear0":
|
||||
# setattr(node, "activation_offload", [0, True, False])
|
||||
# if node.name == "linear1":
|
||||
# setattr(node, "activation_offload", [0, True, False])
|
||||
|
||||
codegen = ChunkCodeGen(gm_prop)
|
||||
graph.set_codegen(codegen)
|
||||
|
||||
Reference in New Issue
Block a user