mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-25 17:22:12 +00:00
add meta
This commit is contained in:
@@ -9,6 +9,8 @@ import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from evoformer.evoformer import evoformer_base
|
||||
from chunk_codegen import ChunkCodeGen
|
||||
with_codegen = True
|
||||
@@ -56,9 +58,10 @@ def _run_offload_codegen(rank):
|
||||
# trace the module and replace codegen
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
graph = tracer.trace(model)
|
||||
# codegen = ChunkCodeGen()
|
||||
# graph.set_codegen(codegen)
|
||||
|
||||
gm_prop = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
|
||||
|
||||
# annotate the chunk part
|
||||
# for node in graph.nodes:
|
||||
# if node.name == "linear0":
|
||||
@@ -66,7 +69,9 @@ def _run_offload_codegen(rank):
|
||||
# if node.name == "linear1":
|
||||
# setattr(node, "activation_offload", [0, True, False])
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
codegen = ChunkCodeGen(gm_prop)
|
||||
# graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have all the components
|
||||
|
||||
Reference in New Issue
Block a user