From 5db3a5bf42a7f8c5fa00141d95fbac633bce4b37 Mon Sep 17 00:00:00 2001 From: oahzxl <43881818+oahzxl@users.noreply.github.com> Date: Wed, 18 Jan 2023 17:02:46 +0800 Subject: [PATCH] [fx] allow control of ckpt_codegen init (#2498) * [fx] allow control of ckpt_codegen init Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so. So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__. * code style --- colossalai/fx/graph_module.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 2d6a71f19..ebb9975f2 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -22,8 +22,13 @@ if COLOGM: class ColoGraphModule(GraphModule): - def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): - graph.set_codegen(ActivationCheckpointCodeGen()) + def __init__(self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = 'GraphModule', + ckpt_codegen: bool = True): + if ckpt_codegen: + graph.set_codegen(ActivationCheckpointCodeGen()) super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals):