[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)

This commit is contained in:
Jiarui Fang
2022-12-06 16:43:06 +08:00
committed by GitHub
parent 1f99205827
commit 33f4412102
5 changed files with 193 additions and 139 deletions

View File

@@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
user_static_memstats: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
@@ -119,14 +118,10 @@ class ShardedModelV2(nn.Module):
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy
self.user_static_memstats = user_static_memstats
self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer:
if self.user_static_memstats:
self._memstats_collector = StaticMemStatsCollector(self.module)
else:
self._memstats_collector = MemStatsCollector()
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else:
@@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
f.write('\n')
f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB')))
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
f.write('\n')
def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector:
if self.user_static_memstats:
self.init_mem_stats(*args)
self._start_collect_memstats()
for p in self.module.parameters():
@@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
self._memstats_collector.overall_mem_stats('cuda'))
self._memstats_collector._memstats.overall_mem_stats('cuda'))
@torch.no_grad()
def _post_backward_operations(self) -> None: