mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polish
This commit is contained in:
@@ -111,7 +111,24 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
with self.global_hook:
|
||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||
def unwrap_fn(elem):
|
||||
|
||||
def _convert_meta(t: torch.Tensor):
|
||||
if t.device == 'meta':
|
||||
return t
|
||||
else:
|
||||
return t.to('meta')
|
||||
|
||||
if isinstance(elem, MetaTensor):
|
||||
return _convert_meta(elem._tensor)
|
||||
|
||||
elif isinstance(elem, torch.Tensor):
|
||||
return _convert_meta(elem)
|
||||
|
||||
else:
|
||||
return elem
|
||||
|
||||
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
||||
n_info = MetaInfo(n)
|
||||
n_info.outputs = _normalize_tuple(r)
|
||||
|
Reference in New Issue
Block a user