mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[gemini] gemini mgr supports "cpu" placement policy (#1118)
* update gemini mgr * update chunk * add docstr * polish placement policy * update test chunk * update test zero * polish unit test * remove useless unit test
This commit is contained in:
@@ -36,8 +36,21 @@ class ChunkFullError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Chunk:
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||
return tensor.storage().size() == 0
|
||||
|
||||
|
||||
def free_storage(tensor: torch.Tensor) -> None:
|
||||
if not is_storage_empty(tensor):
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
|
||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
if is_storage_empty(tensor):
|
||||
tensor.storage().resize_(tensor.numel())
|
||||
|
||||
|
||||
class Chunk:
|
||||
"""
|
||||
A chunk is a contiguous memory space which contains multiple tensors.
|
||||
|
||||
@@ -46,26 +59,37 @@ class Chunk:
|
||||
src_rank (int): the process which owns the chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
|
||||
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
src_rank: int,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None) -> None:
|
||||
init_device: Optional[torch.device] = None,
|
||||
force_data_on_cuda: bool = False) -> None:
|
||||
self.size = chunk_size
|
||||
self.utilized_size = 0
|
||||
self.src_rank = src_rank
|
||||
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
||||
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
|
||||
self.dtype = dtype
|
||||
self.device = init_device or get_current_device()
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
|
||||
device = init_device or get_current_device()
|
||||
if force_data_on_cuda:
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
|
||||
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
|
||||
if device.type == 'cuda':
|
||||
free_storage(self._cpu_data)
|
||||
else:
|
||||
free_storage(self.data)
|
||||
else:
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
|
||||
self._cpu_data = None
|
||||
|
||||
# we only keep the chunk in full in the process by which the tensor is owned
|
||||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(0)
|
||||
|
||||
free_storage(self._payload)
|
||||
|
||||
# each tensor is associated with a TensorInfo to track meta info
|
||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||
self.mem = self.size * self.data.element_size()
|
||||
@@ -83,16 +107,16 @@ class Chunk:
|
||||
# raise exception when the chunk size is exceeded
|
||||
if new_utilized_size > self.size:
|
||||
raise ChunkFullError
|
||||
|
||||
|
||||
# set tensor state
|
||||
tensor_state = TensorState.FREE
|
||||
|
||||
# if the process owns the rank, then copy the tensor to its chunk buffer
|
||||
# otherwise set its storage size to 0 to reduce memory consumption
|
||||
if self.is_src_rank:
|
||||
self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
|
||||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
|
||||
tensor_state = TensorState.HOLD
|
||||
tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
else:
|
||||
tensor.storage().resize_(0)
|
||||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
||||
@@ -103,12 +127,12 @@ class Chunk:
|
||||
Release the memory space on processes which do not own the chunk.
|
||||
"""
|
||||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(0)
|
||||
free_storage(self._payload)
|
||||
self._update_tensors_state(TensorState.FREE)
|
||||
|
||||
def _update_tensors_ptr(self) -> None:
|
||||
for tensor, tensor_info in self.tensors_info.items():
|
||||
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||
for tensor_info in self.tensors_info.values():
|
||||
@@ -122,8 +146,8 @@ class Chunk:
|
||||
# recover the chunk on non-owner processes
|
||||
# and broadcast the chunk from the source to all processes
|
||||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(self.size)
|
||||
self.data.data = self.data.to(get_current_device())
|
||||
alloc_storage(self._payload)
|
||||
self.move_device(get_current_device(), update_ptr=False)
|
||||
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
# update tensor meta info
|
||||
@@ -131,15 +155,32 @@ class Chunk:
|
||||
if not self.is_src_rank:
|
||||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
||||
|
||||
def move_device(self, device: torch.device) -> None:
|
||||
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
|
||||
"""
|
||||
Move the chunk to a target device.
|
||||
|
||||
Args:
|
||||
device (torch.device): the target device for data movement.
|
||||
"""
|
||||
self.data.data = self.data.to(device)
|
||||
self._update_tensors_ptr()
|
||||
if self._payload.device == device:
|
||||
return
|
||||
if self._cpu_data is None:
|
||||
self.data.data = self.data.to(device)
|
||||
else:
|
||||
if device.type == 'cuda':
|
||||
# cpu -> cuda
|
||||
src = self._cpu_data
|
||||
dest = self.data
|
||||
else:
|
||||
# cuda -> cpu
|
||||
src = self.data
|
||||
dest = self._cpu_data
|
||||
alloc_storage(dest)
|
||||
dest.copy_(src)
|
||||
free_storage(src)
|
||||
|
||||
if update_ptr:
|
||||
self._update_tensors_ptr()
|
||||
|
||||
def reduce(self, is_all_reduce: bool = False) -> None:
|
||||
"""
|
||||
@@ -148,7 +189,7 @@ class Chunk:
|
||||
Args:
|
||||
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
|
||||
"""
|
||||
self.data.data = self.data.to(get_current_device())
|
||||
self.move_device(get_current_device(), update_ptr=False)
|
||||
if is_all_reduce:
|
||||
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
||||
else:
|
||||
@@ -187,8 +228,8 @@ class Chunk:
|
||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
|
||||
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
|
||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
@@ -225,7 +266,7 @@ class Chunk:
|
||||
"""
|
||||
Check whether the chunk is empty.
|
||||
"""
|
||||
return self.data.storage().size() == 0
|
||||
return is_storage_empty(self._payload)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
||||
@@ -235,8 +276,8 @@ class Chunk:
|
||||
"""
|
||||
Check if the chunk has inf or nan values.
|
||||
"""
|
||||
return torch.isinf(self.data[:self.utilized_size]).any().item() or \
|
||||
torch.isnan(self.data[:self.utilized_size]).any().item()
|
||||
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
|
||||
torch.isnan(self._payload[:self.utilized_size]).any().item()
|
||||
|
||||
def copy_(self, dest_chunk: 'Chunk'):
|
||||
"""
|
||||
@@ -246,7 +287,7 @@ class Chunk:
|
||||
assert not dest_chunk.is_empty
|
||||
assert self.size == dest_chunk.size
|
||||
assert self.utilized_size == dest_chunk.utilized_size
|
||||
self.data.copy_(dest_chunk.data)
|
||||
self._payload.copy_(dest_chunk._payload)
|
||||
self._update_tensors_ptr()
|
||||
|
||||
@property
|
||||
@@ -254,7 +295,7 @@ class Chunk:
|
||||
"""
|
||||
Get the device type of the chunk.
|
||||
"""
|
||||
return self.data.device.type
|
||||
return self._payload.device.type
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(id(self))
|
||||
@@ -265,6 +306,12 @@ class Chunk:
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
@property
|
||||
def _payload(self) -> torch.Tensor:
|
||||
if self._cpu_data is None or is_storage_empty(self._cpu_data):
|
||||
return self.data
|
||||
return self._cpu_data
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
"""
|
||||
@@ -285,6 +332,7 @@ class ChunkManager:
|
||||
self.enable_distributed_storage = enable_distributed_storage
|
||||
self.device = init_device or get_current_device()
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
|
||||
self.groups_force_data_on_cuda: Dict[str, bool] = {}
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.lazy_release_tensors: List[torch.Tensor] = []
|
||||
@@ -292,6 +340,17 @@ class ChunkManager:
|
||||
self.rank_load: Dict[str, torch.Tensor] = {}
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None:
|
||||
"""Create a chunk group.
|
||||
|
||||
Args:
|
||||
group_name (str): group name
|
||||
force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False.
|
||||
"""
|
||||
assert group_name not in self.chunk_groups
|
||||
self.chunk_groups[group_name] = deque()
|
||||
self.groups_force_data_on_cuda[group_name] = force_data_on_cuda
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
|
||||
"""
|
||||
Append a tensor to a chunk.
|
||||
@@ -304,19 +363,20 @@ class ChunkManager:
|
||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||
raise ValueError(
|
||||
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
||||
if group_name not in self.chunk_groups:
|
||||
self.chunk_groups[group_name] = deque()
|
||||
|
||||
try:
|
||||
# append the tensor to the last chunk
|
||||
self.chunk_groups[group_name][-1].append(tensor)
|
||||
except (IndexError, ChunkFullError):
|
||||
# the except statement will be triggered when there is no chunk or
|
||||
# the except statement will be triggered when there is no chunk or
|
||||
# the last chunk in the chunk group is full
|
||||
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||
chunk_size = self.chunk_size or tensor.numel()
|
||||
src_rank = self._get_next_src_rank(group_name)
|
||||
chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device)
|
||||
chunk = Chunk(chunk_size,
|
||||
src_rank,
|
||||
tensor.dtype,
|
||||
self.device,
|
||||
force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
|
||||
|
||||
if self.enable_distributed_storage and self.chunk_size is None:
|
||||
self.rank_load[group_name][src_rank] += chunk_size
|
||||
@@ -387,7 +447,7 @@ class ChunkManager:
|
||||
# update the memory consumption after releasing
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None:
|
||||
"""
|
||||
Move the chunk to the target device.
|
||||
|
||||
@@ -399,7 +459,7 @@ class ChunkManager:
|
||||
return
|
||||
if chunk.can_move_device and not chunk.is_empty:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.move_device(device)
|
||||
chunk.move_device(device, update_ptr=update_ptr)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
|
||||
Reference in New Issue
Block a user