mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)
This commit is contained in:
@@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
|
||||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False,
|
||||
user_static_memstats: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||
@@ -119,14 +118,10 @@ class ShardedModelV2(nn.Module):
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.rank = dist.get_rank(self.process_group)
|
||||
self.shard_strategy = shard_strategy
|
||||
self.user_static_memstats = user_static_memstats
|
||||
|
||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||
if self._use_memory_tracer:
|
||||
if self.user_static_memstats:
|
||||
self._memstats_collector = StaticMemStatsCollector(self.module)
|
||||
else:
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||
else:
|
||||
@@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
|
||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||
f.write('CUDA model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
|
||||
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
|
||||
f.write('\n')
|
||||
f.write('CUDA non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB')))
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
||||
f.write('CPU non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
|
||||
f.write('\n')
|
||||
|
||||
def _pre_forward_operations(self, *args):
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
if self._memstats_collector:
|
||||
if self.user_static_memstats:
|
||||
self.init_mem_stats(*args)
|
||||
self._start_collect_memstats()
|
||||
|
||||
for p in self.module.parameters():
|
||||
@@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
||||
self._memstats_collector.overall_mem_stats('cuda'))
|
||||
self._memstats_collector._memstats.overall_mem_stats('cuda'))
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
|
Reference in New Issue
Block a user