mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[fx]refactor tracer (#1335)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from typing import List, Union, Any
|
||||
from ..proxy import ColoProxy, MetaDeviceAttribute
|
||||
from ..proxy import ColoProxy, ColoAttribute
|
||||
|
||||
__all__ = ['is_element_in_list', 'extract_meta']
|
||||
|
||||
@@ -19,10 +19,11 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
|
||||
def extract_meta(*args, **kwargs):
|
||||
|
||||
def _convert(val):
|
||||
if isinstance(val, MetaDeviceAttribute):
|
||||
return 'meta'
|
||||
elif isinstance(val, ColoProxy):
|
||||
if isinstance(val, ColoProxy):
|
||||
return val.meta_data
|
||||
elif isinstance(val, (list, tuple)):
|
||||
return type(val)([_convert(ele) for ele in val])
|
||||
|
||||
return val
|
||||
|
||||
new_args = [_convert(val) for val in args]
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import operator
|
||||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
|
||||
|
||||
@meta_patched_function.register(operator.getitem)
|
||||
@@ -14,6 +15,30 @@ def operator_getitem(a, b):
|
||||
return concrete
|
||||
return t
|
||||
|
||||
def _slice_convert(slice_obj):
|
||||
attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step}
|
||||
new_attrs = _slice_attr_convert(attrs)
|
||||
attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step'])
|
||||
return slice(*attr_dict_to_tuple)
|
||||
|
||||
def _slice_attr_convert(attrs):
|
||||
new_attrs = {}
|
||||
for key, value in attrs.items():
|
||||
if isinstance(value, ColoProxy):
|
||||
new_attrs[key] = value.meta_data
|
||||
else:
|
||||
new_attrs[key] = value
|
||||
return new_attrs
|
||||
|
||||
if isinstance(b, tuple):
|
||||
b = list(b)
|
||||
for index, element in enumerate(b):
|
||||
if isinstance(element, slice):
|
||||
b[index] = _slice_convert(element)
|
||||
b = tuple(b)
|
||||
elif isinstance(b, slice):
|
||||
b = _slice_convert(b)
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
# TODO: infer shape without performing the computation.
|
||||
if isinstance(b, tuple):
|
||||
@@ -21,4 +46,12 @@ def operator_getitem(a, b):
|
||||
else:
|
||||
b = to_concrete(b)
|
||||
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||
|
||||
if isinstance(a, ColoProxy):
|
||||
# TODO: infer shape without performing the computation.
|
||||
if isinstance(b, tuple):
|
||||
b = tuple(map(to_concrete, b))
|
||||
else:
|
||||
b = to_concrete(b)
|
||||
return operator.getitem(torch.empty_like(a.meta_data, device="cpu"), b).to("meta")
|
||||
return operator.getitem(a, b)
|
||||
|
Reference in New Issue
Block a user