[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)

* polish chunk manager

* polish unit test

* impl add_extern_static_tensor for chunk mgr

* add mem stats collector v2

* polish code

* polish unit test

* polish code

* polish get chunks
This commit is contained in:
ver217
2022-06-09 20:56:34 +08:00
committed by GitHub
parent b3a03e4bfd
commit be01db37c8
6 changed files with 68 additions and 31 deletions

View File

@@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Deque, Set, List
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
@@ -172,6 +172,12 @@ class Chunk:
def device_type(self) -> str:
return self.data.device.type
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
class ChunkManager:
@@ -226,8 +232,7 @@ class ChunkManager:
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
return src_rank
def access_chunk(self, tensor: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor]
def access_chunk(self, chunk: Chunk) -> None:
if chunk in self.accessed_chunks:
return
if not chunk.is_free:
@@ -236,10 +241,9 @@ class ChunkManager:
self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, tensor: torch.Tensor) -> None:
def release_chunk(self, chunk: Chunk) -> None:
if not self.enable_distributed_storage:
return
chunk = self.tensor_chunk_map[tensor]
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
@@ -248,8 +252,7 @@ class ChunkManager:
if chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
chunk = self.tensor_chunk_map[tensor]
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
if chunk.data.device == device:
return
if chunk.can_move_device and not chunk.is_free:
@@ -261,8 +264,7 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
def reduce_chunk(self, chunk: Chunk) -> bool:
if not chunk.can_reduce:
return False
self.total_mem[chunk.device_type] -= chunk.mem
@@ -274,10 +276,6 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
return chunk.is_free
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
return self.tensor_chunk_map[tensor]
@@ -285,8 +283,8 @@ class ChunkManager:
self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None:
for tensor in self.lazy_release_tensors:
self.release_chunk(tensor)
for chunk in self.get_chunks(self.lazy_release_tensors):
self.release_chunk(chunk)
self.lazy_release_tensors.clear()
def __repr__(self) -> str:
@@ -340,3 +338,23 @@ class ChunkManager:
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
if not dest_chunk.is_free:
dest_chunk.copy_(src_chunk)
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
chunks = []
for tensor in tensors:
chunk = self.get_chunk(tensor)
if chunk not in chunks:
chunks.append(chunk)
return tuple(chunks)
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()