mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[tensor] chunk manager monitor mem usage (#1076)
This commit is contained in:
@@ -54,6 +54,7 @@ class Chunk:
|
||||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(0)
|
||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||
self.mem = self.size * self.data.element_size()
|
||||
|
||||
def append(self, tensor: torch.Tensor) -> None:
|
||||
assert tensor.dtype == self.dtype
|
||||
@@ -167,6 +168,10 @@ class Chunk:
|
||||
self.data.copy_(dest_chunk.data)
|
||||
self._update_tensors_ptr()
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
return self.data.device.type
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
|
||||
@@ -184,6 +189,7 @@ class ChunkManager:
|
||||
self.lazy_release_tensors: List[torch.Tensor] = []
|
||||
if enable_distributed_storage and chunk_size is None:
|
||||
self.rank_load: Dict[str, torch.Tensor] = {}
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
@@ -202,6 +208,8 @@ class ChunkManager:
|
||||
self.rank_load[group_name][src_rank] += chunk_size
|
||||
self.chunk_groups[group_name].append(chunk)
|
||||
chunk.append(tensor)
|
||||
if not chunk.is_free:
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
|
||||
if not self.enable_distributed_storage:
|
||||
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
|
||||
@@ -222,8 +230,11 @@ class ChunkManager:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
if not chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.access()
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def release_chunk(self, tensor: torch.Tensor) -> None:
|
||||
if not self.enable_distributed_storage:
|
||||
@@ -234,11 +245,17 @@ class ChunkManager:
|
||||
if chunk.can_release:
|
||||
chunk.release()
|
||||
self.accessed_chunks.remove(chunk)
|
||||
if chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
||||
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk.can_move_device:
|
||||
if chunk.data.device == device:
|
||||
return
|
||||
if chunk.can_move_device and not chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.move_device(device)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
@@ -248,7 +265,9 @@ class ChunkManager:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
return True
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||
@@ -272,6 +291,7 @@ class ChunkManager:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n'
|
||||
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg += f'Group {group_name}:\n'
|
||||
for i, chunk in enumerate(group):
|
||||
|
Reference in New Issue
Block a user