[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,3 +1,3 @@
from .sharded_model_v2 import ShardedModelV2
__all__ = ['ShardedModelV2']
__all__ = ["ShardedModelV2"]

View File

@@ -25,7 +25,7 @@ def free_storage(data: torch.Tensor) -> None:
@torch.no_grad()
def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
"""Allocate storage for a tensor."""
if data.storage().size() == size.numel(): # no need to reallocate
if data.storage().size() == size.numel(): # no need to reallocate
return
assert data.storage().size() == 0
data.storage().resize_(size.numel())

View File

@@ -20,7 +20,6 @@ else:
class Bucket:
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
self.group = group
@@ -35,18 +34,18 @@ class Bucket:
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(self.output_shard[:self.offset],
self.buffer[:, :self.offset].contiguous(),
group=self.group)
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
)
else:
dist.reduce_scatter(self.output_shard[:self.offset],
list(self.buffer[:, :self.offset].unbind(0)),
group=self.group)
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.buffer[:, :self.offset].zero_()
self.buffer[:, : self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.buffer[0])
@@ -74,12 +73,12 @@ class Bucket:
tensor_size = tensor_list[0].numel()
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
offset = self.offset
self.buffer[:, offset:offset + tensor_size].copy_(stacked_input)
self.buffer[:, offset : offset + tensor_size].copy_(stacked_input)
self.offset += tensor_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0])
result_view = self.output_shard[offset : offset + tensor_size].view_as(tensor_list[0])
self.callbacks.append(functools.partial(callback_fn, result_view))
@@ -142,8 +141,9 @@ class ReduceScatterBucketer:
"""
world_size = group.size()
assert (len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
first_input = input_list[0]
first_input_size = first_input.numel()
@@ -183,7 +183,7 @@ class ReduceScatterBucketer:
@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size

View File

@@ -2,7 +2,6 @@
import functools
import itertools
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Iterator, Optional, Tuple
import torch
@@ -40,7 +39,7 @@ from .zero_hook import ZeroHook
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class ShardedModelV2(nn.Module):
@@ -78,20 +77,22 @@ class ShardedModelV2(nn.Module):
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
"""
def __init__(self,
module: nn.Module,
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
def __init__(
self,
module: nn.Module,
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
tensor_placement_policy: str = "cuda",
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs,
):
assert not isinstance(module, ShardedModelV2), "Nested ShardedModelV2 is not supported."
super().__init__()
self.logger = get_dist_logger()
self.bf16 = bf16
@@ -101,13 +102,13 @@ class ShardedModelV2(nn.Module):
sharded_cnt = 0
unshard_cnt = 0
for param in submodule.parameters(recurse=False):
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
assert hasattr(param, "colo_attr"), "You must use ZeroInitContext to init your module first."
if param.colo_attr.param_is_sharded:
sharded_cnt += 1
else:
unshard_cnt += 1
assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param'
submodule.param_is_sharded = (sharded_cnt > 0)
assert (not sharded_cnt) or (not unshard_cnt), "nn.Module can not both have shard param and unshard param"
submodule.param_is_sharded = sharded_cnt > 0
self.sharded_params = []
self.unshard_params = []
@@ -124,7 +125,7 @@ class ShardedModelV2(nn.Module):
self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy
self._use_memory_tracer = tensor_placement_policy == 'auto'
self._use_memory_tracer = tensor_placement_policy == "auto"
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
@@ -132,18 +133,19 @@ class ShardedModelV2(nn.Module):
else:
self._memstats_collector = None
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
tensor_placement_policy
)(mem_stats_collector=self._memstats_collector)
if 'warmup_non_model_data_ratio' in kwargs:
if tensor_placement_policy != 'auto':
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
if "warmup_non_model_data_ratio" in kwargs:
if tensor_placement_policy != "auto":
self.logger.warning("setting warmup_non_model_data_ratio is useless if not use auto placement")
else:
ratio = kwargs['warmup_non_model_data_ratio']
ratio = kwargs["warmup_non_model_data_ratio"]
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
self.logger.info(f"setting warmup_non_model_data_ratio as {ratio} for auto placement")
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, "colo_attr")]
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
# Register hooks
@@ -155,7 +157,7 @@ class ShardedModelV2(nn.Module):
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = tensor_placement_policy != 'cuda'
self._cpu_offload: bool = tensor_placement_policy != "cuda"
for param in module.parameters():
# Init `offload_grad`
param.colo_attr.offload_grad = self._cpu_offload
@@ -164,9 +166,11 @@ class ShardedModelV2(nn.Module):
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set
# gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if \
gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_predivide_factor: float = (
gradient_predivide_factor
if gradient_predivide_factor is not None
else get_gradient_predivide_factor(self.world_size)
)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
@@ -194,7 +198,7 @@ class ShardedModelV2(nn.Module):
def cpu_offload(self):
return self._cpu_offload
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
def dump_memory_stats(self, filename: Optional[str] = "dump_mem_stats.log") -> None:
"""
dummy memory tracer collected information to a file.
try:
@@ -205,18 +209,18 @@ class ShardedModelV2(nn.Module):
exit(0)
"""
if self._use_memory_tracer:
self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0])
self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, 'w+') as f:
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write('\n')
f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
f.write('\n')
with open(filename, "w+") as f:
f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n")
f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n")
f.write("CUDA model data (GB)\n")
f.write("\n")
f.write("CUDA non model data (GB)\n")
f.write(str(self._memstats_collector._memstats.non_model_data_list("cuda")))
f.write("CPU non model data (GB)\n")
f.write(str(self._memstats_collector._memstats.non_model_data_list("cpu")))
f.write("\n")
def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook
@@ -224,14 +228,14 @@ class ShardedModelV2(nn.Module):
self._start_collect_memstats()
for p in self.module.parameters():
if hasattr(p, 'colo_attr'):
if hasattr(p, "colo_attr"):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
self._stateful_tensor_mgr.start_iter()
def _post_forward_operations(self):
for p in self.module.parameters():
if hasattr(p, 'colo_attr'):
if hasattr(p, "colo_attr"):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
@@ -261,8 +265,9 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
self._cuda_margin_space = (
colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
)
@torch.no_grad()
def _post_backward_operations(self) -> None:
@@ -330,7 +335,7 @@ class ShardedModelV2(nn.Module):
"""
if grad is None:
return
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
assert not grad.requires_grad, "ShardedModel only works with gradients that don't require gradients"
if not self._require_backward_grad_sync:
return
# used to cheat Pytorch, since we can't return None
@@ -354,16 +359,19 @@ class ShardedModelV2(nn.Module):
grad.data.div_(self.gradient_predivide_factor)
if self.world_size > 1:
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
self.reducer.reduce_scatter_async(
grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param),
)
else:
self._reduce_scatter_callback(param, grad)
torch.cuda.current_stream().wait_stream(self.comm_stream)
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
assert isinstance(reduced_grad,
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
assert isinstance(
reduced_grad, torch.Tensor
), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
reduced_grad.data = reduced_grad.data.contiguous().view(-1)
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
@@ -372,7 +380,6 @@ class ShardedModelV2(nn.Module):
# FIXME(ver217): refactor the below line when impl eviction policy
def _save_grad(self, param: Parameter, grad: torch.Tensor):
# record whether we have overflow
self.overflow_counter += torch.isinf(grad).any().item()
self.overflow_counter += torch.isnan(grad).any().item()
@@ -384,8 +391,9 @@ class ShardedModelV2(nn.Module):
if self.reuse_fp16_shard:
# make parameters point to gradient
assert param.colo_attr.saved_grad.is_null(
), 'Gradient accumulation is not supported when reuse_fp16_shard=True'
assert (
param.colo_attr.saved_grad.is_null()
), "Gradient accumulation is not supported when reuse_fp16_shard=True"
param.colo_attr.grad_payload_reset(grad.data)
# release the memory of param
@@ -396,7 +404,6 @@ class ShardedModelV2(nn.Module):
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
else:
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
@@ -410,39 +417,44 @@ class ShardedModelV2(nn.Module):
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
return self.module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
return self.module.named_parameters(prefix, recurse)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
return self._colo_state_dict(destination,
prefix,
keep_vars,
shard_strategy=self.shard_strategy,
state_dict_func=nn.Module.state_dict,
module_to_load=self.module,
sharded_params=self.sharded_params,
process_group=self.process_group)
def state_dict(self, destination=None, prefix="", keep_vars=False) -> "OrderedDict[str, torch.Tensor]":
return self._colo_state_dict(
destination,
prefix,
keep_vars,
shard_strategy=self.shard_strategy,
state_dict_func=nn.Module.state_dict,
module_to_load=self.module,
sharded_params=self.sharded_params,
process_group=self.process_group,
)
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True) -> None:
for name, p in self.named_parameters():
if name in state_dict:
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
device=p.colo_attr.data_payload.device))
p.colo_attr.data_payload_reset(
state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)
)
# Force re-shard
p.colo_attr.sharded_data_tensor.is_sharded = False
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
elif strict:
raise RuntimeError(f'Missing key in state_dict: {name}')
raise RuntimeError(f"Missing key in state_dict: {name}")
def _colo_state_dict(self,
destination=None,
prefix='',
keep_vars=False,
shard_strategy: Optional[BaseShardStrategy] = None,
state_dict_func=None,
module_to_load=None,
sharded_params=[],
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
def _colo_state_dict(
self,
destination=None,
prefix="",
keep_vars=False,
shard_strategy: Optional[BaseShardStrategy] = None,
state_dict_func=None,
module_to_load=None,
sharded_params=[],
process_group=None,
) -> "OrderedDict[str, torch.Tensor]":
if len(sharded_params) == 0:
for param in self.parameters():
if param.colo_attr.param_is_sharded:
@@ -460,15 +472,9 @@ class ShardedModelV2(nn.Module):
p.colo_attr.set_data_none()
return gathered_state_dict
def _colo_load_from_state_dict(self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
shard_strategy=None):
def _colo_load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, shard_strategy=None
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
@@ -512,10 +518,12 @@ class ShardedModelV2(nn.Module):
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if hasattr(param, 'colo_attr'):
if hasattr(param, "colo_attr"):
param.colo_attr.data_payload_reset(
input_param.to(dtype=param.colo_attr.data_payload.dtype,
device=param.colo_attr.data_payload.device))
input_param.to(
dtype=param.colo_attr.data_payload.dtype, device=param.colo_attr.data_payload.device
)
)
if shard_strategy is not None:
# Force re-shard
param.colo_attr.sharded_data_tensor.is_sharded = False
@@ -531,19 +539,21 @@ class ShardedModelV2(nn.Module):
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(
key, input_param.shape, param.shape))
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
)
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
elif strict:
missing_keys.append(key)
@@ -559,8 +569,8 @@ class ShardedModelV2(nn.Module):
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)

View File

@@ -11,7 +11,7 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
Note the other_model has to be the same as self.
"""
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, 'colo_attr')
assert hasattr(zero_param, "colo_attr")
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])

