mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fx] Add offload codegen (#1598)
* [fx] add input activation offload to codegen * [fx] modify unit test * [fx] remove two skips in torch11 * [fx] use all_input_nodes instead of _input_nodes
This commit is contained in:
@@ -22,14 +22,20 @@ if COLOGM:
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
"""Bind checkpoint functions to ColoGraphModule
|
||||
We need to bind our checkpoint functions to the GraphModule so
|
||||
that we could correctly use self.checkpoint for GraphModule forward
|
||||
"""Bind function needed for correctly execute gm forward
|
||||
|
||||
We need to bind checkpoint functions and saved_tensor_hooks functions
|
||||
to gm so that we could correctly execute gm forward
|
||||
|
||||
Args:
|
||||
ckpt_def (_type_): definition before the forward function
|
||||
globals (_type_): global variables
|
||||
"""
|
||||
|
||||
ckpt_code = "\n".join(ckpt_def)
|
||||
globals_copy = globals.copy()
|
||||
_exec_with_source(ckpt_code, globals_copy)
|
||||
func_list = [func for func in globals_copy.keys() if "checkpoint" in func]
|
||||
func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func]
|
||||
for func in func_list:
|
||||
tmp_func = globals_copy[func]
|
||||
setattr(self, func, tmp_func.__get__(self, self.__class__))
|
||||
|
Reference in New Issue
Block a user