mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[fx] supported model tracing for huggingface bert (#1201)
* [fx] supported model tracing for huggingface bert * polish test
This commit is contained in:
@@ -198,6 +198,16 @@ class ColoTracer(Tracer):
|
||||
sig = inspect.signature(root.forward)
|
||||
sig_names = set(sig.parameters.keys())
|
||||
meta_arg_names = set(meta_args.keys())
|
||||
|
||||
# update concrete args with default values
|
||||
non_meta_arg_names = sig_names - meta_arg_names
|
||||
for k, v in sig.parameters.items():
|
||||
if k in non_meta_arg_names and \
|
||||
k not in concrete_args and \
|
||||
v.default is not inspect.Parameter.empty:
|
||||
concrete_args[k] = v.default
|
||||
|
||||
# get non concrete arg names
|
||||
concrete_arg_names = set(concrete_args.keys())
|
||||
non_concrete_arg_names = sig_names - concrete_arg_names
|
||||
|
||||
@@ -213,8 +223,12 @@ class ColoTracer(Tracer):
|
||||
# assign as attributed for late reference
|
||||
def _check_kwargs(kwargs, should_be_meta: bool):
|
||||
for k, v in kwargs.items():
|
||||
assert v.is_meta == should_be_meta, \
|
||||
f'expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
|
||||
if not should_be_meta:
|
||||
assert not torch.is_tensor(v) or not v.is_meta, \
|
||||
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
|
||||
else:
|
||||
assert v.is_meta == should_be_meta, \
|
||||
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
|
||||
|
||||
_check_kwargs(concrete_args, should_be_meta=False)
|
||||
_check_kwargs(meta_args, should_be_meta=True)
|
||||
|
Reference in New Issue
Block a user