View File

@@ -20,11 +20,13 @@ class ZeroHook(BaseOpHook):
Warning: this class has been deprecated after version 0.1.12
"""
def __init__(self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector] = None,
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
process_group: Optional[dist.ProcessGroup] = None):
def __init__(
self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector] = None,
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
process_group: Optional[dist.ProcessGroup] = None,
):
super().__init__()
self.logger = get_dist_logger("ZeROHook")
self.shard_strategy = shard_strategy
@@ -41,7 +43,7 @@ class ZeroHook(BaseOpHook):
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
assert hasattr(param, "colo_attr")
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
@@ -50,7 +52,7 @@ class ZeroHook(BaseOpHook):
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
assert hasattr(param, "colo_attr")
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
@@ -74,10 +76,9 @@ class ZeroHook(BaseOpHook):
self.gather_parameters(module)
for param in module.parameters(recurse=False):
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
assert param.data.device.type == "cuda", f"PRE FWD param.data must be on CUDA"
def post_fwd_exec(self, module: torch.nn.Module, *args):
# change tensor state to HOLD_AFTER_FWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
@@ -93,10 +94,9 @@ class ZeroHook(BaseOpHook):
self.gather_parameters(module)
for param in module.parameters(recurse=False):
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
assert param.data.device.type == "cuda", f"PRE BWD param.data must be on CUDA"
def post_bwd_exec(self, module: torch.nn.Module, input):
# change tensor state to HOLD_AFTER_BWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
@@ -114,5 +114,6 @@ class ZeroHook(BaseOpHook):
if self._stateful_tensor_mgr:
self.logger.debug(
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}",
ranks=[0])
ranks=[0],
)
self._stateful_tensor_mgr.finish_iter()