[autoparallel] refactor and add rotorc. (#1789)

* [autoparallel] refactor and add rotorc.

* [autoparallel] refactor and add rotorc.
This commit is contained in:
Super Daniel
2022-11-03 12:32:51 +08:00
committed by GitHub
parent 4d6e1284cb
commit e8a9bebc87
5 changed files with 333 additions and 129 deletions

View File

@@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
kwargs['inplace'] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
meta.bwd_mem_out -= param_size
@@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
module.inplace = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
# grad for param will not be counted