[fx] add more meta_registry for MetaTensor execution. (#2000)

* [sc] add examples for auto checkpoint.

* merge upstream

* [fx] add more meta_registry for MetaTensor execution.
This commit is contained in:
Super Daniel
2022-11-23 10:55:46 +08:00
committed by GitHub
parent d00d905b86
commit 2edbef13cc
4 changed files with 70 additions and 21 deletions

View File

@@ -20,28 +20,25 @@ def symbolic_trace(
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow.
If specified using ``meta_args`` only, the tracing can be done ahead of time.
With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using
``meta_args`` only, the tracing can be done ahead of time.
Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names
and the value of the argument's values.
Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the
argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model)
>>> gm = symbolic_trace(model, concrete_args=concrete_args)
# else try this
>>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
# else try this
>>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)})
>>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
@@ -52,7 +49,6 @@ def symbolic_trace(
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
tracer = ColoTracer()
graph = tracer.trace(root, concrete_args, meta_args)
name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__)
return ColoGraphModule(tracer.root, graph, name)
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)