[fx] supported model tracing for huggingface bert (#1201)

* [fx] supported model tracing for huggingface bert

* polish test
This commit is contained in:
Frank Lee
2022-07-05 13:19:57 +08:00
committed by GitHub
parent 060b917daf
commit f7878f465c
5 changed files with 126 additions and 4 deletions

View File

@@ -198,6 +198,16 @@ class ColoTracer(Tracer):
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
@@ -213,8 +223,12 @@ class ColoTracer(Tracer):
# assign as attributed for late reference
def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items():
assert v.is_meta == should_be_meta, \
f'expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
if not should_be_meta:
assert not torch.is_tensor(v) or not v.is_meta, \
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
else:
assert v.is_meta == should_be_meta, \
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
_check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True)