[gemini] gemini mgr supports "cpu" placement policy (#1118)

* update gemini mgr

* update chunk

* add docstr

* polish placement policy

* update test chunk

* update test zero

* polish unit test

* remove useless unit test
This commit is contained in:
ver217
2022-06-15 15:05:19 +08:00
committed by GitHub
parent f99f56dff4
commit 7d14b473f0
7 changed files with 124 additions and 129 deletions

View File

@@ -36,8 +36,21 @@ class ChunkFullError(Exception):
pass
class Chunk:
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
class Chunk:
"""
A chunk is a contiguous memory space which contains multiple tensors.
@@ -46,26 +59,37 @@ class Chunk:
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""
def __init__(self,
chunk_size: int,
src_rank: int,
dtype: torch.dtype,
init_device: Optional[torch.device] = None) -> None:
init_device: Optional[torch.device] = None,
force_data_on_cuda: bool = False) -> None:
self.size = chunk_size
self.utilized_size = 0
self.src_rank = src_rank
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
self.dtype = dtype
self.device = init_device or get_current_device()
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
device = init_device or get_current_device()
if force_data_on_cuda:
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
if device.type == 'cuda':
free_storage(self._cpu_data)
else:
free_storage(self.data)
else:
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
self._cpu_data = None
# we only keep the chunk in full in the process by which the tensor is owned
if not self.is_src_rank:
self.data.storage().resize_(0)
free_storage(self._payload)
# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
self.mem = self.size * self.data.element_size()
@@ -83,16 +107,16 @@ class Chunk:
# raise exception when the chunk size is exceeded
if new_utilized_size > self.size:
raise ChunkFullError
# set tensor state
tensor_state = TensorState.FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if self.is_src_rank:
self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
tensor_state = TensorState.HOLD
tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape)
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
else:
tensor.storage().resize_(0)
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
@@ -103,12 +127,12 @@ class Chunk:
Release the memory space on processes which do not own the chunk.
"""
if not self.is_src_rank:
self.data.storage().resize_(0)
free_storage(self._payload)
self._update_tensors_state(TensorState.FREE)
def _update_tensors_ptr(self) -> None:
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
@@ -122,8 +146,8 @@ class Chunk:
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if not self.is_src_rank:
self.data.storage().resize_(self.size)
self.data.data = self.data.to(get_current_device())
alloc_storage(self._payload)
self.move_device(get_current_device(), update_ptr=False)
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
# update tensor meta info
@@ -131,15 +155,32 @@ class Chunk:
if not self.is_src_rank:
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
def move_device(self, device: torch.device) -> None:
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
"""
Move the chunk to a target device.
Args:
device (torch.device): the target device for data movement.
"""
self.data.data = self.data.to(device)
self._update_tensors_ptr()
if self._payload.device == device:
return
if self._cpu_data is None:
self.data.data = self.data.to(device)
else:
if device.type == 'cuda':
# cpu -> cuda
src = self._cpu_data
dest = self.data
else:
# cuda -> cpu
src = self.data
dest = self._cpu_data
alloc_storage(dest)
dest.copy_(src)
free_storage(src)
if update_ptr:
self._update_tensors_ptr()
def reduce(self, is_all_reduce: bool = False) -> None:
"""
@@ -148,7 +189,7 @@ class Chunk:
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self.data.data = self.data.to(get_current_device())
self.move_device(get_current_device(), update_ptr=False)
if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
else:
@@ -187,8 +228,8 @@ class Chunk:
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info = self.tensors_info[tensor]
self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property
def can_release(self) -> bool:
@@ -225,7 +266,7 @@ class Chunk:
"""
Check whether the chunk is empty.
"""
return self.data.storage().size() == 0
return is_storage_empty(self._payload)
def __repr__(self) -> str:
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
@@ -235,8 +276,8 @@ class Chunk:
"""
Check if the chunk has inf or nan values.
"""
return torch.isinf(self.data[:self.utilized_size]).any().item() or \
torch.isnan(self.data[:self.utilized_size]).any().item()
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
torch.isnan(self._payload[:self.utilized_size]).any().item()
def copy_(self, dest_chunk: 'Chunk'):
"""
@@ -246,7 +287,7 @@ class Chunk:
assert not dest_chunk.is_empty
assert self.size == dest_chunk.size
assert self.utilized_size == dest_chunk.utilized_size
self.data.copy_(dest_chunk.data)
self._payload.copy_(dest_chunk._payload)
self._update_tensors_ptr()
@property
@@ -254,7 +295,7 @@ class Chunk:
"""
Get the device type of the chunk.
"""
return self.data.device.type
return self._payload.device.type
def __hash__(self) -> int:
return hash(id(self))
@@ -265,6 +306,12 @@ class Chunk:
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
@property
def _payload(self) -> torch.Tensor:
if self._cpu_data is None or is_storage_empty(self._cpu_data):
return self.data
return self._cpu_data
class ChunkManager:
"""
@@ -285,6 +332,7 @@ class ChunkManager:
self.enable_distributed_storage = enable_distributed_storage
self.device = init_device or get_current_device()
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
self.groups_force_data_on_cuda: Dict[str, bool] = {}
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
self.accessed_chunks: Set[Chunk] = set()
self.lazy_release_tensors: List[torch.Tensor] = []
@@ -292,6 +340,17 @@ class ChunkManager:
self.rank_load: Dict[str, torch.Tensor] = {}
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None:
"""Create a chunk group.
Args:
group_name (str): group name
force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False.
"""
assert group_name not in self.chunk_groups
self.chunk_groups[group_name] = deque()
self.groups_force_data_on_cuda[group_name] = force_data_on_cuda
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
"""
Append a tensor to a chunk.
@@ -304,19 +363,20 @@ class ChunkManager:
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
raise ValueError(
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
try:
# append the tensor to the last chunk
self.chunk_groups[group_name][-1].append(tensor)
except (IndexError, ChunkFullError):
# the except statement will be triggered when there is no chunk or
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
chunk_size = self.chunk_size or tensor.numel()
src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device)
chunk = Chunk(chunk_size,
src_rank,
tensor.dtype,
self.device,
force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
if self.enable_distributed_storage and self.chunk_size is None:
self.rank_load[group_name][src_rank] += chunk_size
@@ -387,7 +447,7 @@ class ChunkManager:
# update the memory consumption after releasing
self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None:
"""
Move the chunk to the target device.
@@ -399,7 +459,7 @@ class ChunkManager:
return
if chunk.can_move_device and not chunk.is_empty:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(device)
chunk.move_device(device, update_ptr=update_ptr)
self.total_mem[chunk.device_type] += chunk.mem
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: