mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,4 +3,4 @@ from .manager import ChunkManager
|
||||
from .search_utils import classify_params_by_dp_degree, search_chunk_configuration
|
||||
from .utils import init_chunk_manager
|
||||
|
||||
__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager']
|
||||
__all__ = ["Chunk", "ChunkManager", "classify_params_by_dp_degree", "search_chunk_configuration", "init_chunk_manager"]
|
||||
|
@@ -17,12 +17,17 @@ class TensorState(Enum):
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE,
|
||||
TensorState.HOLD),
|
||||
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||
TensorState.HOLD))
|
||||
STATE_TRANS = (
|
||||
(TensorState.FREE, TensorState.HOLD),
|
||||
(TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE),
|
||||
(TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD),
|
||||
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
|
||||
(TensorState.READY_FOR_REDUCE, TensorState.HOLD),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,14 +58,16 @@ def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
class Chunk:
|
||||
_total_number = 0
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
process_group: ProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
cpu_shard_init: bool = False,
|
||||
keep_gathered: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
process_group: ProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
cpu_shard_init: bool = False,
|
||||
keep_gathered: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Chunk: A container owning a piece of contiguous memory space for tensors
|
||||
Here we use all-gather operation to gather the whole chunk.
|
||||
@@ -99,9 +106,9 @@ class Chunk:
|
||||
device = init_device or get_current_device()
|
||||
|
||||
# chunk_temp is a global chunk, which only exists during building the chunks.
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||
|
||||
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
|
||||
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
|
||||
|
||||
# cuda local chunk, which is sharded on GPUs
|
||||
self.cuda_shard = None
|
||||
@@ -134,7 +141,7 @@ class Chunk:
|
||||
# they are treated the same as that of the parameters in DDP during training.
|
||||
self.keep_gathered = keep_gathered
|
||||
if self.keep_gathered:
|
||||
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
||||
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
||||
|
||||
# if pin_memory is True, we allocate a piece of CPU pin-memory
|
||||
# for it all the time
|
||||
@@ -160,7 +167,7 @@ class Chunk:
|
||||
|
||||
if self.chunk_temp is not None:
|
||||
# this chunk is not closed
|
||||
if self.chunk_temp.device.type == 'cuda':
|
||||
if self.chunk_temp.device.type == "cuda":
|
||||
cuda_memory += self.chunk_mem
|
||||
else:
|
||||
cpu_memory += self.chunk_mem
|
||||
@@ -180,11 +187,11 @@ class Chunk:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
if self.is_gathered:
|
||||
return 'cuda'
|
||||
return "cuda"
|
||||
elif self.cuda_shard is not None:
|
||||
return 'cuda'
|
||||
return "cuda"
|
||||
else:
|
||||
return 'cpu'
|
||||
return "cpu"
|
||||
|
||||
@property
|
||||
def payload(self) -> torch.Tensor:
|
||||
@@ -217,8 +224,10 @@ class Chunk:
|
||||
if self.keep_gathered:
|
||||
return False
|
||||
else:
|
||||
return self.tensor_state_cnter[TensorState.HOLD] + \
|
||||
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
return (
|
||||
self.tensor_state_cnter[TensorState.HOLD] + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD]
|
||||
== self.num_tensors
|
||||
)
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
@@ -226,27 +235,25 @@ class Chunk:
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
"""Check if the chunk has inf or nan values on CUDA.
|
||||
"""
|
||||
"""Check if the chunk has inf or nan values on CUDA."""
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
|
||||
valid_tensor = self.cuda_global_chunk[: self.utilized_size]
|
||||
else:
|
||||
assert self.cuda_shard is not None # only check on CUDA
|
||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||
assert self.cuda_shard is not None # only check on CUDA
|
||||
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
|
||||
def set_l2_norm(self) -> None:
|
||||
"""Record l2 norm of this chunks on CUDA.
|
||||
"""
|
||||
"""Record l2 norm of this chunks on CUDA."""
|
||||
assert self.l2_norm is None, "you are calculating the l2 norm twice"
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
|
||||
valid_tensor = self.cuda_global_chunk[: self.utilized_size]
|
||||
else:
|
||||
assert self.cuda_shard is not None # calculate on CUDA
|
||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||
assert self.cuda_shard is not None # calculate on CUDA
|
||||
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||
chunk_l2_norm = valid_tensor.data.float().norm(2)
|
||||
self.l2_norm = chunk_l2_norm.item()**2
|
||||
self.l2_norm = chunk_l2_norm.item() ** 2
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor):
|
||||
"""Add a tensor to the chunk.
|
||||
@@ -263,9 +270,9 @@ class Chunk:
|
||||
if new_utilized_size > self.chunk_size:
|
||||
raise ChunkFullError
|
||||
|
||||
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
|
||||
self.chunk_temp[self.utilized_size : new_utilized_size].copy_(tensor.data.flatten())
|
||||
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
tensor.data = self.chunk_temp[self.utilized_size : new_utilized_size].view(tensor.shape)
|
||||
|
||||
# record all the information about the tensor
|
||||
self.num_tensors += 1
|
||||
@@ -275,8 +282,7 @@ class Chunk:
|
||||
self.utilized_size = new_utilized_size
|
||||
|
||||
def close_chunk(self):
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
|
||||
"""
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk later."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is not None
|
||||
|
||||
@@ -286,7 +292,7 @@ class Chunk:
|
||||
elif self.utilized_size < self.shard_end:
|
||||
self.valid_end = self.utilized_size - self.shard_begin
|
||||
|
||||
if self.chunk_temp.device.type == 'cpu':
|
||||
if self.chunk_temp.device.type == "cpu":
|
||||
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
|
||||
self.__update_tensors_ptr()
|
||||
else:
|
||||
@@ -298,12 +304,12 @@ class Chunk:
|
||||
if self.keep_gathered:
|
||||
return
|
||||
|
||||
if self.pin_memory or self.shard_device.type == 'cpu':
|
||||
if self.pin_memory or self.shard_device.type == "cpu":
|
||||
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
|
||||
self.cpu_shard.copy_(self.cuda_shard)
|
||||
self.cpu_vis_flag = True # cpu_shard has been visited
|
||||
self.cpu_vis_flag = True # cpu_shard has been visited
|
||||
|
||||
if self.shard_device.type == 'cpu':
|
||||
if self.shard_device.type == "cpu":
|
||||
self.cuda_shard = None
|
||||
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False):
|
||||
@@ -318,12 +324,12 @@ class Chunk:
|
||||
# when the current chunk is not synchronized with the optimizer
|
||||
# just use another way for the movement
|
||||
if not self.optim_sync_flag:
|
||||
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
|
||||
assert device.type == "cuda", "each chunk should first be moved to CUDA"
|
||||
self.__paired_shard_move()
|
||||
self.optim_sync_flag = True
|
||||
return
|
||||
|
||||
if device.type == 'cuda':
|
||||
if device.type == "cuda":
|
||||
assert device == get_current_device(), "can't move chunk to another device"
|
||||
|
||||
if self.cuda_shard:
|
||||
@@ -333,7 +339,7 @@ class Chunk:
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
elif device.type == 'cpu':
|
||||
elif device.type == "cpu":
|
||||
if self.cuda_shard is None:
|
||||
return
|
||||
|
||||
@@ -350,8 +356,7 @@ class Chunk:
|
||||
raise NotImplementedError
|
||||
|
||||
def access_chunk(self):
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
|
||||
"""
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
@@ -360,8 +365,7 @@ class Chunk:
|
||||
self.__update_tensors_ptr()
|
||||
|
||||
def release_chunk(self):
|
||||
"""Release the usable chunk. It's an operation done in CUDA.
|
||||
"""
|
||||
"""Release the usable chunk. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
@@ -369,8 +373,7 @@ class Chunk:
|
||||
self.__scatter()
|
||||
|
||||
def reduce(self):
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA.
|
||||
"""
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
||||
@@ -423,20 +426,18 @@ class Chunk:
|
||||
assert self.is_gathered
|
||||
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
|
||||
|
||||
def get_valid_length(self) -> int:
|
||||
"""Get the valid length of the chunk's payload.
|
||||
"""
|
||||
"""Get the valid length of the chunk's payload."""
|
||||
if self.keep_gathered:
|
||||
return self.utilized_size
|
||||
else:
|
||||
return self.valid_end
|
||||
|
||||
def init_pair(self, friend_chunk: 'Chunk') -> None:
|
||||
"""Initialize the paired chunk.
|
||||
"""
|
||||
def init_pair(self, friend_chunk: "Chunk") -> None:
|
||||
"""Initialize the paired chunk."""
|
||||
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
|
||||
self.paired_chunk = friend_chunk
|
||||
friend_chunk.paired_chunk = self
|
||||
@@ -445,8 +446,7 @@ class Chunk:
|
||||
assert friend_chunk.paired_chunk is self
|
||||
|
||||
def optim_update(self) -> None:
|
||||
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
|
||||
"""
|
||||
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer."""
|
||||
# sanity check
|
||||
assert self.paired_chunk is not None
|
||||
|
||||
@@ -455,15 +455,15 @@ class Chunk:
|
||||
assert friend_chunk.is_gathered is True
|
||||
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
|
||||
self.optim_sync_flag = True
|
||||
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
|
||||
elif friend_chunk.device_type == "cuda" and self.device_type == "cuda":
|
||||
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
||||
self.optim_sync_flag = True
|
||||
self.cpu_vis_flag = False
|
||||
else:
|
||||
# optim_sync_flag is set to False
|
||||
# see shard_move function for more details
|
||||
assert friend_chunk.device_type == 'cpu'
|
||||
assert self.device_type == 'cpu'
|
||||
assert friend_chunk.device_type == "cpu"
|
||||
assert self.device_type == "cpu"
|
||||
self.optim_sync_flag = False
|
||||
self.cpu_vis_flag = False
|
||||
|
||||
@@ -492,7 +492,7 @@ class Chunk:
|
||||
|
||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device)
|
||||
|
||||
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end])
|
||||
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin : self.shard_end])
|
||||
|
||||
free_storage(self.cuda_global_chunk)
|
||||
self.is_gathered = False
|
||||
@@ -518,7 +518,7 @@ class Chunk:
|
||||
assert type(self.cuda_global_chunk) == torch.Tensor
|
||||
|
||||
for tensor, tensor_info in self.tensors_info.items():
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
|
||||
|
||||
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
|
||||
self.tensor_state_cnter[tensor_info.state] -= 1
|
||||
@@ -539,38 +539,41 @@ class Chunk:
|
||||
def __repr__(self, detailed: bool = True):
|
||||
output = [
|
||||
"Chunk Information:\n",
|
||||
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
|
||||
self.pg_size),
|
||||
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
|
||||
self.chunk_size, self.dtype, self.pg_size
|
||||
),
|
||||
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
|
||||
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
|
||||
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size
|
||||
),
|
||||
]
|
||||
|
||||
def print_tensor(tensor, prefix=''):
|
||||
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
|
||||
tensor.device))
|
||||
def print_tensor(tensor, prefix=""):
|
||||
output.append(
|
||||
"{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, tensor.device)
|
||||
)
|
||||
|
||||
if self.chunk_temp is not None:
|
||||
output.append("\tchunk temp:\n")
|
||||
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
|
||||
print_tensor(tensor=self.chunk_temp, prefix="\t\t")
|
||||
|
||||
if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0:
|
||||
output.append("\tchunk total:\n")
|
||||
print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t')
|
||||
print_tensor(tensor=self.cuda_global_chunk, prefix="\t\t")
|
||||
|
||||
if self.cuda_shard is not None:
|
||||
output.append("\tcuda shard:\n")
|
||||
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
|
||||
print_tensor(tensor=self.cuda_shard, prefix="\t\t")
|
||||
|
||||
if self.cpu_shard is not None:
|
||||
output.append("\tcpu shard:\n")
|
||||
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
|
||||
print_tensor(tensor=self.cpu_shard, prefix="\t\t")
|
||||
|
||||
memory_info = self.memory_usage
|
||||
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
|
||||
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info["cuda"], memory_info["cpu"]))
|
||||
|
||||
if detailed:
|
||||
output.append("\ttensor state monitor:\n")
|
||||
for st in TensorState:
|
||||
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
|
||||
|
||||
return ''.join(output)
|
||||
return "".join(output)
|
||||
|
@@ -20,27 +20,28 @@ class ChunkManager:
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||
|
||||
self.device = init_device or get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size")
|
||||
v["init_device"] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
|
||||
|
||||
def register_tensor(self,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
def register_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Register a tensor to the chunk manager.
|
||||
Then, the tensor should be accessed by `get_chunks`.
|
||||
@@ -94,25 +95,22 @@ class ChunkManager:
|
||||
self.tensor_chunk_map[tensor] = chunk_group[-1]
|
||||
|
||||
def close_all_groups(self):
|
||||
"""Close all the chunks of all groups.
|
||||
"""
|
||||
"""Close all the chunks of all groups."""
|
||||
for group_name in self.chunk_groups:
|
||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
"""Make the chunk can be used for calculation.
|
||||
"""
|
||||
"""Make the chunk can be used for calculation."""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == 'cpu':
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""Scatter the chunk in CUDA.
|
||||
"""
|
||||
"""Scatter the chunk in CUDA."""
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
@@ -121,8 +119,7 @@ class ChunkManager:
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
"""Move the shard of the chunk to the target device.
|
||||
"""
|
||||
"""Move the shard of the chunk to the target device."""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
@@ -130,14 +127,12 @@ class ChunkManager:
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
"""Transit tensor state according to pre-defined state machine.
|
||||
"""
|
||||
"""Transit tensor state according to pre-defined state machine."""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Reduce or all reduce the chunk.
|
||||
"""
|
||||
"""Reduce or all reduce the chunk."""
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
@@ -213,18 +208,17 @@ class ChunkManager:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = [
|
||||
'Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
"Chunk Manager Information:\n",
|
||||
"Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n",
|
||||
]
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
msg.append(f"Group {group_name}:\n")
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
msg.append(f"[{i}] {chunk}\n")
|
||||
return "".join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
"""Register a chunk group."""
|
||||
if group_name not in self.chunk_groups:
|
||||
self.chunk_groups[group_name] = deque()
|
||||
return self.chunk_groups[group_name]
|
||||
|
@@ -76,8 +76,9 @@ def _tensor_numel(local_param: ColoParameter) -> int:
|
||||
return local_param.numel()
|
||||
|
||||
|
||||
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
|
||||
def classify_params_by_dp_degree(
|
||||
param_order: OrderedParamGenerator, process_group: ProcessGroup
|
||||
) -> Dict[int, List[ColoParameter]]:
|
||||
"""classify_params_by_dp_degree
|
||||
|
||||
Classify the parameters by their dp degree
|
||||
@@ -105,14 +106,15 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
|
||||
|
||||
def search_chunk_configuration(
|
||||
model: nn.Module,
|
||||
search_range_m: float,
|
||||
search_interval: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_m: float = 32,
|
||||
filter_exlarge_params: bool = True,
|
||||
strict_ddp_flag: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
||||
model: nn.Module,
|
||||
search_range_m: float,
|
||||
search_interval: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_m: float = 32,
|
||||
filter_exlarge_params: bool = True,
|
||||
strict_ddp_flag: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstas: Optional[MemStats] = None,
|
||||
) -> Tuple[Dict, int, int]:
|
||||
"""search_chunk_configuration
|
||||
|
||||
Search the chunk configuration for a model.
|
||||
@@ -168,7 +170,7 @@ def search_chunk_configuration(
|
||||
max_size = max(max_size, max(size_dict[key]))
|
||||
start_size = int(math.ceil(max_size / search_interval) * search_interval)
|
||||
|
||||
min_chunk_waste = float('+inf')
|
||||
min_chunk_waste = float("+inf")
|
||||
best_chunk_size = start_size
|
||||
|
||||
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
|
||||
|
@@ -5,8 +5,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import search_chunk_configuration
|
||||
|
||||
@@ -17,15 +15,17 @@ def safe_div(a, b):
|
||||
return a / b
|
||||
|
||||
|
||||
def init_chunk_manager(model: nn.Module,
|
||||
init_device: Optional[torch.device] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs) -> ChunkManager:
|
||||
def init_chunk_manager(
|
||||
model: nn.Module,
|
||||
init_device: Optional[torch.device] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> ChunkManager:
|
||||
if hidden_dim:
|
||||
search_interval = hidden_dim
|
||||
else:
|
||||
search_interval = 1024 # defaults to 1024
|
||||
search_interval = 1024 # defaults to 1024
|
||||
kwargs["search_interval"] = search_interval
|
||||
|
||||
dist.barrier()
|
||||
@@ -41,11 +41,13 @@ def init_chunk_manager(model: nn.Module,
|
||||
wasted_size /= mega_unit
|
||||
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||
sep='',
|
||||
flush=True)
|
||||
print(
|
||||
"searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||
sep="",
|
||||
flush=True,
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
chunk_manager = ChunkManager(config_dict, init_device)
|
||||
|
Reference in New Issue
Block a user