mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fx] supported data-dependent control flow in model tracing (#1185)
* [fx] supported data-dependent control flow in model tracing * polish code
This commit is contained in:
31
colossalai/fx/tracer/_tracer_utils.py
Normal file
31
colossalai/fx/tracer/_tracer_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import List, Union, Any
|
||||
from ..proxy import ColoProxy, MetaDeviceAttribute
|
||||
|
||||
__all__ = ['is_element_in_list', 'extract_meta']
|
||||
|
||||
|
||||
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
|
||||
if isinstance(elements, (tuple, list, set)):
|
||||
for ele in elements:
|
||||
if ele not in list_:
|
||||
return False, ele
|
||||
else:
|
||||
if elements not in list_:
|
||||
return False, elements
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def extract_meta(*args, **kwargs):
|
||||
|
||||
def _convert(val):
|
||||
if isinstance(val, MetaDeviceAttribute):
|
||||
return 'meta'
|
||||
elif isinstance(val, ColoProxy):
|
||||
assert val.meta_tensor is not None
|
||||
return val.meta_tensor
|
||||
return val
|
||||
|
||||
new_args = [_convert(val) for val in args]
|
||||
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
||||
return new_args, new_kwargs
|
Reference in New Issue
Block a user