mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)
* polish chunk manager * polish unit test * impl add_extern_static_tensor for chunk mgr * add mem stats collector v2 * polish code * polish unit test * polish code * polish get chunks
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Deque, Set, List
|
||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||
from collections import deque
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
@@ -172,6 +172,12 @@ class Chunk:
|
||||
def device_type(self) -> str:
|
||||
return self.data.device.type
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(id(self))
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self is __o
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
|
||||
@@ -226,8 +232,7 @@ class ChunkManager:
|
||||
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
|
||||
return src_rank
|
||||
|
||||
def access_chunk(self, tensor: torch.Tensor) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
if not chunk.is_free:
|
||||
@@ -236,10 +241,9 @@ class ChunkManager:
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def release_chunk(self, tensor: torch.Tensor) -> None:
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
if not self.enable_distributed_storage:
|
||||
return
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
@@ -248,8 +252,7 @@ class ChunkManager:
|
||||
if chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
||||
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
|
||||
if chunk.data.device == device:
|
||||
return
|
||||
if chunk.can_move_device and not chunk.is_free:
|
||||
@@ -261,8 +264,7 @@ class ChunkManager:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, tensor: torch.Tensor) -> bool:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
@@ -274,10 +276,6 @@ class ChunkManager:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||
|
||||
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
return chunk.is_free
|
||||
|
||||
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||
return self.tensor_chunk_map[tensor]
|
||||
|
||||
@@ -285,8 +283,8 @@ class ChunkManager:
|
||||
self.lazy_release_tensors.extend(tensors)
|
||||
|
||||
def exec_lazy_release(self) -> None:
|
||||
for tensor in self.lazy_release_tensors:
|
||||
self.release_chunk(tensor)
|
||||
for chunk in self.get_chunks(self.lazy_release_tensors):
|
||||
self.release_chunk(chunk)
|
||||
self.lazy_release_tensors.clear()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -340,3 +338,23 @@ class ChunkManager:
|
||||
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
|
||||
if not dest_chunk.is_free:
|
||||
dest_chunk.copy_(src_chunk)
|
||||
|
||||
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||
chunks = []
|
||||
for tensor in tensors:
|
||||
chunk = self.get_chunk(tensor)
|
||||
if chunk not in chunks:
|
||||
chunks.append(chunk)
|
||||
return tuple(chunks)
|
||||
|
||||
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
|
||||
"""Add extern static tensor to chunk manager.
|
||||
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
|
||||
They are "static", which means their shape, dtype, device never change.
|
||||
Thus, their memory usage never changes.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
Reference in New Issue
Block a user