mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fx] refactor tracer to trace complete graph (#1342)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx] refactor tracer to trace complete graph
* add comments and solve conflicts.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from typing import List, Union, Any
|
||||
from ..proxy import ColoProxy, ColoAttribute
|
||||
import torch
|
||||
from .meta_patch import meta_patched_function, meta_patched_module
|
||||
|
||||
__all__ = ['is_element_in_list', 'extract_meta']
|
||||
|
||||
@@ -29,3 +31,20 @@ def extract_meta(*args, **kwargs):
|
||||
new_args = [_convert(val) for val in args]
|
||||
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
||||
return new_args, new_kwargs
|
||||
|
||||
|
||||
def compute_meta_data_for_functions_proxy(target, args, kwargs):
|
||||
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
||||
|
||||
# fetch patched function
|
||||
if meta_patched_function.has(target):
|
||||
meta_target = meta_patched_function.get(target)
|
||||
elif meta_patched_function.has(target.__name__):
|
||||
meta_target = meta_patched_function.get(target.__name__)
|
||||
else:
|
||||
meta_target = target
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
if isinstance(meta_out, torch.Tensor):
|
||||
meta_out = meta_out.to(device="meta")
|
||||
|
||||
return meta_out
|
||||
|
Reference in New Issue
Block a user