[fx] Add offload codegen (#1598)

* [fx] add input activation offload to codegen

* [fx] modify unit test

* [fx] remove two skips in torch11

* [fx] use all_input_nodes instead of _input_nodes
This commit is contained in:
Boyuan Yao
2022-09-14 15:49:06 +08:00
committed by GitHub
parent c8e9b2ad78
commit a7cda6f57d
4 changed files with 264 additions and 20 deletions

View File

@@ -22,14 +22,20 @@ if COLOGM:
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
"""Bind checkpoint functions to ColoGraphModule
We need to bind our checkpoint functions to the GraphModule so
that we could correctly use self.checkpoint for GraphModule forward
"""Bind function needed for correctly execute gm forward
We need to bind checkpoint functions and saved_tensor_hooks functions
to gm so that we could correctly execute gm forward
Args:
ckpt_def (_type_): definition before the forward function
globals (_type_): global variables
"""
ckpt_code = "\n".join(ckpt_def)
globals_copy = globals.copy()
_exec_with_source(ckpt_code, globals_copy)
func_list = [func for func in globals_copy.keys() if "checkpoint" in func]
func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func]
for func in func_list:
tmp_func = globals_copy[func]
setattr(self, func, tmp_func.__get__(self, self.__class__))