mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[autoparallel] refactor and add rotorc. (#1789)
* [autoparallel] refactor and add rotorc. * [autoparallel] refactor and add rotorc.
This commit is contained in:
@@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
kwargs['inplace'] = True
|
||||
meta.bwd_mem_tmp = 0
|
||||
meta.bwd_mem_out = 0
|
||||
do_not_cache = False
|
||||
|
||||
meta.bwd_mem_out -= param_size
|
||||
@@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
module.inplace = True
|
||||
meta.bwd_mem_tmp = 0
|
||||
meta.bwd_mem_out = 0
|
||||
do_not_cache = False
|
||||
|
||||
# grad for param will not be counted
|
||||
|
Reference in New Issue
Block a user