mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[feature] new zero implementation (#1623)
This commit is contained in:
@@ -1,10 +1,6 @@
|
||||
from .chunk import TensorInfo, Chunk, TensorState
|
||||
from .chunk_mgr import ChunkManager
|
||||
from .chunk import TensorInfo, TensorState
|
||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
||||
from .gemini_mgr import GeminiManager
|
||||
|
||||
__all__ = [
|
||||
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk',
|
||||
'TensorState'
|
||||
]
|
||||
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState']
|
||||
|
@@ -1,316 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
COMPUTE = 1
|
||||
HOLD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||
TensorState.HOLD))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
state: TensorState
|
||||
offset: int
|
||||
end: int
|
||||
|
||||
|
||||
class ChunkFullError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||
return tensor.storage().size() == 0
|
||||
|
||||
|
||||
def free_storage(tensor: torch.Tensor) -> None:
|
||||
if not is_storage_empty(tensor):
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
|
||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
if is_storage_empty(tensor):
|
||||
tensor.storage().resize_(tensor.numel())
|
||||
|
||||
|
||||
class Chunk:
|
||||
"""
|
||||
A chunk is a contiguous memory space which contains multiple tensors.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in a chunk
|
||||
src_rank (int): the process which owns the chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
|
||||
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
src_rank: int,
|
||||
process_group: ColoProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
force_data_on_cuda: bool = False) -> None:
|
||||
self.size = chunk_size
|
||||
self.utilized_size = 0
|
||||
self.src_rank = src_rank
|
||||
self.process_group = process_group
|
||||
self.is_src_rank = process_group.dp_local_rank() == src_rank
|
||||
self.global_src_rank = process_group.get_ranks_in_dp()[src_rank]
|
||||
self.dtype = dtype
|
||||
device = init_device or get_current_device()
|
||||
if force_data_on_cuda:
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
|
||||
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
|
||||
if device.type == 'cuda':
|
||||
free_storage(self._cpu_data)
|
||||
else:
|
||||
free_storage(self.data)
|
||||
else:
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
|
||||
self._cpu_data = None
|
||||
|
||||
# we only keep the chunk in full in the process by which the tensor is owned
|
||||
if not self.is_src_rank:
|
||||
free_storage(self._payload)
|
||||
|
||||
# each tensor is associated with a TensorInfo to track meta info
|
||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||
self.mem = self.size * self.data.element_size()
|
||||
|
||||
def append(self, tensor: torch.Tensor) -> None:
|
||||
"""
|
||||
Add a tensor to the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a tensor to be added to the chunk
|
||||
"""
|
||||
assert tensor.dtype == self.dtype
|
||||
new_utilized_size = self.utilized_size + tensor.numel()
|
||||
|
||||
# raise exception when the chunk size is exceeded
|
||||
if new_utilized_size > self.size:
|
||||
raise ChunkFullError
|
||||
|
||||
# set tensor state
|
||||
tensor_state = TensorState.FREE
|
||||
|
||||
# if the process owns the rank, then copy the tensor to its chunk buffer
|
||||
# otherwise set its storage size to 0 to reduce memory consumption
|
||||
if self.is_src_rank:
|
||||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
|
||||
tensor_state = TensorState.HOLD
|
||||
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
else:
|
||||
tensor.storage().resize_(0)
|
||||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
||||
self.utilized_size = new_utilized_size
|
||||
|
||||
def release(self) -> None:
|
||||
"""
|
||||
Release the memory space on processes which do not own the chunk.
|
||||
"""
|
||||
if not self.is_src_rank:
|
||||
free_storage(self._payload)
|
||||
self._update_tensors_state(TensorState.FREE)
|
||||
|
||||
def _update_tensors_ptr(self) -> None:
|
||||
assert type(self._payload) == torch.Tensor
|
||||
for tensor, tensor_info in self.tensors_info.items():
|
||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if prev_state is None or tensor_info.state == prev_state:
|
||||
tensor_info.state = next_state
|
||||
|
||||
def access(self) -> None:
|
||||
"""
|
||||
Broadcast the chunk to synchronize the tensors across data parallel processes.
|
||||
"""
|
||||
# recover the chunk on non-owner processes
|
||||
# and broadcast the chunk from the source to all processes
|
||||
if not self.is_src_rank:
|
||||
alloc_storage(self._payload)
|
||||
self.move_device(get_current_device(), update_ptr=False)
|
||||
dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
|
||||
|
||||
# update tensor meta info
|
||||
self._update_tensors_ptr()
|
||||
if not self.is_src_rank:
|
||||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
||||
|
||||
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
|
||||
"""
|
||||
Move the chunk to a target device.
|
||||
|
||||
Args:
|
||||
device (torch.device): the target device for data movement.
|
||||
"""
|
||||
if self._payload.device == device:
|
||||
return
|
||||
if self._cpu_data is None:
|
||||
self.data.data = self.data.to(device)
|
||||
else:
|
||||
if device.type == 'cuda':
|
||||
# cpu -> cuda
|
||||
src = self._cpu_data
|
||||
dest = self.data
|
||||
else:
|
||||
# cuda -> cpu
|
||||
src = self.data
|
||||
dest = self._cpu_data
|
||||
alloc_storage(dest)
|
||||
dest.copy_(src)
|
||||
free_storage(src)
|
||||
|
||||
if update_ptr:
|
||||
self._update_tensors_ptr()
|
||||
|
||||
def reduce(self, is_all_reduce: bool = False) -> None:
|
||||
"""
|
||||
Reduce or all-reduce the chunk.
|
||||
|
||||
Args:
|
||||
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
|
||||
"""
|
||||
self.move_device(get_current_device(), update_ptr=False)
|
||||
if is_all_reduce:
|
||||
dist.all_reduce(self.data, group=self.process_group.dp_process_group())
|
||||
else:
|
||||
dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
|
||||
self._update_tensors_ptr()
|
||||
self._update_tensors_state(TensorState.HOLD)
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
"""
|
||||
Make a transition of the tensor into the next state.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a torch Tensor object.
|
||||
tensor_state (TensorState): the target state for transition.
|
||||
"""
|
||||
|
||||
# As the gradient hook can be triggered either before or after post-backward
|
||||
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
||||
# or compute -> ready_for_reduce -> hold_after_bwd
|
||||
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
|
||||
# this function only apply valid state transformation
|
||||
# invalid calls will be ignored and nothing changes
|
||||
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
||||
# print(
|
||||
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
|
||||
# )
|
||||
return
|
||||
self.tensors_info[tensor].state = tensor_state
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
"""
|
||||
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
"""
|
||||
Check whether the chunk can be released.
|
||||
"""
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state != TensorState.HOLD:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_move_device(self) -> bool:
|
||||
"""
|
||||
Check whether the chunk can be moved across devices.
|
||||
"""
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_reduce(self) -> bool:
|
||||
"""
|
||||
Check whether the chunk can be reduced.
|
||||
"""
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state != TensorState.READY_FOR_REDUCE:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check whether the chunk is empty.
|
||||
"""
|
||||
return is_storage_empty(self._payload)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
"""
|
||||
Check if the chunk has inf or nan values.
|
||||
"""
|
||||
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
|
||||
torch.isnan(self._payload[:self.utilized_size]).any().item()
|
||||
|
||||
def copy_(self, dest_chunk: 'Chunk'):
|
||||
"""
|
||||
Copy the data of this chunk to a destination chunk.
|
||||
"""
|
||||
assert not self.is_empty
|
||||
assert not dest_chunk.is_empty
|
||||
assert self.size == dest_chunk.size
|
||||
assert self.utilized_size == dest_chunk.utilized_size
|
||||
self._payload.copy_(dest_chunk._payload)
|
||||
self._update_tensors_ptr()
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""
|
||||
Get the device type of the chunk.
|
||||
"""
|
||||
return self._payload.device.type
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(id(self))
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self is __o
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
@property
|
||||
def _payload(self) -> torch.Tensor:
|
||||
if self._cpu_data is None or is_storage_empty(self._cpu_data):
|
||||
return self.data
|
||||
return self._cpu_data
|
3
colossalai/gemini/chunk/__init__.py
Normal file
3
colossalai/gemini/chunk/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import clasify_params, search_chunk_configuration
|
@@ -1,14 +1,55 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkFullError, \
|
||||
free_storage, alloc_storage
|
||||
|
||||
|
||||
class ChunkV2:
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
COMPUTE = 1
|
||||
HOLD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||
TensorState.HOLD))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
state: TensorState
|
||||
offset: int
|
||||
end: int
|
||||
|
||||
|
||||
class ChunkFullError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||
return tensor.storage().size() == 0
|
||||
|
||||
|
||||
def free_storage(tensor: torch.Tensor) -> None:
|
||||
if not is_storage_empty(tensor):
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
|
||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
if is_storage_empty(tensor):
|
||||
tensor.storage().resize_(tensor.numel())
|
||||
|
||||
|
||||
class Chunk:
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
@@ -19,18 +60,18 @@ class ChunkV2:
|
||||
pin_memory: bool = False) -> None:
|
||||
"""
|
||||
Chunk: A container owning a piece of contiguous memory space for tensors
|
||||
AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk.
|
||||
This kind of chunk is exclusively used for DDP and ZeRO DDP.
|
||||
Here we use all-gather operation to gather the whole chunk.
|
||||
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
|
||||
It is designed to make the full use of communication and PCIE bandwidth.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in a chunk
|
||||
chunk_size (int): the number of elements in the chunk
|
||||
process_group (ColoProcessGroup): the process group of this chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, the device where the tensor is initialized
|
||||
The default value is None, which is the current GPU
|
||||
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
|
||||
pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory
|
||||
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
|
||||
"""
|
||||
|
||||
self.chunk_size = chunk_size
|
||||
@@ -42,7 +83,8 @@ class ChunkV2:
|
||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||
|
||||
# the chunk size should be able to be divied by the size of GPU
|
||||
assert chunk_size % self.pg_size == 0
|
||||
if not keep_gathered:
|
||||
assert chunk_size % self.pg_size == 0
|
||||
self.shard_size = chunk_size // self.pg_size
|
||||
self.shard_begin = self.shard_size * self.pg_rank
|
||||
self.shard_end = self.shard_begin + self.shard_size
|
||||
@@ -80,18 +122,15 @@ class ChunkV2:
|
||||
|
||||
# we introduce the paired chunk here
|
||||
# it refers to another chunk having the same parameters
|
||||
# but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk
|
||||
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
|
||||
self.paired_chunk = None
|
||||
# if the the gradient of this chunk is reduced, the flag is True
|
||||
# so the flag is False for unused parameters
|
||||
self.grad_reduced_flag = False
|
||||
# if this chunk is synchronized with the optimizer, the flag is True
|
||||
self.optim_sync_flag = True
|
||||
# if the cpu_shard has been visited during the training step, the flag is True
|
||||
self.cpu_vis_flag = False
|
||||
|
||||
@property
|
||||
def memory_usage(self):
|
||||
def memory_usage(self) -> Dict[str, int]:
|
||||
cuda_memory = 0
|
||||
cpu_memory = 0
|
||||
|
||||
@@ -112,7 +151,7 @@ class ChunkV2:
|
||||
return dict(cuda=cuda_memory, cpu=cpu_memory)
|
||||
|
||||
@property
|
||||
def device_type(self):
|
||||
def device_type(self) -> str:
|
||||
if self.chunk_temp is not None:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
@@ -123,6 +162,56 @@ class ChunkV2:
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
@property
|
||||
def payload(self) -> torch.Tensor:
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if self.is_gathered:
|
||||
return self.chunk_total
|
||||
elif self.cuda_shard is not None:
|
||||
return self.cuda_shard
|
||||
else:
|
||||
return self.cpu_shard
|
||||
|
||||
@property
|
||||
def payload_mem(self) -> int:
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if self.is_gathered:
|
||||
return self.chunk_mem
|
||||
else:
|
||||
return self.shard_mem
|
||||
|
||||
@property
|
||||
def can_move(self) -> bool:
|
||||
return not self.is_gathered
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
if self.keep_gathered:
|
||||
return False
|
||||
else:
|
||||
return self.tensors_state_monitor[TensorState.HOLD] + \
|
||||
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
"""Check if the chunk has inf or nan values in CUDA.
|
||||
"""
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.chunk_total[:self.utilized_size]
|
||||
else:
|
||||
assert self.cuda_shard is not None # only check in CUDA
|
||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor):
|
||||
"""Add a tensor to the chunk.
|
||||
|
||||
@@ -150,7 +239,10 @@ class ChunkV2:
|
||||
self.utilized_size = new_utilized_size
|
||||
|
||||
def close_chunk(self, shard_dev: Optional[torch.device] = None):
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk.
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
|
||||
|
||||
Args:
|
||||
shard_dev: the device where the shard locates
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is not None
|
||||
@@ -163,6 +255,7 @@ class ChunkV2:
|
||||
|
||||
if self.chunk_temp.device.type == 'cpu':
|
||||
self.chunk_total = self.chunk_temp.to(get_current_device())
|
||||
self.__update_tensors_ptr()
|
||||
else:
|
||||
self.chunk_total = self.chunk_temp
|
||||
self.chunk_temp = None
|
||||
@@ -186,6 +279,12 @@ class ChunkV2:
|
||||
self.cuda_shard = None
|
||||
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False):
|
||||
"""Move the shard tensor in the chunk.
|
||||
|
||||
Args:
|
||||
device: the device to which the shard will move
|
||||
force_copy: if True, copy function is called mandatorily
|
||||
"""
|
||||
# sanity check
|
||||
assert not self.is_gathered
|
||||
# when the current chunk is not synchronized with the optimizer
|
||||
@@ -223,8 +322,7 @@ class ChunkV2:
|
||||
raise NotImplementedError
|
||||
|
||||
def access_chunk(self):
|
||||
"""Make the chunk usable for the parameters inside it.
|
||||
It is an operation done in CUDA.
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
@@ -234,8 +332,7 @@ class ChunkV2:
|
||||
self.__update_tensors_ptr()
|
||||
|
||||
def release_chunk(self):
|
||||
"""Release the usable chunk.
|
||||
It is an operation done in CUDA.
|
||||
"""Release the usable chunk. It's an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
@@ -244,8 +341,7 @@ class ChunkV2:
|
||||
self.__scatter()
|
||||
|
||||
def reduce(self):
|
||||
"""Reduce scatter all the gradients.
|
||||
It is an operation done in CUDA.
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
@@ -267,7 +363,6 @@ class ChunkV2:
|
||||
free_storage(self.chunk_total)
|
||||
self.is_gathered = False
|
||||
self.__update_tensors_state(TensorState.HOLD)
|
||||
self.grad_reduced_flag = True
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
"""
|
||||
@@ -285,9 +380,6 @@ class ChunkV2:
|
||||
# this function only apply valid state transformation
|
||||
# invalid calls will be ignored and nothing changes
|
||||
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
||||
# print(
|
||||
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
|
||||
# )
|
||||
return
|
||||
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
|
||||
|
||||
@@ -306,46 +398,56 @@ class ChunkV2:
|
||||
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
def can_move(self) -> bool:
|
||||
return not self.is_gathered
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
def get_valid_length(self) -> int:
|
||||
"""Get the valid length of the chunk's payload.
|
||||
"""
|
||||
if self.keep_gathered:
|
||||
return False
|
||||
return self.utilized_size
|
||||
else:
|
||||
return self.tensors_state_monitor[TensorState.HOLD] + \
|
||||
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
return self.valid_end
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
def init_pair(self, friend_chunk: 'Chunk') -> None:
|
||||
"""Initialize the paired chunk.
|
||||
"""
|
||||
Check if the chunk has inf or nan values in CUDA.
|
||||
"""
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.chunk_total[:self.utilized_size]
|
||||
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
|
||||
self.paired_chunk = friend_chunk
|
||||
friend_chunk.paired_chunk = self
|
||||
else:
|
||||
assert self.cuda_shard is not None # only check in CUDA
|
||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||
assert self.paired_chunk is friend_chunk
|
||||
assert friend_chunk.paired_chunk is self
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
def optim_update(self) -> None:
|
||||
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.paired_chunk is not None
|
||||
|
||||
friend_chunk = self.paired_chunk
|
||||
if self.is_gathered is True:
|
||||
assert friend_chunk.is_gathered is True
|
||||
self.chunk_total.copy_(friend_chunk.chunk_total)
|
||||
self.optim_sync_flag = True
|
||||
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
|
||||
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
||||
self.optim_sync_flag = True
|
||||
self.cpu_vis_flag = False
|
||||
else:
|
||||
assert friend_chunk.device_type == 'cpu'
|
||||
assert self.device_type == 'cpu'
|
||||
self.optim_sync_flag = False
|
||||
self.cpu_vis_flag = False
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
def __gather(self):
|
||||
if not self.is_gathered:
|
||||
# sanity check
|
||||
assert self.cuda_shard is not None
|
||||
|
||||
if self.pg_size == 1:
|
||||
self.chunk_total = self.cuda_shard
|
||||
else:
|
||||
alloc_storage(self.chunk_total)
|
||||
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
|
||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||
alloc_storage(self.chunk_total)
|
||||
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
|
||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
@@ -404,9 +506,9 @@ class ChunkV2:
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self is __o
|
||||
|
||||
def __repr__(self, detailed: bool = False):
|
||||
def __repr__(self, detailed: bool = True):
|
||||
output = [
|
||||
"AgChunk Information:\n",
|
||||
"Chunk Information:\n",
|
||||
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
|
||||
self.pg_size),
|
||||
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
|
||||
@@ -442,6 +544,3 @@ class ChunkV2:
|
||||
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
|
||||
|
||||
return ''.join(output)
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
@@ -4,23 +4,19 @@ from collections import deque
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.gemini.chunk import ChunkFullError, TensorState
|
||||
from colossalai.gemini.update import ChunkV2 as Chunk
|
||||
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk
|
||||
|
||||
|
||||
class ChunkManagerV2:
|
||||
class ChunkManager:
|
||||
"""
|
||||
A manager class to manipulate the tensors in chunks.
|
||||
|
||||
Args:
|
||||
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
|
||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration: Dict[int, Dict],
|
||||
init_device: Optional[torch.device] = None,
|
||||
pin_memory: bool = False) -> None:
|
||||
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
|
||||
|
||||
self.device = init_device or get_current_device()
|
||||
self.size_config: Dict[int, int] = dict()
|
||||
@@ -28,7 +24,6 @@ class ChunkManagerV2:
|
||||
for k, v in self.kwargs_config.items():
|
||||
self.size_config[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
v['pin_memory'] = pin_memory
|
||||
|
||||
self.chunk_groups: Dict[str, Deque] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
@@ -36,8 +31,14 @@ class ChunkManagerV2:
|
||||
self.lazy_release_tensors: List[torch.Tensor] = list()
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int) -> None:
|
||||
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
|
||||
"""Append a tensor to a chunk.
|
||||
|
||||
Args:
|
||||
tensor: the tensor appended to the chunk
|
||||
group_type: the data type of the group
|
||||
config_key: the key of the group's name, usually the size of the dp world
|
||||
pin_memory: whether the chunk is pinned in the cpu memory
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||
@@ -66,7 +67,8 @@ class ChunkManagerV2:
|
||||
chunk_size=chunk_size,
|
||||
process_group=tensor.process_group,
|
||||
dtype=tensor.dtype,
|
||||
**chunk_kwargs
|
||||
pin_memory=pin_memory,
|
||||
**chunk_kwargs,
|
||||
)
|
||||
|
||||
chunk_group.append(chunk)
|
||||
@@ -87,6 +89,8 @@ class ChunkManagerV2:
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
if chunk.device_type == 'cpu':
|
||||
chunk.shard_move(get_current_device())
|
||||
chunk.access_chunk()
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
self.accessed_chunks.add(chunk)
|
||||
@@ -102,13 +106,13 @@ class ChunkManagerV2:
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
self.accessed_chunks.remove(chunk)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
"""Move the shard of the chunk to the target device.
|
||||
"""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.shard_move(device)
|
||||
chunk.shard_move(device, force_copy)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
@@ -123,7 +127,7 @@ class ChunkManagerV2:
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.release_chunk()
|
||||
chunk.reduce()
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return True
|
||||
|
||||
@@ -165,14 +169,14 @@ class ChunkManagerV2:
|
||||
self.release_chunk(chunk)
|
||||
self.lazy_release_tensors.clear()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = ['Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n']
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
def get_cuda_movable_chunks(self, group_type: str) -> List[Chunk]:
|
||||
chunk_list = []
|
||||
for group_name in self.chunk_groups:
|
||||
if group_type in group_name:
|
||||
for chunk in self.chunk_groups[group_name]:
|
||||
if chunk.device_type == 'cuda' and chunk.can_move:
|
||||
chunk_list.append(chunk)
|
||||
return chunk_list
|
||||
|
||||
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||
"""
|
||||
@@ -200,6 +204,17 @@ class ChunkManagerV2:
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = [
|
||||
'Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
]
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
@@ -208,8 +223,9 @@ class ChunkManagerV2:
|
||||
return self.chunk_groups[group_name]
|
||||
|
||||
def __close_one_chunk(self, chunk: Chunk):
|
||||
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.close_chunk(self.device)
|
||||
chunk.close_chunk(device)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def __sub_memroy_usage(self, usage: Dict[str, int]):
|
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
@@ -7,7 +8,7 @@ from colossalai.tensor import ColoParameter
|
||||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
||||
"""Filter those parameters whose size is too large from others.
|
||||
"""
|
||||
params_size = [p.numel() for p in model.parameters()]
|
||||
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params_size_arr = np.array(params_size)
|
||||
|
||||
std = np.std(params_size_arr)
|
||||
@@ -36,6 +37,9 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
||||
params_dict: Dict[int, List[ColoParameter]] = dict()
|
||||
for param in model.parameters():
|
||||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if getattr(param, '_ddp_to_ignore', False):
|
||||
continue
|
||||
|
||||
param_key = param.process_group.dp_world_size()
|
||||
|
||||
if param_key not in params_dict:
|
||||
@@ -47,13 +51,13 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
||||
|
||||
def search_chunk_configuration(
|
||||
model: nn.Module,
|
||||
search_range_mb: int,
|
||||
search_range_mb: float,
|
||||
search_interval_byte: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_mb: int = 32,
|
||||
filter_exlarge_params: bool = True):
|
||||
search_range_byte = search_range_mb * 1024**2
|
||||
min_chunk_size_byte = min_chunk_size_mb * 1024**2
|
||||
assert search_range_byte % search_interval_byte == 0
|
||||
min_chunk_size_mb: float = 32,
|
||||
filter_exlarge_params: bool = True) -> Dict:
|
||||
search_range_byte = round(search_range_mb * 1024**2)
|
||||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
||||
assert search_range_byte >= 0
|
||||
|
||||
params_dict = clasify_params(model)
|
||||
config_dict: Dict[int, Dict] = dict()
|
||||
@@ -75,11 +79,12 @@ def search_chunk_configuration(
|
||||
max_size = min_chunk_size_byte
|
||||
for key in size_dict:
|
||||
max_size = max(max_size, max(size_dict[key]))
|
||||
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
|
||||
|
||||
min_chunk_waste = float('+inf')
|
||||
best_chunk_size = max_size
|
||||
best_chunk_size = start_size
|
||||
|
||||
for chunk_size in range(max_size, max_size + search_range_byte + 1, search_interval_byte):
|
||||
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
|
||||
temp_waste = 0
|
||||
for key in size_dict:
|
||||
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
@@ -1,344 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||
from collections import deque
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
"""
|
||||
A manager class to manipulate the tensors in chunks.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the size of a chunk.
|
||||
process_group (ColoProcessGroup): process group of the chunk.
|
||||
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
|
||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: Optional[int],
|
||||
process_group: ColoProcessGroup,
|
||||
enable_distributed_storage: bool = False,
|
||||
init_device: Optional[torch.device] = None) -> None:
|
||||
assert chunk_size is None or chunk_size > 0
|
||||
assert isinstance(process_group, ColoProcessGroup)
|
||||
self.chunk_size = chunk_size
|
||||
self.process_group = process_group
|
||||
self.enable_distributed_storage = enable_distributed_storage
|
||||
self.device = init_device or get_current_device()
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
|
||||
self.groups_force_data_on_cuda: Dict[str, bool] = {}
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.lazy_release_tensors: List[torch.Tensor] = []
|
||||
if enable_distributed_storage and chunk_size is None:
|
||||
self.rank_load: Dict[str, torch.Tensor] = {}
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None:
|
||||
"""Create a chunk group.
|
||||
|
||||
Args:
|
||||
group_name (str): group name
|
||||
force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False.
|
||||
"""
|
||||
assert group_name not in self.chunk_groups
|
||||
self.chunk_groups[group_name] = deque()
|
||||
self.groups_force_data_on_cuda[group_name] = force_data_on_cuda
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
|
||||
"""
|
||||
Append a tensor to a chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a tensor to append to the chunk.
|
||||
group_name (str): the name of the chunk group.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
if isinstance(tensor, ColoTensor):
|
||||
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
|
||||
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
|
||||
try:
|
||||
# append the tensor to the last chunk
|
||||
self.chunk_groups[group_name][-1].append(tensor)
|
||||
except (IndexError, ChunkFullError):
|
||||
# the except statement will be triggered when there is no chunk or
|
||||
# the last chunk in the chunk group is full
|
||||
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
else:
|
||||
chunk_size = self.chunk_size or tensor.numel()
|
||||
src_rank = self._get_next_src_rank(group_name)
|
||||
chunk = Chunk(chunk_size,
|
||||
src_rank,
|
||||
self.process_group,
|
||||
tensor.dtype,
|
||||
self.device,
|
||||
force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
|
||||
|
||||
if self.enable_distributed_storage and self.chunk_size is None:
|
||||
self.rank_load[group_name][src_rank] += chunk_size
|
||||
|
||||
self.chunk_groups[group_name].append(chunk)
|
||||
chunk.append(tensor)
|
||||
if not chunk.is_empty:
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
|
||||
if not self.enable_distributed_storage:
|
||||
# as distributed storage is not enabled, there is no need to broadcast
|
||||
# chunks, thus we set these chunks as accessed
|
||||
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
|
||||
|
||||
def _get_next_src_rank(self, group_name: str) -> int:
|
||||
if not self.enable_distributed_storage:
|
||||
# the chunk is owned by the current rank if no distributed storage is enabled
|
||||
return self.process_group.dp_local_rank()
|
||||
if self.chunk_size is None:
|
||||
if group_name not in self.rank_load:
|
||||
self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64)
|
||||
|
||||
# the process owning the tensor will be the process with the smallest number of elements
|
||||
src_rank = torch.argmin(self.rank_load[group_name]).item()
|
||||
else:
|
||||
# chunk is owned by processes in a round-robin fashion
|
||||
chunk_idx = len(self.chunk_groups[group_name])
|
||||
src_rank = chunk_idx % self.process_group.dp_world_size()
|
||||
return src_rank
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
"""
|
||||
Synchronize the chunks via broadcast.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): the chunk to synchronize.
|
||||
"""
|
||||
if chunk in self.accessed_chunks:
|
||||
if chunk.device_type != 'cuda':
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.move_device(get_current_device())
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
return
|
||||
if not chunk.is_empty:
|
||||
# as tensor is moved to the target device
|
||||
# the memory consumption of the original device is reduced
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.access()
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""
|
||||
Release the memory space of a chunk.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): the chunk to release memory space
|
||||
"""
|
||||
|
||||
if not self.enable_distributed_storage:
|
||||
return
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
chunk.release()
|
||||
self.accessed_chunks.remove(chunk)
|
||||
if chunk.is_empty:
|
||||
# update the memory consumption after releasing
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None:
|
||||
"""
|
||||
Move the chunk to the target device.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): the chunk to move to target device
|
||||
device (torch.device): target device
|
||||
"""
|
||||
if chunk.device_type == device.type:
|
||||
return
|
||||
if chunk.can_move_device and not chunk.is_empty:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.move_device(device, update_ptr=update_ptr)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
"""
|
||||
Transit tensor state according to pre-defined state machine.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor for state transititon
|
||||
state (TensorState): next tensor state for transtition
|
||||
"""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
"""
|
||||
Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used.
|
||||
Otherwise, this method uses reduce.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): the chunk for reduction.
|
||||
"""
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
return True
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||
"""
|
||||
Copy data to the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||
data (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||
|
||||
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||
"""
|
||||
Return the chunk owning the tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a torch tensor object
|
||||
"""
|
||||
return self.tensor_chunk_map[tensor]
|
||||
|
||||
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
|
||||
"""
|
||||
Add tensors to the buffer for lazy release.
|
||||
|
||||
Args:
|
||||
tensors (List[torch.Tensor]): the tensors to be released lazily
|
||||
"""
|
||||
self.lazy_release_tensors.extend(tensors)
|
||||
|
||||
def exec_lazy_release(self) -> None:
|
||||
"""
|
||||
Execute release for tensors added to the lazy release buffer.
|
||||
"""
|
||||
|
||||
for chunk in self.get_chunks(self.lazy_release_tensors):
|
||||
self.release_chunk(chunk)
|
||||
self.lazy_release_tensors.clear()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = f'Rank {self.process_group.dp_local_rank()}:\n'
|
||||
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg += f'Group {group_name}:\n'
|
||||
for i, chunk in enumerate(group):
|
||||
msg += f'[{i}] {chunk}\n'
|
||||
return msg
|
||||
|
||||
@staticmethod
|
||||
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
|
||||
"""
|
||||
Calculate the utilization rate of a chunk.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the size of a chunk
|
||||
params_numel (List[int]): the list of integers representing the number of elements of parameters
|
||||
"""
|
||||
assert len(params_numel) > 0
|
||||
total_size = 0
|
||||
total_utilized_size = 0
|
||||
cur_chunk_utilized_size = 0
|
||||
for size in params_numel:
|
||||
assert chunk_size >= size
|
||||
total_utilized_size += size
|
||||
if total_size == 0 or cur_chunk_utilized_size + size > chunk_size:
|
||||
total_size += chunk_size
|
||||
cur_chunk_utilized_size = 0
|
||||
cur_chunk_utilized_size += size
|
||||
return total_utilized_size / total_size
|
||||
|
||||
@staticmethod
|
||||
def search_chunk_size(module: torch.nn.Module,
|
||||
search_range: int,
|
||||
n_grids: int,
|
||||
min_chunk_size: Optional[int] = None,
|
||||
filter_exlarge_params: bool = True) -> int:
|
||||
"""
|
||||
Search for the chunk size for optimal chunk utilization.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): a torch module object
|
||||
search_range (int): the range of chunk size to search. The actual search range will be from
|
||||
max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range.
|
||||
n_grids (int): the number of intervals in the search range
|
||||
min_chunk_size (int): optional, the minimum size for a chunk. The default is None.
|
||||
|
||||
"""
|
||||
assert search_range % n_grids == 0
|
||||
# TODO(ver217): sort params and filter unused ones
|
||||
params_numel = [p.numel() for p in module.parameters()]
|
||||
if filter_exlarge_params:
|
||||
params_numel = _filter_exlarge_params(params_numel)
|
||||
max_param_numel = max(params_numel)
|
||||
if min_chunk_size is not None:
|
||||
assert min_chunk_size >= max_param_numel
|
||||
else:
|
||||
min_chunk_size = max_param_numel
|
||||
step_size = search_range // n_grids
|
||||
max_chunk_util = -1
|
||||
best_chunk_size = -1
|
||||
for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size):
|
||||
chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel)
|
||||
if chunk_util > max_chunk_util:
|
||||
max_chunk_util = chunk_util
|
||||
best_chunk_size = chunk_size
|
||||
return best_chunk_size
|
||||
|
||||
def copy_chunk_group(self, dest_group_name: str, src_group_name: str):
|
||||
"""
|
||||
Copy chunk data from one group to another group.
|
||||
|
||||
Args:
|
||||
dest_group_name (str): the destination group which receives the copied data
|
||||
src_group_name (str): the source group which provides the data to copy
|
||||
"""
|
||||
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
|
||||
if not dest_chunk.is_empty:
|
||||
dest_chunk.copy_(src_chunk)
|
||||
|
||||
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||
"""
|
||||
Get all chunks owning the input tensors.
|
||||
|
||||
Args:
|
||||
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
|
||||
"""
|
||||
chunks = []
|
||||
for tensor in tensors:
|
||||
chunk = self.get_chunk(tensor)
|
||||
if chunk not in chunks:
|
||||
chunks.append(chunk)
|
||||
return tuple(chunks)
|
||||
|
||||
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
|
||||
"""Add extern static tensor to chunk manager.
|
||||
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
|
||||
They are "static", which means their shape, dtype, device never change.
|
||||
Thus, their memory usage never changes.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
||||
|
||||
def _filter_exlarge_params(params_numel: List[int]) -> List[int]:
|
||||
params_numel_arr = np.array(params_numel)
|
||||
std = np.std(params_numel_arr)
|
||||
mean = np.mean(params_numel_arr)
|
||||
upper_limit = mean + 3 * std
|
||||
return list(filter(lambda x: x <= upper_limit, params_numel))
|
@@ -3,7 +3,7 @@ import functools
|
||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from typing import List, Optional, Tuple
|
||||
from time import time
|
||||
from colossalai.gemini import Chunk, ChunkManager
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
|
||||
|
||||
@@ -56,37 +56,44 @@ class GeminiManager:
|
||||
self._evict_time = 0
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None:
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_type: str) -> None:
|
||||
""" Adjust the layout of statefuil tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
start = time()
|
||||
self._record_chunks_order(chunks)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_type)
|
||||
self._layout_time += time() - start
|
||||
vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
|
||||
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
|
||||
self._d2h_volume += vol
|
||||
self._evict_time += evict_time
|
||||
# move COMPUTE tensors to CUDA
|
||||
self._h2d_volume += cuda_demand
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_type: str):
|
||||
start = time()
|
||||
cuda_demand = 0
|
||||
for chunk in chunks:
|
||||
if chunk.device_type == 'cpu' or chunk.is_empty:
|
||||
cuda_demand += chunk.mem
|
||||
if chunk.device_type == 'cuda':
|
||||
if chunk.is_gathered:
|
||||
pass
|
||||
else:
|
||||
cuda_demand += chunk.chunk_mem - chunk.shard_mem
|
||||
elif chunk.device_type == 'cpu':
|
||||
cuda_demand += chunk.chunk_mem
|
||||
else:
|
||||
raise RuntimeError
|
||||
self._comp_cuda_demand_time += time() - start
|
||||
can_evict_chunks = []
|
||||
for chunk in self._chunk_manager.chunk_groups[group_name]:
|
||||
if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device:
|
||||
can_evict_chunks.append(chunk)
|
||||
|
||||
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks(group_type)
|
||||
return cuda_demand, can_evict_chunks
|
||||
|
||||
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
|
@@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.gemini import ChunkManager
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
|
||||
import torch
|
||||
import time
|
||||
|
@@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from typing import Type
|
||||
import functools
|
||||
from colossalai.gemini import Chunk, ChunkManager
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
|
||||
|
||||
class PlacementPolicy(ABC):
|
||||
@@ -19,7 +19,7 @@ class PlacementPolicy(ABC):
|
||||
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector
|
||||
|
||||
@abstractmethod
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> None:
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@@ -32,12 +32,12 @@ class CPUPlacementPolicy(PlacementPolicy):
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
volume = 0
|
||||
start = time()
|
||||
for chunk in can_evict_chunks:
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
|
||||
volume += chunk.mem
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
volume += chunk.shard_mem
|
||||
return volume, time() - start
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class CUDAPlacementPolicy(PlacementPolicy):
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
@@ -59,7 +59,8 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
|
||||
need_mem_stats: bool = True
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
_warmup_non_model_data_ratio: float = 0.8
|
||||
_steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
@@ -70,14 +71,14 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
can_evict_chunks: List[Chunk],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: List[Tuple[Chunk, ...]] = [],
|
||||
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> int:
|
||||
**kwargs) -> Tuple[int, float]:
|
||||
"""
|
||||
Evict tensors from CUDA device.
|
||||
|
||||
Args:
|
||||
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
|
||||
can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted.
|
||||
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
|
||||
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
|
||||
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
|
||||
@@ -114,12 +115,12 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
for chunk in to_free_chunks:
|
||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += chunk.mem
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
|
||||
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
freed_cuda_model_data += chunk.shard_mem
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
|
||||
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}")
|
||||
return freed_cuda_model_data, time() - start
|
||||
|
||||
@staticmethod
|
||||
@@ -147,7 +148,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, PlacementPolicy] = {
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
'cpu': CPUPlacementPolicy,
|
||||
'cuda': CUDAPlacementPolicy,
|
||||
'auto': AutoPlacementPolicy
|
||||
|
@@ -1,131 +0,0 @@
|
||||
import queue
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, Dict
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
def evict_check(st: StatefulTensor) -> bool:
|
||||
if st.state is not TensorState.COMPUTE and st.device.type == 'cuda':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Here ST means Stateful Tensor
|
||||
class BaseSTContainer(ABC):
|
||||
"""A type of container that store all potential stateful tensors which can be evicted from
|
||||
CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
|
||||
evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
|
||||
memory, meaning its state isn't COMPUTE.
|
||||
|
||||
This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
|
||||
And it pops stateful tensors in function, `evict_tensors`.
|
||||
|
||||
In order to acquire an optimal eviction policy, users may need to offer computation step
|
||||
index of each stateful tensor. So we can use a heap to maintain all potential evictable
|
||||
statefule tensors. When poping, we can get the stateful tensor that used furthest in
|
||||
current computation step.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
self.compute_step_dict = compute_step_dict
|
||||
self.total_step = total_step
|
||||
|
||||
@abstractmethod
|
||||
def empty(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
pass
|
||||
|
||||
|
||||
class QueueSTContainer(BaseSTContainer):
|
||||
"""Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
|
||||
It pops potential evictable stateful tensors in FIFO.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
super().__init__(compute_step_dict, total_step)
|
||||
self.container = None
|
||||
|
||||
def empty(self) -> bool:
|
||||
assert self.container is not None
|
||||
return self.container.empty()
|
||||
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
self.container = queue.SimpleQueue()
|
||||
for stateful_tensor in stateful_tensor_list:
|
||||
self.container.put(stateful_tensor)
|
||||
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
self.container.put(stateful_tensor)
|
||||
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
ret = None
|
||||
while not self.empty():
|
||||
out_tensor = self.container.get()
|
||||
if evict_check(out_tensor):
|
||||
ret = out_tensor
|
||||
break
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class HeapSTContainer(BaseSTContainer):
|
||||
"""Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
|
||||
It pops potential evictable stateful tensors in the order of the distance between current
|
||||
step and next used step.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
super().__init__(compute_step_dict, total_step)
|
||||
self.container = None
|
||||
|
||||
def empty(self) -> bool:
|
||||
assert self.container is not None
|
||||
return self.container == []
|
||||
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
self.container = []
|
||||
for stateful_tensor in stateful_tensor_list:
|
||||
# we want to pop the tensor which has the greatest next_step
|
||||
# so the weight is next_step multiplied by -1
|
||||
weight = -self.__get_next_compute_step(stateful_tensor, -1)
|
||||
self.container.append((weight, stateful_tensor))
|
||||
heapq.heapify(self.container)
|
||||
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
# we want to pop the tensor which has the greatest next_step
|
||||
# so the weight is next_step multiplied by -1
|
||||
weight = -self.__get_next_compute_step(stateful_tensor, cur_step)
|
||||
heapq.heappush(self.container, (weight, stateful_tensor))
|
||||
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
ret = None
|
||||
while not self.empty():
|
||||
_, out_tensor = heapq.heappop(self.container)
|
||||
if evict_check(out_tensor):
|
||||
ret = out_tensor
|
||||
break
|
||||
return ret
|
||||
|
||||
def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int):
|
||||
# compute the id of next step
|
||||
# if the tensor is not used in the furture
|
||||
# next_step is set to the maximum
|
||||
next_step = self.total_step
|
||||
step_list = self.compute_step_dict[stateful_tensor]
|
||||
for step in step_list:
|
||||
if step > cur_step:
|
||||
next_step = step
|
||||
break
|
||||
return next_step
|
@@ -1,3 +0,0 @@
|
||||
from .chunkv2 import ChunkV2
|
||||
from .chunk_mgrv2 import ChunkManagerV2
|
||||
from .search_utils import clasify_params, search_chunk_configuration
|
@@ -3,16 +3,18 @@ import itertools
|
||||
import torch.distributed as dist
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.gemini.chunk import TensorState, Chunk
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from .reducer import Reducer
|
||||
|
||||
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
@@ -208,28 +210,34 @@ class ZeroDDP(ColoDDP):
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False) -> None:
|
||||
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.fp32_params: List[ColoParameter] = []
|
||||
self.fp32_params: List[ColoTensor] = []
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
|
||||
self.chunk_manager.create_group('fp32_param')
|
||||
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert isinstance(p, ColoParameter)
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
p.data = p.half()
|
||||
continue
|
||||
fp32_p = p.float().detach()
|
||||
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
fp32_data = p.float().data
|
||||
p.data = p.half()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
self._cast_buffers()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@@ -248,10 +256,7 @@ class ZeroDDP(ColoDDP):
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
p.grad = p.data
|
||||
p.grad = None
|
||||
|
||||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
@@ -276,21 +281,22 @@ class ZeroDDP(ColoDDP):
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
if self.dp_world_size > 1:
|
||||
grad = grad / self.dp_world_size
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if reduced and not chunk.is_empty:
|
||||
if reduced:
|
||||
if chunk.is_gathered:
|
||||
chunk.chunk_total.div_(chunk.pg_size)
|
||||
else:
|
||||
chunk.cuda_shard.div_(chunk.pg_size)
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
@@ -311,14 +317,11 @@ class ZeroDDP(ColoDDP):
|
||||
['bias', 'weight']
|
||||
|
||||
"""
|
||||
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
||||
record_flag = (not only_rank_0) or is_rank_0
|
||||
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, record_flag)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
||||
|
||||
for hook in self._state_dict_hooks.values():
|
||||
hook_result = hook(self, destination, prefix, local_metadata)
|
||||
@@ -326,7 +329,7 @@ class ZeroDDP(ColoDDP):
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True):
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
@@ -339,30 +342,30 @@ class ZeroDDP(ColoDDP):
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
# save parameters
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
# record the original device of the chunk
|
||||
org_chunk_dev_typ = chunk.device_type
|
||||
self.chunk_manager.access_chunk(chunk)
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor in chunk.get_tensors():
|
||||
rec_p = torch.empty([0])
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
rec_p = tensor.cpu() # move the whole tensor to CPU mem
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in param_to_save_data
|
||||
param_to_save_data[tensor] = rec_p
|
||||
# release the actual memory of the chunk
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if not chunk.is_empty and org_chunk_dev_typ == 'cpu':
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
param_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
rec_p = param_to_save_data[fp32_p]
|
||||
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
destination[prefix + name] = record_parameter
|
||||
|
||||
# save all buffers
|
||||
for name, buf in self.named_buffers():
|
||||
@@ -466,40 +469,61 @@ class ZeroDDP(ColoDDP):
|
||||
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
def load(name, dest_tensor, copy_func):
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
def load(param_name, dest_tensor, copy_func):
|
||||
state_key = prefix + param_name
|
||||
if state_key in state_dict:
|
||||
input_param = state_dict[state_key]
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
if input_param.shape != dest_tensor.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(key, input_param.shape,
|
||||
'the shape in current model is {}.'.format(state_key, input_param.shape,
|
||||
dest_tensor.shape))
|
||||
return
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param)
|
||||
copy_func(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(),
|
||||
ex.args))
|
||||
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
|
||||
input_param.size(), ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
missing_keys.append(state_key)
|
||||
|
||||
def load_fp32_p(fp32_p, data):
|
||||
if fp32_p.storage().size() > 0:
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data)
|
||||
def load_fp32_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
fp32_to_name = dict()
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
load(name, fp32_p, partial(load_fp32_p, fp32_p))
|
||||
self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.chunk_total.copy_(temp_chunk)
|
||||
elif chunk.cuda_shard is not None:
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
else:
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.optim_update()
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
|
20
colossalai/nn/parallel/utils.py
Normal file
20
colossalai/nn/parallel/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||
if chunk.is_gathered:
|
||||
return chunk.chunk_total
|
||||
|
||||
if chunk.cuda_shard is not None:
|
||||
shard_temp = chunk.cuda_shard
|
||||
else:
|
||||
shard_temp = chunk.cpu_shard.to(get_current_device())
|
||||
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device())
|
||||
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
|
||||
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
|
||||
|
||||
return total_temp
|
@@ -54,8 +54,8 @@ class ZeROHookV2(ParamOpHook):
|
||||
|
||||
@contextmanager
|
||||
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
||||
old_training_phase = self._training_phase
|
||||
try:
|
||||
old_training_phase = self._training_phase
|
||||
self._training_phase = training_phase
|
||||
yield
|
||||
finally:
|
||||
|
@@ -2,17 +2,14 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn import Parameter
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from typing import Dict
|
||||
from typing import Dict, Tuple, Set
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm
|
||||
from collections import defaultdict, abc as container_abcs
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from torch._six import inf
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
@@ -33,8 +30,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
Args:
|
||||
optim (Optimizer): An Optimizer instance.
|
||||
module (ZeroDDP): A ``ZeroDDP`` instance.
|
||||
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
|
||||
Defaults to 0.0.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
@@ -61,11 +58,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
assert isinstance(module, ZeroDDP)
|
||||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager = self.gemini_manager.chunk_manager
|
||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
|
||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
|
||||
for p, fp32_p in zip(module.parameters(), module.fp32_params):
|
||||
self.fp16_param_to_fp32_param[p] = fp32_p
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
if chunk_16 not in self.chunk16_set:
|
||||
self.chunk16_set.add(chunk_16)
|
||||
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
@@ -75,7 +81,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
@@ -90,16 +96,26 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
self._register_states = disposable(self._register_states_)
|
||||
|
||||
def _update_params_ptr(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if not self.module.chunk_manager.get_chunk(p).is_empty:
|
||||
p.data = self.fp16_param_to_fp32_param[p]
|
||||
else:
|
||||
assert p.grad is None
|
||||
def _set_grad_ptr(self):
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
chunk16 = chunk32.paired_chunk
|
||||
|
||||
fake_param.data = chunk16.payload[begin:end]
|
||||
fake_param.grad = fake_param.data
|
||||
fake_param.data = chunk32.payload[begin:end]
|
||||
|
||||
def _update_fp16_params(self):
|
||||
self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
||||
none_tensor = torch.empty([0])
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
assert fake_param.grad is None
|
||||
fake_param.data = none_tensor
|
||||
|
||||
for chunk16 in self.chunk16_set:
|
||||
chunk16.optim_update()
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
@@ -128,6 +144,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_params()
|
||||
self._set_grad_ptr()
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
@@ -138,45 +155,14 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
return
|
||||
self._update_params_ptr()
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
self._register_states()
|
||||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
def compute_grad_norm(self, norm_type: float = 2.0) -> float:
|
||||
norm_type = float(norm_type)
|
||||
if not self.chunk_manager.enable_distributed_storage:
|
||||
return compute_grad_norm(self.module.parameters(), norm_type)
|
||||
|
||||
non_distributed_params = []
|
||||
distributed_params = []
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
non_distributed_params.append(p)
|
||||
else:
|
||||
distributed_params.append(p)
|
||||
non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type)
|
||||
distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)],
|
||||
device=get_current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(distributed_norm_tensor,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=self.chunk_manager.process_group.dp_process_group())
|
||||
total_norm = max(non_distributed_norm, distributed_norm_tensor.item())
|
||||
else:
|
||||
dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group())
|
||||
total_norm = non_distributed_norm + distributed_norm_tensor.item()
|
||||
total_norm = total_norm**(1 / norm_type)
|
||||
return total_norm
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
total_norm = self.compute_grad_norm(norm_type)
|
||||
_clip_grad_norm(self.module.parameters(), max_norm, total_norm)
|
||||
return total_norm
|
||||
raise NotImplementedError
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
@@ -197,24 +183,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
|
||||
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||
fp32_params_used_cuda_margin_mem = 0
|
||||
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
|
||||
self.chunk_manager.chunk_groups['fp32_param']):
|
||||
if fp32_param_chunk.is_empty:
|
||||
continue
|
||||
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device())
|
||||
self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem
|
||||
for p in fp16_param_chunk.get_tensors():
|
||||
state = self.optim.state[p]
|
||||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
chunk16 = chunk32.paired_chunk
|
||||
|
||||
if chunk32.device_type == 'cuda':
|
||||
continue
|
||||
|
||||
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(chunk32, get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(chunk16, get_current_device())
|
||||
self.module.set_chunk_grad_device(chunk16, get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
|
||||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
if chunk32.device_type == 'cuda':
|
||||
state = self.optim.state[fake_param]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(get_current_device())
|
||||
|
||||
self.module._setup_grads_ptr()
|
||||
|
||||
def _register_states_(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
@@ -223,110 +216,27 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
if isinstance(val, torch.Tensor):
|
||||
self.chunk_manager.add_extern_static_tensor(val)
|
||||
|
||||
def state_dict(self, only_rank_0: bool = True):
|
||||
r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None.
|
||||
This saves memory usage.
|
||||
def __init__optimizer(self):
|
||||
|
||||
It contains two entries:
|
||||
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
||||
param_info = local_chunk.tensors_info[local_param]
|
||||
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
||||
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
||||
return begin, end
|
||||
|
||||
* state - a dict holding current optimization state. Its content
|
||||
differs between optimizer classes.
|
||||
* param_groups - a list containing all parameter groups where each
|
||||
parameter group is a dict
|
||||
"""
|
||||
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
||||
if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0:
|
||||
return
|
||||
optim_state_dict = super().state_dict()
|
||||
scaler_state_dict = self.grad_scaler.state_dict()
|
||||
optim_state_dict['scaler'] = scaler_state_dict
|
||||
if not self.chunk_manager.enable_distributed_storage:
|
||||
return optim_state_dict
|
||||
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0}
|
||||
if not self.chunk_manager.process_group.has_cpu_groups:
|
||||
self.chunk_manager.process_group.set_cpu_groups()
|
||||
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())]
|
||||
if only_rank_0:
|
||||
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
|
||||
dist.gather_object(local_state,
|
||||
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
|
||||
dst=dst_rank,
|
||||
group=self.chunk_manager.process_group.cpu_dp_process_group())
|
||||
if not is_rank_0:
|
||||
return
|
||||
else:
|
||||
dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group())
|
||||
for state in output:
|
||||
optim_state_dict['state'].update(state)
|
||||
return optim_state_dict
|
||||
for group in self.optim.param_groups:
|
||||
fake_params_list = list()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""Loads the optimizer state.
|
||||
for param in group['params']:
|
||||
chunk16 = self.chunk_manager.get_chunk(param)
|
||||
range_pair = get_range_pair(chunk16, param)
|
||||
if range_pair[0] >= range_pair[1]:
|
||||
continue
|
||||
|
||||
Args:
|
||||
state_dict (dict): optimizer state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if 'scaler' not in state_dict:
|
||||
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
|
||||
else:
|
||||
self.grad_scaler.load_state_dict(deepcopy(state_dict['scaler']))
|
||||
fake_param = torch.nn.Parameter(torch.empty([0]))
|
||||
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
||||
self.param_to_range[fake_param] = range_pair
|
||||
|
||||
# Validate the state_dict
|
||||
groups = self.param_groups
|
||||
saved_groups = deepcopy(state_dict['param_groups'])
|
||||
fake_params_list.append(fake_param)
|
||||
|
||||
if len(groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of "
|
||||
"parameter groups")
|
||||
param_lens = (len(g['params']) for g in groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Update the state
|
||||
id_map = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in groups)))
|
||||
}
|
||||
|
||||
def cast(param, value):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
state = defaultdict(dict)
|
||||
for k, v in state_dict['state'].items():
|
||||
if k in id_map:
|
||||
param = self.fp16_param_to_fp32_param[id_map[k]]
|
||||
if param.storage().size() > 0:
|
||||
state[param] = cast(param, deepcopy(v))
|
||||
else:
|
||||
state[k] = deepcopy(v)
|
||||
|
||||
# Update parameter groups, setting their 'params' value
|
||||
def update_group(group, new_group):
|
||||
new_group['params'] = group['params']
|
||||
return new_group
|
||||
|
||||
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
||||
self.__setstate__({'state': state, 'param_groups': param_groups})
|
||||
|
||||
|
||||
def convert_state_dict_to_cpu(state: Dict[str, torch.Tensor]):
|
||||
return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in state.items()}
|
||||
group['params'] = fake_params_list
|
||||
|
Reference in New Issue
Block a user