[tensor] chunk manager monitor mem usage (#1076)

This commit is contained in:
ver217
2022-06-07 15:00:00 +08:00
committed by GitHub
parent 98cdbf49c6
commit 1b17859328
2 changed files with 35 additions and 1 deletions

View File

@@ -54,6 +54,7 @@ class Chunk:
if not self.is_src_rank:
self.data.storage().resize_(0)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
self.mem = self.size * self.data.element_size()
def append(self, tensor: torch.Tensor) -> None:
assert tensor.dtype == self.dtype
@@ -167,6 +168,10 @@ class Chunk:
self.data.copy_(dest_chunk.data)
self._update_tensors_ptr()
@property
def device_type(self) -> str:
return self.data.device.type
class ChunkManager:
@@ -184,6 +189,7 @@ class ChunkManager:
self.lazy_release_tensors: List[torch.Tensor] = []
if enable_distributed_storage and chunk_size is None:
self.rank_load: Dict[str, torch.Tensor] = {}
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
assert tensor not in self.tensor_chunk_map
@@ -202,6 +208,8 @@ class ChunkManager:
self.rank_load[group_name][src_rank] += chunk_size
self.chunk_groups[group_name].append(chunk)
chunk.append(tensor)
if not chunk.is_free:
self.total_mem[chunk.device_type] += chunk.mem
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
if not self.enable_distributed_storage:
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
@@ -222,8 +230,11 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
if chunk in self.accessed_chunks:
return
if not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.access()
self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, tensor: torch.Tensor) -> None:
if not self.enable_distributed_storage:
@@ -234,11 +245,17 @@ class ChunkManager:
if chunk.can_release:
chunk.release()
self.accessed_chunks.remove(chunk)
if chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
chunk = self.tensor_chunk_map[tensor]
if chunk.can_move_device:
if chunk.data.device == device:
return
if chunk.can_move_device and not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(device)
self.total_mem[chunk.device_type] += chunk.mem
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
chunk = self.tensor_chunk_map[tensor]
@@ -248,7 +265,9 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
if not chunk.can_reduce:
return False
self.total_mem[chunk.device_type] -= chunk.mem
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
self.total_mem[chunk.device_type] += chunk.mem
return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
@@ -272,6 +291,7 @@ class ChunkManager:
def __repr__(self) -> str:
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n'
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
for group_name, group in self.chunk_groups.items():
msg += f'Group {group_name}:\n'
for i, chunk in enumerate(group):