mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[fx] allow native ckpt trace and codegen. (#2438)
This commit is contained in:
@@ -13,6 +13,7 @@ def symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
meta_args: Optional[Dict[str, Any]] = None,
|
||||
trace_act_ckpt=False,
|
||||
) -> ColoGraphModule:
|
||||
"""
|
||||
Symbolic tracing API
|
||||
@@ -49,6 +50,6 @@ def symbolic_trace(
|
||||
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
|
||||
|
||||
"""
|
||||
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return ColoGraphModule(root, graph, name)
|
||||
|
Reference in New Issue
Block a user