[fx] allow native ckpt trace and codegen. (#2438)

This commit is contained in:
Super Daniel
2023-01-11 13:49:59 +08:00
committed by GitHub
parent 41429b9b28
commit c41e59e5ad
3 changed files with 37 additions and 23 deletions

View File

@@ -13,6 +13,7 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule:
"""
Symbolic tracing API
@@ -49,6 +50,6 @@ def symbolic_trace(
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)