[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -10,6 +10,13 @@ from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper
__all__ = [
'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
"GeminiDDP",
"GeminiOptimizer",
"GeminiAdamOptimizer",
"zero_model_wrapper",
"zero_optim_wrapper",
"LowLevelZeroOptimizer",
"ColoInitContext",
"post_process_colo_init_ctx",
"get_static_torch_model",
]

View File

@@ -6,6 +6,15 @@ from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
from .utils import get_static_torch_model
__all__ = [
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
"GeminiManager",
"TensorInfo",
"TensorState",
"ChunkManager",
"search_chunk_configuration",
"GeminiDDP",
"get_static_torch_model",
"GeminiAdamOptimizer",
"GeminiOptimizer",
"ColoInitContext",
"post_process_colo_init_ctx",
]

View File

@@ -3,4 +3,4 @@ from .manager import ChunkManager
from .search_utils import classify_params_by_dp_degree, search_chunk_configuration
from .utils import init_chunk_manager
__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager']
__all__ = ["Chunk", "ChunkManager", "classify_params_by_dp_degree", "search_chunk_configuration", "init_chunk_manager"]

View File

@@ -17,12 +17,17 @@ class TensorState(Enum):
READY_FOR_REDUCE = 4
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE,
TensorState.HOLD),
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
STATE_TRANS = (
(TensorState.FREE, TensorState.HOLD),
(TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE),
(TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD),
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
(TensorState.READY_FOR_REDUCE, TensorState.HOLD),
)
@dataclass
@@ -53,14 +58,16 @@ def alloc_storage(tensor: torch.Tensor) -> None:
class Chunk:
_total_number = 0
def __init__(self,
chunk_size: int,
process_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False) -> None:
def __init__(
self,
chunk_size: int,
process_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False,
) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Here we use all-gather operation to gather the whole chunk.
@@ -99,9 +106,9 @@ class Chunk:
device = init_device or get_current_device()
# chunk_temp is a global chunk, which only exists during building the chunks.
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
# cuda local chunk, which is sharded on GPUs
self.cuda_shard = None
@@ -134,7 +141,7 @@ class Chunk:
# they are treated the same as that of the parameters in DDP during training.
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
pin_memory = False # since this chunk is gathered, it doesn't need to pin
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
@@ -160,7 +167,7 @@ class Chunk:
if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == 'cuda':
if self.chunk_temp.device.type == "cuda":
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
@@ -180,11 +187,11 @@ class Chunk:
return self.chunk_temp.device.type
else:
if self.is_gathered:
return 'cuda'
return "cuda"
elif self.cuda_shard is not None:
return 'cuda'
return "cuda"
else:
return 'cpu'
return "cpu"
@property
def payload(self) -> torch.Tensor:
@@ -217,8 +224,10 @@ class Chunk:
if self.keep_gathered:
return False
else:
return self.tensor_state_cnter[TensorState.HOLD] + \
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
return (
self.tensor_state_cnter[TensorState.HOLD] + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD]
== self.num_tensors
)
@property
def can_reduce(self):
@@ -226,27 +235,25 @@ class Chunk:
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values on CUDA.
"""
"""Check if the chunk has inf or nan values on CUDA."""
if self.is_gathered:
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
valid_tensor = self.cuda_global_chunk[: self.utilized_size]
else:
assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[: self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def set_l2_norm(self) -> None:
"""Record l2 norm of this chunks on CUDA.
"""
"""Record l2 norm of this chunks on CUDA."""
assert self.l2_norm is None, "you are calculating the l2 norm twice"
if self.is_gathered:
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
valid_tensor = self.cuda_global_chunk[: self.utilized_size]
else:
assert self.cuda_shard is not None # calculate on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
assert self.cuda_shard is not None # calculate on CUDA
valid_tensor = self.cuda_shard[: self.valid_end]
chunk_l2_norm = valid_tensor.data.float().norm(2)
self.l2_norm = chunk_l2_norm.item()**2
self.l2_norm = chunk_l2_norm.item() ** 2
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
@@ -263,9 +270,9 @@ class Chunk:
if new_utilized_size > self.chunk_size:
raise ChunkFullError
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
self.chunk_temp[self.utilized_size : new_utilized_size].copy_(tensor.data.flatten())
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
tensor.data = self.chunk_temp[self.utilized_size : new_utilized_size].view(tensor.shape)
# record all the information about the tensor
self.num_tensors += 1
@@ -275,8 +282,7 @@ class Chunk:
self.utilized_size = new_utilized_size
def close_chunk(self):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
"""
"""Close the chunk. Any tensor can't be appended to a closed chunk later."""
# sanity check
assert self.chunk_temp is not None
@@ -286,7 +292,7 @@ class Chunk:
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == 'cpu':
if self.chunk_temp.device.type == "cpu":
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
self.__update_tensors_ptr()
else:
@@ -298,12 +304,12 @@ class Chunk:
if self.keep_gathered:
return
if self.pin_memory or self.shard_device.type == 'cpu':
if self.pin_memory or self.shard_device.type == "cpu":
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
self.cpu_vis_flag = True # cpu_shard has been visited
if self.shard_device.type == 'cpu':
if self.shard_device.type == "cpu":
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
@@ -318,12 +324,12 @@ class Chunk:
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
assert device.type == "cuda", "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return
if device.type == 'cuda':
if device.type == "cuda":
assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
@@ -333,7 +339,7 @@ class Chunk:
if not self.pin_memory:
self.cpu_shard = None
elif device.type == 'cpu':
elif device.type == "cpu":
if self.cuda_shard is None:
return
@@ -350,8 +356,7 @@ class Chunk:
raise NotImplementedError
def access_chunk(self):
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
"""
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
# sanity check
assert self.chunk_temp is None
@@ -360,8 +365,7 @@ class Chunk:
self.__update_tensors_ptr()
def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA.
"""
"""Release the usable chunk. It's an operation done in CUDA."""
# sanity check
assert self.chunk_temp is None
@@ -369,8 +373,7 @@ class Chunk:
self.__scatter()
def reduce(self):
"""Reduce scatter all the gradients. It's an operation done in CUDA.
"""
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
# sanity check
assert self.is_gathered
@@ -423,20 +426,18 @@ class Chunk:
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.
"""
"""Get the valid length of the chunk's payload."""
if self.keep_gathered:
return self.utilized_size
else:
return self.valid_end
def init_pair(self, friend_chunk: 'Chunk') -> None:
"""Initialize the paired chunk.
"""
def init_pair(self, friend_chunk: "Chunk") -> None:
"""Initialize the paired chunk."""
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
@@ -445,8 +446,7 @@ class Chunk:
assert friend_chunk.paired_chunk is self
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer."""
# sanity check
assert self.paired_chunk is not None
@@ -455,15 +455,15 @@ class Chunk:
assert friend_chunk.is_gathered is True
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
elif friend_chunk.device_type == "cuda" and self.device_type == "cuda":
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
else:
# optim_sync_flag is set to False
# see shard_move function for more details
assert friend_chunk.device_type == 'cpu'
assert self.device_type == 'cpu'
assert friend_chunk.device_type == "cpu"
assert self.device_type == "cpu"
self.optim_sync_flag = False
self.cpu_vis_flag = False
@@ -492,7 +492,7 @@ class Chunk:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device)
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end])
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin : self.shard_end])
free_storage(self.cuda_global_chunk)
self.is_gathered = False
@@ -518,7 +518,7 @@ class Chunk:
assert type(self.cuda_global_chunk) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensor_state_cnter[tensor_info.state] -= 1
@@ -539,38 +539,41 @@ class Chunk:
def __repr__(self, detailed: bool = True):
output = [
"Chunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
self.pg_size),
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
self.chunk_size, self.dtype, self.pg_size
),
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size
),
]
def print_tensor(tensor, prefix=''):
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
tensor.device))
def print_tensor(tensor, prefix=""):
output.append(
"{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, tensor.device)
)
if self.chunk_temp is not None:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
print_tensor(tensor=self.chunk_temp, prefix="\t\t")
if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t')
print_tensor(tensor=self.cuda_global_chunk, prefix="\t\t")
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
print_tensor(tensor=self.cuda_shard, prefix="\t\t")
if self.cpu_shard is not None:
output.append("\tcpu shard:\n")
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
print_tensor(tensor=self.cpu_shard, prefix="\t\t")
memory_info = self.memory_usage
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info["cuda"], memory_info["cpu"]))
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
return ''.join(output)
return "".join(output)

View File

@@ -20,27 +20,28 @@ class ChunkManager:
"""
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items():
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device
self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size")
v["init_device"] = self.device
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
def register_tensor(self,
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
def register_tensor(
self,
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
cpu_offload: bool = False,
pin_memory: bool = False,
) -> None:
"""
Register a tensor to the chunk manager.
Then, the tensor should be accessed by `get_chunks`.
@@ -94,25 +95,22 @@ class ChunkManager:
self.tensor_chunk_map[tensor] = chunk_group[-1]
def close_all_groups(self):
"""Close all the chunks of all groups.
"""
"""Close all the chunks of all groups."""
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None:
"""Make the chunk can be used for calculation.
"""
"""Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks:
return
self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == 'cpu':
if chunk.device_type == "cpu":
chunk.shard_move(get_current_device())
self.__add_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.
"""
"""Scatter the chunk in CUDA."""
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
@@ -121,8 +119,7 @@ class ChunkManager:
self.__add_memory_usage(chunk.memory_usage)
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
"""Move the shard of the chunk to the target device.
"""
"""Move the shard of the chunk to the target device."""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memory_usage(chunk.memory_usage)
@@ -130,14 +127,12 @@ class ChunkManager:
self.__add_memory_usage(chunk.memory_usage)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
"""Transit tensor state according to pre-defined state machine.
"""
"""Transit tensor state according to pre-defined state machine."""
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool:
"""Reduce or all reduce the chunk.
"""
"""Reduce or all reduce the chunk."""
if not chunk.can_reduce:
return False
self.__sub_memory_usage(chunk.memory_usage)
@@ -213,18 +208,17 @@ class ChunkManager:
def __repr__(self) -> str:
msg = [
'Chunk Manager Information:\n',
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
"Chunk Manager Information:\n",
"Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n",
]
for group_name, group in self.chunk_groups.items():
msg.append(f'Group {group_name}:\n')
msg.append(f"Group {group_name}:\n")
for i, chunk in enumerate(group):
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
msg.append(f"[{i}] {chunk}\n")
return "".join(msg)
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
"""Register a chunk group.
"""
"""Register a chunk group."""
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
return self.chunk_groups[group_name]

View File

@@ -76,8 +76,9 @@ def _tensor_numel(local_param: ColoParameter) -> int:
return local_param.numel()
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
def classify_params_by_dp_degree(
param_order: OrderedParamGenerator, process_group: ProcessGroup
) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
@@ -105,14 +106,15 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
def search_chunk_configuration(
model: nn.Module,
search_range_m: float,
search_interval: int, # hidden size is the best value for the interval
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
process_group: Optional[ProcessGroup] = None,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
model: nn.Module,
search_range_m: float,
search_interval: int, # hidden size is the best value for the interval
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
process_group: Optional[ProcessGroup] = None,
memstas: Optional[MemStats] = None,
) -> Tuple[Dict, int, int]:
"""search_chunk_configuration
Search the chunk configuration for a model.
@@ -168,7 +170,7 @@ def search_chunk_configuration(
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval) * search_interval)
min_chunk_waste = float('+inf')
min_chunk_waste = float("+inf")
best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):

View File

@@ -5,8 +5,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.utils import is_ddp_ignored
from .manager import ChunkManager
from .search_utils import search_chunk_configuration
@@ -17,15 +15,17 @@ def safe_div(a, b):
return a / b
def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
verbose: bool = False,
**kwargs) -> ChunkManager:
def init_chunk_manager(
model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
verbose: bool = False,
**kwargs,
) -> ChunkManager:
if hidden_dim:
search_interval = hidden_dim
else:
search_interval = 1024 # defaults to 1024
search_interval = 1024 # defaults to 1024
kwargs["search_interval"] = search_interval
dist.barrier()
@@ -41,11 +41,13 @@ def init_chunk_manager(model: nn.Module,
wasted_size /= mega_unit
if verbose and dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)
print(
"searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep="",
flush=True,
)
dist.barrier()
chunk_manager = ChunkManager(config_dict, init_device)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterator, Optional, Tuple, Union
from typing import Any, Iterator, Optional, Tuple, Union
import torch
from torch import nn
@@ -12,7 +12,7 @@ from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
def _named_params_with_replica(
module: nn.Module,
prefix: str = '',
prefix: str = "",
recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
@@ -21,16 +21,17 @@ def _named_params_with_replica(
for name, val in mod._parameters.items():
if val is None:
continue
name = mod_prefix + ('.' if mod_prefix else '') + name
name = mod_prefix + ("." if mod_prefix else "") + name
yield name, val
def _convert_to_coloparam(param: torch.nn.Parameter,
device: torch.device,
dtype=torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec: Optional[Any] = None) -> ColoParameter:
def _convert_to_coloparam(
param: torch.nn.Parameter,
device: torch.device,
dtype=torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec: Optional[Any] = None,
) -> ColoParameter:
if type(param) is ColoParameter:
return param
# detaching tensor is necessary for optimizers.
@@ -66,12 +67,13 @@ def ColoModulize(module):
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec=None):
def __init__(
self,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec=None,
):
"""
Args:
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
@@ -89,6 +91,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def _register_colo_modules(self):
from colossalai.legacy.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding())
@@ -105,25 +108,25 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
if type(param) is ColoParameter:
continue
split = name.rfind('.')
if split >= 0: # param in submodule
split = name.rfind(".")
if split >= 0: # param in submodule
module_name = name[:split]
param_name = name[split + 1:]
param_name = name[split + 1 :]
else:
module_name = '' # param in current module
module_name = "" # param in current module
param_name = name
name_list.append((module_name, param_name))
replaced_tensors = dict(
) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
replaced_tensors = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
for module_name, param_name in name_list:
submodule = module.get_submodule(module_name)
param = submodule.get_parameter(param_name)
if param in replaced_tensors:
colo_param = replaced_tensors[param]
else:
colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg,
self._default_dist_spec)
colo_param = _convert_to_coloparam(
param, self._device, self._dtype, self._default_pg, self._default_dist_spec
)
replaced_tensors[param] = colo_param
delattr(submodule, param_name)
setattr(submodule, param_name, colo_param)
@@ -136,11 +139,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
for param in module.parameters():
param_number += 1
meta_param_number += (param.device.type == 'meta')
meta_param_number += param.device.type == "meta"
for buffer in module.buffers():
buffer_number += 1
meta_buffer_number += (buffer.device.type == 'meta')
meta_buffer_number += buffer.device.type == "meta"
if meta_param_number > 0 and meta_param_number != param_number:
raise ValueError("Meta parameters and valued parameters can not be in the same model")
@@ -152,11 +155,13 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
buffer.data = buffer.data.to(device=self._device)
def post_process_colo_init_ctx(model: torch.nn.Module,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec=None):
def post_process_colo_init_ctx(
model: torch.nn.Module,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec=None,
):
"""post_process_colo_init_ctx
This function is called after `ColoInitContext`.
@@ -178,8 +183,8 @@ def post_process_colo_init_ctx(model: torch.nn.Module,
# print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter")
torch_params.append((n, p))
for (n, param) in torch_params:
name_list = n.split('.')
for n, param in torch_params:
name_list = n.split(".")
module = model
for i in range(len(name_list) - 1):
module = module._modules[name_list[i]]

View File

@@ -10,7 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import StateDictSharder, calculate_tensor_size
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
@@ -27,10 +27,10 @@ from .utils import get_temp_total_chunk_on_cuda
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
__all__ = [
'GeminiDDP',
"GeminiDDP",
]
@@ -54,27 +54,28 @@ class GeminiDDP(ModelWrapper):
"""
def __init__(
self,
module: torch.nn.Module,
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device('cpu'),
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
search_range_m: int = 32, # chunk search options
hidden_dim: Optional[int] = None, # chunk search options
min_chunk_size_m: float = 32, # chunk search options
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
verbose: bool = False) -> None:
self,
module: torch.nn.Module,
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
search_range_m: int = 32, # chunk search options
hidden_dim: Optional[int] = None, # chunk search options
min_chunk_size_m: float = 32, # chunk search options
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
@@ -82,22 +83,26 @@ class GeminiDDP(ModelWrapper):
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32
self.chunk_manager = init_chunk_manager(model=module,
init_device=chunk_init_device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
verbose=verbose)
self.gemini_manager = GeminiManager(placement_policy,
self.chunk_manager,
memstats,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
self.chunk_manager = init_chunk_manager(
model=module,
init_device=chunk_init_device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
verbose=verbose,
)
self.gemini_manager = GeminiManager(
placement_policy,
self.chunk_manager,
memstats,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
@@ -126,13 +131,15 @@ class GeminiDDP(ModelWrapper):
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
param_name = m_name + "." + p_name if m_name else p_name
self.name2param[param_name] = p_var
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
self._init_chunks(
param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != "cuda",
pin_memory=pin_memory,
)
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()
@@ -146,19 +153,18 @@ class GeminiDDP(ModelWrapper):
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
def named_parameters(self, prefix: str = "", recurse: bool = True):
return self.module.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
def named_buffers(self, prefix: str = "", recurse: bool = True):
return self.module.named_buffers(prefix, recurse)
def named_children(self):
return self.module.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
def named_modules(
self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
return self.module.named_modules(memo, prefix, remove_duplicate)
@staticmethod
@@ -184,11 +190,9 @@ class GeminiDDP(ModelWrapper):
# as save/load state dict is overwrited, only return self
return self
def _get_non_persistent_buffers_set(self,
module,
memo: Optional[Set[nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
def _get_non_persistent_buffers_set(
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
@@ -204,19 +208,20 @@ class GeminiDDP(ModelWrapper):
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
remove_duplicate)
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = self._get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self):
"""This function is only triggered for inference.
"""
"""This function is only triggered for inference."""
access_list = list(self.chunk_manager.accessed_chunks)
# we need to scatter all accessed chunks and move them to their original places
for chunk in access_list:
@@ -233,7 +238,8 @@ class GeminiDDP(ModelWrapper):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
assert (
not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup()
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
@@ -250,8 +256,7 @@ class GeminiDDP(ModelWrapper):
return outputs
def _inference_forward(self, *args, **kwargs):
"""This function is only triggered for inference.
"""
"""This function is only triggered for inference."""
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
if not self.scatter_after_inference:
# gather all chunks
@@ -287,12 +292,14 @@ class GeminiDDP(ModelWrapper):
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
f"{error_str}")
raise RuntimeError(
"ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
f"{error_str}",
)
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
)
self.gemini_manager.post_iter()
@@ -314,8 +321,10 @@ class GeminiDDP(ModelWrapper):
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter.")
raise RuntimeError(
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter."
)
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
@@ -339,12 +348,9 @@ class GeminiDDP(ModelWrapper):
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(self,
destination=None,
prefix='',
keep_vars=False,
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16):
def state_dict(
self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16
):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
@@ -391,7 +397,7 @@ class GeminiDDP(ModelWrapper):
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu()
assert tensor not in chunk_to_save_data
chunk_to_save_data[tensor] = record_tensor
@@ -399,8 +405,9 @@ class GeminiDDP(ModelWrapper):
del temp_chunk
return chunk_to_save_data
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool,
dtype: torch.dtype) -> Dict:
def _get_param_to_save_data(
self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype
) -> Dict:
"""
get param content from chunks.
@@ -459,11 +466,13 @@ class GeminiDDP(ModelWrapper):
destination[prefix + name] = buf if keep_vars else buf.detach()
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
if (
getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
destination[extra_state_key] = self.get_extra_state()
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
@@ -491,32 +500,38 @@ class GeminiDDP(ModelWrapper):
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
state_dict._metadata = metadata # type: ignore[attr-defined]
prefix = ''
prefix = ""
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys)))
0,
"Unexpected key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in unexpected_keys)
),
)
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
)
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs))
)
return _IncompatibleKeys(missing_keys, unexpected_keys)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
@@ -564,19 +579,21 @@ class GeminiDDP(ModelWrapper):
input_param = input_param[0]
if input_param.shape != dest_tensor.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(state_key, input_param.shape,
dest_tensor.shape))
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(state_key, input_param.shape, dest_tensor.shape)
)
return
try:
with torch.no_grad():
copy_func(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
input_param.size(), ex.args))
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(state_key, dest_tensor.size(), input_param.size(), ex.args)
)
elif strict:
missing_keys.append(state_key)
@@ -600,15 +617,15 @@ class GeminiDDP(ModelWrapper):
for tensor, tensor_info in chunk.tensors_info.items():
parameter_name = fp32_to_name[tensor]
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
else:
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
del temp_chunk
@@ -622,8 +639,10 @@ class GeminiDDP(ModelWrapper):
load(name, buf, buf.copy_)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state",
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
if (
getattr(self.__class__, "set_extra_state", torch.nn.Module.set_extra_state)
is not torch.nn.Module.set_extra_state
):
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
@@ -634,7 +653,7 @@ class GeminiDDP(ModelWrapper):
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = key[len(prefix) :]
if input_name not in local_state:
unexpected_keys.append(key)
@@ -659,18 +678,22 @@ class GeminiDDP(ModelWrapper):
p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(
tensor=p,
group_type="fp16_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
@@ -694,7 +717,7 @@ class GeminiDDP(ModelWrapper):
if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision)
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, "LazyTensor"]) -> None:
"""Convert parameter to ColoParameter in-place.
Args:
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
@@ -709,12 +732,14 @@ class GeminiDDP(ModelWrapper):
p.__class__ = ColoParameter
p.__init__(p, requires_grad=requires_grad)
def state_dict_shard(self,
prefix: str = '',
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]:
def state_dict_shard(
self,
prefix: str = "",
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16,
) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Both parameters and persistent buffers (e.g. running averages) are included.
@@ -770,8 +795,10 @@ class GeminiDDP(ModelWrapper):
yield block, block_size
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
if (
getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = self.get_extra_state()
block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:

View File

@@ -17,7 +17,6 @@ class TrainingPhase(Enum):
class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()
self._gemini_manager = gemini_manager
@@ -40,7 +39,11 @@ class GeminiZeROHook(ColoParamOpHook):
def post_op(self, params):
params = [p for p in params if not is_ddp_ignored(p)]
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
tensor_state = (
TensorState.HOLD
if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad
else TensorState.HOLD_AFTER_BWD
)
self._chunk_manager.trans_tensor_state(p, tensor_state)
def pre_forward(self, params: List[torch.Tensor]) -> None:

View File

@@ -26,12 +26,13 @@ class GeminiManager:
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""
def __init__(self,
placement_policy: str,
chunk_manager: ChunkManager,
memstats: Optional[MemStats] = None,
**placement_kwargs) -> None:
def __init__(
self,
placement_policy: str,
chunk_manager: ChunkManager,
memstats: Optional[MemStats] = None,
**placement_kwargs,
) -> None:
assert placement_policy in PlacementPolicyFactory.get_policy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy)
@@ -39,8 +40,9 @@ class GeminiManager:
self._premade_memstats_ = memstats is not None
self._memstats = memstats
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
self._memstats) if policy_cls.need_mem_stats else None
self._mem_stats_collector = (
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
)
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
@@ -62,7 +64,7 @@ class GeminiManager:
@property
def need_warmup(self) -> bool:
return self.policy_name in ('auto', 'const')
return self.policy_name in ("auto", "const")
def is_warmup(self):
return self._warmup
@@ -85,15 +87,14 @@ class GeminiManager:
self._mem_stats_collector.start_collection()
def post_iter(self):
"""This function must be called when each iteration finishes
"""
"""This function must be called when each iteration finishes"""
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.finish_collection()
self._warmup = False
self.reset_attributes()
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of stateful tensors according to the information provided
"""Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
@@ -102,11 +103,13 @@ class GeminiManager:
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
compute_idx=self._compute_idx)
vol, evict_time = self._placement_policy.evict_tensors(
can_evict_chunks=hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
compute_idx=self._compute_idx,
)
self._d2h_volume += vol
self._evict_time += evict_time
@@ -118,12 +121,12 @@ class GeminiManager:
start = time()
cuda_demand = 0
for chunk in chunks:
if chunk.device_type == 'cuda':
if chunk.device_type == "cuda":
if chunk.is_gathered:
pass
else:
cuda_demand += chunk.chunk_mem - chunk.shard_mem
elif chunk.device_type == 'cpu':
elif chunk.device_type == "cpu":
cuda_demand += chunk.chunk_mem
else:
raise RuntimeError
@@ -159,6 +162,7 @@ class GeminiManager:
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
def setup_grads_device(
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
) -> None:
self._placement_policy.setup_grads_device(params, grads_device_map)

View File

@@ -10,34 +10,35 @@ from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer']
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
module: GeminiDDP,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
def __init__(
self,
module: GeminiDDP,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
)
self.module = module
def check_local_overflow(self) -> bool:
@@ -77,25 +78,28 @@ class GeminiOptimizer(OptimizerWrapper):
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
"""
def __init__(self,
optim: Optimizer,
module: GeminiDDP,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
verbose: bool = False,
**defaults: Any):
def __init__(
self,
optim: Optimizer,
module: GeminiDDP,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
verbose: bool = False,
**defaults: Any,
):
super().__init__(optim)
assert isinstance(module, GeminiDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
f"{_AVAIL_OPTIM_LIST}"
assert type(optim) in _AVAIL_OPTIM_LIST, (
"You should use an optimizer in the available list:\n" f"{_AVAIL_OPTIM_LIST}"
)
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
@@ -118,8 +122,10 @@ class GeminiOptimizer(OptimizerWrapper):
for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!")
warnings.warn(
f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!"
)
else:
ddp_param_list.append(param)
@@ -132,14 +138,16 @@ class GeminiOptimizer(OptimizerWrapper):
self.__init__optimizer()
if module.mixed_precision is torch.float16:
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(
module,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
elif module.mixed_precision is torch.bfloat16:
self.mix_precision_mixin = BF16MixedPrecisionMixin()
else:
@@ -148,12 +156,15 @@ class GeminiOptimizer(OptimizerWrapper):
self._logger = get_dist_logger()
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0"
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr(
optim, 'num_fp32_shards_per_param', 0) >= 2
self._should_move_fp32_params_h2d: bool = (
self.gemini_manager.is_cuda_margin_mem_avail
and self.gpu_margin_mem_ratio > 0.0
and getattr(optim, "num_fp32_shards_per_param", 0) >= 2
)
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
@@ -161,7 +172,7 @@ class GeminiOptimizer(OptimizerWrapper):
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group['params']:
for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param]
begin, end = self.param_to_range[fake_param]
chunk16 = chunk32.paired_chunk
@@ -173,7 +184,7 @@ class GeminiOptimizer(OptimizerWrapper):
def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group['params']:
for fake_param in group["params"]:
assert fake_param.grad is None
fake_param.data = none_tensor.to(fake_param.device)
@@ -198,7 +209,7 @@ class GeminiOptimizer(OptimizerWrapper):
group_to_norm[c16.torch_pg] = 0.0
group_to_norm[c16.torch_pg] += c16.l2_norm
c16.l2_norm = None # clear l2 norm
c16.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
for group, part_norm in group_to_norm.items():
@@ -230,9 +241,9 @@ class GeminiOptimizer(OptimizerWrapper):
if self.mix_precision_mixin.should_skip_step():
if self.verbose:
self._logger.info(f'Found overflow. Skip step')
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
self._logger.info(f"Found overflow. Skip step")
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
@@ -269,11 +280,11 @@ class GeminiOptimizer(OptimizerWrapper):
fp32_params_used_cuda_margin_mem = 0
for group in self.param_groups:
for fake_param in group['params']:
for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param]
chunk16 = chunk32.paired_chunk
if chunk32.device_type == 'cuda':
if chunk32.device_type == "cuda":
continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
@@ -284,9 +295,9 @@ class GeminiOptimizer(OptimizerWrapper):
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
for group in self.param_groups:
for fake_param in group['params']:
for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param]
if chunk32.device_type == 'cuda':
if chunk32.device_type == "cuda":
state = self.optim.state[fake_param]
for k, v in state.items():
if isinstance(v, torch.Tensor):
@@ -294,14 +305,13 @@ class GeminiOptimizer(OptimizerWrapper):
def _register_states_(self):
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
state = self.optim.state[p]
for val in state.values():
if isinstance(val, torch.Tensor):
self.chunk_manager.add_extern_static_tensor(val)
def __init__optimizer(self):
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
@@ -313,10 +323,9 @@ class GeminiOptimizer(OptimizerWrapper):
param_id = -1
for group in self.optim.param_groups:
fake_params_list = list()
group_backup = {k: v for k, v in group.items() if k != 'params'}
group_backup = {k: v for k, v in group.items() if k != "params"}
group_ids = []
for param in group['params']:
for param in group["params"]:
# Record the mapping of id to current param.
param_id += 1
self.id_to_real_params[param_id] = param
@@ -337,12 +346,12 @@ class GeminiOptimizer(OptimizerWrapper):
fake_params_list.append(fake_param)
# Update self.optim.param_groups as well as backup group.
group['params'] = fake_params_list
group_backup['params'] = group_ids
group["params"] = fake_params_list
group_backup["params"] = group_ids
self.param_groups_backup.append(group_backup)
def get_offsets(self, param_id: int) -> tuple:
'''
"""
Args:
param_id(int): The id of parameter.
@@ -351,7 +360,7 @@ class GeminiOptimizer(OptimizerWrapper):
shard_offset(int): Offset of its optimizer state shard
relative to the whole optimizer state.
shard_size(int): Length of parameter shard owned by current process.
'''
"""
if param_id not in self.id_to_fake_params:
return -1, -1, -1
@@ -425,11 +434,11 @@ class GeminiOptimizer(OptimizerWrapper):
if is_collector:
states = self.optim.state[fake_param]
for state_name in state_names:
if state_name == 'step':
if state_name == "step":
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
collected_states[state_name] = torch.tensor(states['step'],
dtype=torch.float32,
requires_grad=False).cpu()
collected_states[state_name] = torch.tensor(
states["step"], dtype=torch.float32, requires_grad=False
).cpu()
else:
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
@@ -441,12 +450,13 @@ class GeminiOptimizer(OptimizerWrapper):
# Collector gets prepared for state collecting.
if is_collector:
for state_name in state_names:
if state_name == 'step':
if state_name == "step":
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu()
else:
collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32,
requires_grad=False).cpu()
collected_states[state_name] = torch.zeros(
param.numel(), dtype=torch.float32, requires_grad=False
).cpu()
# Materials for gathering, including compacted state tensors, and the offset of shard inside each state.
compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
@@ -465,8 +475,9 @@ class GeminiOptimizer(OptimizerWrapper):
shard_size = state_shard[2]
if compacted_states is None:
continue
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
shard_size)
self.load_from_compacted_states(
compacted_states, collected_states, state_names, shard_offset, shard_size
)
# Reshape tensors
if is_collector:
@@ -476,14 +487,16 @@ class GeminiOptimizer(OptimizerWrapper):
return collected_states
def pack_optimizer_states_to_tensor(self,
param_id: int,
state_names: list,
device: torch.device = torch.device('cuda'),
dtype: torch.dtype = torch.float32) -> torch.Tensor:
'''
def pack_optimizer_states_to_tensor(
self,
param_id: int,
state_names: list,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
With param id given, pack its optimizer states into a compact tensor and return.
'''
"""
if param_id not in self.id_to_fake_params:
return None
@@ -493,7 +506,7 @@ class GeminiOptimizer(OptimizerWrapper):
shard_size = param_range[1] - param_range[0]
compacted_size = 0
for name in state_names:
if name == 'step':
if name == "step":
compacted_size += 1
else:
compacted_size += shard_size
@@ -502,7 +515,7 @@ class GeminiOptimizer(OptimizerWrapper):
next_state_offset = 0
for state_name, state_tensor in states.items():
# State 'step' needs special operation.
if state_name == 'step':
if state_name == "step":
if isinstance(state_tensor, torch.Tensor):
compacted_states[next_state_offset] = state_tensor[0].item()
else:
@@ -511,47 +524,53 @@ class GeminiOptimizer(OptimizerWrapper):
next_state_offset += 1
else:
assert state_tensor.numel() == shard_size
compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor)
compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor)
next_state_offset += shard_size
return compacted_states
def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list,
shard_start: int, shard_size: int):
'''
def load_from_compacted_states(
self,
compacted_states: torch.Tensor,
collected_states: dict,
state_names: list,
shard_start: int,
shard_size: int,
):
"""
Given a tensor carrying compacted optimizer states,
update these states to collected_states.
'''
"""
shard_end = shard_start + shard_size
next_state_offset = 0
for state_name in state_names:
if state_name == 'step':
collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(),
dtype=torch.float32,
requires_grad=False).cpu()
if state_name == "step":
collected_states["step"].data = torch.tensor(
compacted_states[next_state_offset].item(), dtype=torch.float32, requires_grad=False
).cpu()
next_state_offset += 1
else:
target_segment = collected_states[state_name][shard_start:shard_end]
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
target_segment.copy_(compacted_states[next_state_offset : next_state_offset + shard_size])
next_state_offset += shard_size
def get_param_groups_for_saving(self) -> list:
'''
"""
Return the param_groups in Pytorch format when saving to checkpoint.
'''
"""
param_groups = copy.deepcopy(self.param_groups_backup)
# To be compatible with pytorch checkpointing,
# store extra hyperparameters used by pytorch Adam optimizer.
torch_special_hyperparameters = {
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': False
"amsgrad": False,
"maximize": False,
"foreach": None,
"capturable": False,
"differentiable": False,
"fused": False,
}
for group in param_groups:
@@ -580,13 +599,13 @@ class GeminiOptimizer(OptimizerWrapper):
so it should be called only when memory resources are abundant.
"""
state_dict = {}
state_dict['param_groups'] = self.get_param_groups_for_saving()
state_dict["param_groups"] = self.get_param_groups_for_saving()
# Collect optimizer states.
state_dict['state'] = dict()
state_dict["state"] = dict()
for param_id in self.id_to_real_params.keys():
dist.barrier()
state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
state_dict["state"][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
return state_dict
def load_param_groups(self, saved_param_groups: list):
@@ -601,13 +620,13 @@ class GeminiOptimizer(OptimizerWrapper):
for group in saved_param_groups:
fake_params_list = list()
updated_group = {k: v for k, v in group.items() if k != 'params'}
for param_id in group['params']:
updated_group = {k: v for k, v in group.items() if k != "params"}
for param_id in group["params"]:
if param_id not in self.id_to_fake_params:
continue
fake_param = self.id_to_fake_params[param_id]
fake_params_list.append(fake_param)
updated_group['params'] = fake_params_list
updated_group["params"] = fake_params_list
self.optim.param_groups.append(updated_group)
def load_single_param_states(self, param_id: int, saved_states: dict):
@@ -621,15 +640,14 @@ class GeminiOptimizer(OptimizerWrapper):
"""
assert isinstance(value, torch.Tensor)
ret_val = value
if (key == "step"):
if key == "step":
assert value.numel() == 1
ret_val = int(value.item())
else:
state_start, state_end = state_range
ret_val = torch.zeros(state_end - state_start,
dtype=torch.float32,
device=param.device,
requires_grad=False)
ret_val = torch.zeros(
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
)
ret_val.copy_(value.flatten()[state_start:state_end])
return ret_val
@@ -642,7 +660,7 @@ class GeminiOptimizer(OptimizerWrapper):
updated_states = dict()
for k, v in saved_states.items():
updated_states[k] = cast(fake_param, state_range, v, k)
del v # clean loaded states
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)
def load_param_states(self, param_states: dict):
@@ -658,8 +676,8 @@ class GeminiOptimizer(OptimizerWrapper):
def optimizer_loading_epilogue(self):
# Epilogue when loading state_dict to pytorch optimizer.
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault('differentiable', False)
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault("differentiable", False)
def load_state_dict(self, state_dict: dict):
"""Loads optimizer state from complete optimizer state_dict.
@@ -669,16 +687,15 @@ class GeminiOptimizer(OptimizerWrapper):
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
assert 'param_groups' in state_dict
assert 'state' in state_dict
self.load_param_groups(state_dict['param_groups'])
self.load_param_states(state_dict['state'])
assert "param_groups" in state_dict
assert "state" in state_dict
self.load_param_groups(state_dict["param_groups"])
self.load_param_states(state_dict["state"])
self.optimizer_loading_epilogue()
def state_shard(self,
prefix: str = '',
max_shard_size: int = 1024,
only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]:
def state_shard(
self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True
) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing shards of optimizer states one by one.
The max size of each dictionary shard is specified by ``max_shard_size``.
@@ -694,7 +711,6 @@ class GeminiOptimizer(OptimizerWrapper):
sharder = StateDictSharder(max_shard_size)
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
@@ -705,19 +721,20 @@ class GeminiOptimizer(OptimizerWrapper):
yield sharder.current_block, sharder.current_block_size
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
raise NotImplementedError("Gemini does not support clip_grad_by_value")
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> torch.Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs,
) -> torch.Tensor:
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm")
class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)

View File

@@ -1,10 +1,14 @@
from .param_runtime_order import OrderedParamGenerator # isort:skip
from .memory_stats import MemStats # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memstats_collector import MemStatsCollector # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .param_runtime_order import OrderedParamGenerator # isort:skip
from .memory_stats import MemStats # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memstats_collector import MemStatsCollector # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
__all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'MemStats',
'OrderedParamGenerator'
"AsyncMemoryMonitor",
"SyncCudaMemoryMonitor",
"MemStatsCollector",
"ChunkMemStatsCollector",
"MemStats",
"OrderedParamGenerator",
]

View File

@@ -8,7 +8,6 @@ from .memstats_collector import MemStatsCollector
class ChunkMemStatsCollector(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
"""
@@ -27,10 +26,11 @@ class ChunkMemStatsCollector(MemStatsCollector):
record model data volume on cuda and cpu.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_mem = self._chunk_manager.total_mem['cuda']
cuda_mem = self._chunk_manager.total_mem["cuda"]
self._memstats.record_max_cuda_model_data(cuda_mem)
@property
def cuda_margin_mem(self) -> float:
from colossalai.legacy.utils.memory import colo_device_memory_capacity
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda

View File

@@ -111,6 +111,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
def _measure_usage(self):
from colossalai.legacy.utils import colo_device_memory_used
max_usage = 0
while self.keep_measuring:
max_usage = max(

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import List, Optional
import torch
@@ -6,7 +6,6 @@ from .param_runtime_order import OrderedParamGenerator
class MemStats(object):
def __init__(self) -> None:
"""
Store the non model data statistics used for Gemini and GeminiOptimizer.
@@ -92,17 +91,17 @@ class MemStats(object):
return self._param_runtime_order
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
if device_type == "cuda":
return self._non_model_data_cuda_list
elif device_type == 'cpu':
elif device_type == "cpu":
return self._non_model_data_cpu_list
else:
raise TypeError
def max_non_model_data(self, device_type: str) -> float:
if device_type == 'cuda':
if device_type == "cuda":
return max(self._non_model_data_cuda_list)
elif device_type == 'cpu':
elif device_type == "cpu":
return max(self._non_model_data_cpu_list)
else:
raise TypeError

View File

@@ -40,11 +40,12 @@ class MemStatsCollector:
Returns:
int: max non model data memory usage of current sampling period
"""
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
assert not self._start_flag, "Cannot get mem stats info during collection phase."
assert self._step_total > 0, "Cannot get mem stats info before collection phase."
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, (
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "
f"step total {self._step_total}"
)
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data
@@ -60,9 +61,9 @@ class MemStatsCollector:
def finish_collection(self):
self.sample_overall_data()
# self._step_total = len(self._sampling_time)
self._step_total = len(self._memstats.non_model_data_list('cuda'))
self._step_total = len(self._memstats.non_model_data_list("cuda"))
self._start_flag = False
print(f'finish_collection {self._step_total}')
print(f"finish_collection {self._step_total}")
# deprecated
def record_model_data_volume(self) -> None:
@@ -73,7 +74,7 @@ class MemStatsCollector:
from colossalai.legacy.zero.gemini import StatefulTensor
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"]
self._memstats.record_max_cuda_model_data(cuda_mem)
def sample_overall_data(self) -> None:

View File

@@ -4,7 +4,6 @@ import torch
class ParamGenerator(ABC):
def append(self, param: torch.nn.Parameter):
pass

View File

@@ -10,10 +10,10 @@ from colossalai.utils import _cast_float
from .memory_stats import MemStats
__all__ = ['RuntimeMemTracer']
__all__ = ["RuntimeMemTracer"]
class RuntimeMemTracer():
class RuntimeMemTracer:
"""RuntimeMemTracer for the module training using ColoParameter.
Trace non-model memory usage during fwd+bwd process.

View File

@@ -15,9 +15,9 @@ from .chunk_memstats_collector import ChunkMemStatsCollector
class ModuleInfos:
def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str,
parent_module: torch.nn.Module):
def __init__(
self, module: torch.nn.Module, module_name: str, module_full_name: str, parent_module: torch.nn.Module
):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
@@ -35,14 +35,13 @@ class StaticMemStatsCollector(ChunkMemStatsCollector):
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
data = [MetaTensor(torch.rand(inp.shape, device="meta"), fake_device="cpu") for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
@@ -87,12 +86,13 @@ class StaticMemStatsCollector(ChunkMemStatsCollector):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
def register_opnodes_recursively(
self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None,
):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():

View File

@@ -14,7 +14,7 @@ def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
"""
if optim is None:
return 0, 0
assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()"
assert hasattr(optim, "get_memory_usage"), f"{type(optim)} has no attr get_memory_usage()"
return optim.get_memory_usage()
@@ -35,16 +35,16 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
return 0, 0
assert isinstance(t, torch.Tensor)
_cpu_mem_usage, _cuda_mem_usage = 0, 0
if t.device.type == 'cpu':
if t.device.type == "cpu":
_cpu_mem_usage += t.numel() * t.element_size()
elif t.device.type == 'cuda':
elif t.device.type == "cuda":
_cuda_mem_usage += t.numel() * t.element_size()
return _cuda_mem_usage, _cpu_mem_usage
cuda_mem_usage = 0
cpu_mem_usage = 0
for param in model.parameters():
if hasattr(param, 'colo_attr'):
if hasattr(param, "colo_attr"):
t_cuda, t_cpu = param.colo_attr.get_memory_usage()
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu

View File

@@ -17,10 +17,9 @@ from .memory_tracer import ChunkMemStatsCollector
class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
**kwargs) -> None:
def __init__(
self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs
) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
@@ -29,23 +28,25 @@ class PlacementPolicy(ABC):
raise NotImplementedError
@abstractmethod
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
def setup_grads_device(
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
) -> None:
raise NotImplementedError
class StaticPlacementPolicy(PlacementPolicy):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0,
**kwargs) -> None:
def __init__(
self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0,
**kwargs,
) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0
self.shard_param_frac = shard_param_frac
self.offload_optim_frac = offload_optim_frac
@@ -66,13 +67,14 @@ class StaticPlacementPolicy(PlacementPolicy):
for chunk in can_evict_chunks:
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
break
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
self.chunk_manager.move_chunk(chunk, torch.device("cpu"))
# real saved mem is shard_mem, for simplicity we use chunk_mem
can_offload_chunk_mem -= chunk.chunk_mem
return 0, 0.0
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
def setup_grads_device(
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
) -> None:
total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
@@ -85,7 +87,7 @@ class StaticPlacementPolicy(PlacementPolicy):
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
device = get_current_device()
else:
device = torch.device('cpu')
device = torch.device("cpu")
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
offloaded_optim_chunk_mem += chunk.chunk_mem
for p in params:
@@ -97,12 +99,14 @@ class StaticPlacementPolicy(PlacementPolicy):
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs) -> None:
def __init__(
self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs,
) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
@@ -110,13 +114,15 @@ class AutoPlacementPolicy(PlacementPolicy):
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs) -> Tuple[int, float]:
def evict_tensors(
self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs,
) -> Tuple[int, float]:
"""
Evict tensors from CUDA device.
@@ -135,13 +141,13 @@ class AutoPlacementPolicy(PlacementPolicy):
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda")
cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
@@ -160,11 +166,13 @@ class AutoPlacementPolicy(PlacementPolicy):
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
self.chunk_manager.move_chunk(chunk, torch.device("cpu"))
freed_cuda_model_data += chunk.chunk_mem
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}")
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
return freed_cuda_model_data, time() - start
@staticmethod
@@ -178,8 +186,9 @@ class AutoPlacementPolicy(PlacementPolicy):
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
def setup_grads_device(
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
) -> None:
for p in params:
chunk = self.chunk_manager.get_chunk(p)
# init offload optim settings
@@ -187,13 +196,13 @@ class AutoPlacementPolicy(PlacementPolicy):
if chunk.keep_gathered:
grads_device_map[p] = get_current_device()
else:
grads_device_map[p] = torch.device('cpu')
grads_device_map[p] = torch.device("cpu")
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {
'auto': AutoPlacementPolicy,
'static': StaticPlacementPolicy,
"auto": AutoPlacementPolicy,
"static": StaticPlacementPolicy,
}
@staticmethod

View File

@@ -27,16 +27,15 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
return total_temp
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
"""
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ""):
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules."""
if memo is None:
memo = set()
if module not in memo:
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
submodule_prefix = prefix + ("." if prefix else "") + name
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
yield m
@@ -60,10 +59,9 @@ def _get_shallow_copy_model(model: nn.Module):
return old_to_new[model]
def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
def get_static_torch_model(
zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True
) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original GeminiDDP model is not modified.
Thus, you can use the original model in further training.
@@ -79,6 +77,7 @@ def get_static_torch_model(zero_ddp_model,
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
assert isinstance(zero_ddp_model, GeminiDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
@@ -86,15 +85,17 @@ def get_static_torch_model(zero_ddp_model,
torch_model = _get_shallow_copy_model(colo_model)
if not only_rank_0 or dist.get_rank() == 0:
for (name, colo_module), (_, torch_module) in \
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
for (name, colo_module), (_, torch_module) in zip(
_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)
):
# clean the parameter list of the new torch module
torch_module._parameters = OrderedDict()
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
# get the full name of the parameter
full_param_name = name + ('.' if name else '') + sufix_param_name
assert full_param_name in state_dict, \
f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
full_param_name = name + ("." if name else "") + sufix_param_name
assert (
full_param_name in state_dict
), f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
state_param = state_dict[full_param_name]
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))

View File

@@ -1,3 +1,3 @@
from .low_level_optim import LowLevelZeroOptimizer
__all__ = ['LowLevelZeroOptimizer']
__all__ = ["LowLevelZeroOptimizer"]

View File

@@ -44,8 +44,8 @@ def shuffle_by_round_robin(tensor_list, num_partitions):
for partition_id in range(partitions_count):
partition_tensors = partitions[partition_id]
for item in partition_tensors:
tensor_index_mapping[item['index']] = len(new_tensor_list)
new_tensor_list.append(item['tensor'])
tensor_index_mapping[item["index"]] = len(new_tensor_list)
new_tensor_list.append(item["tensor"])
return new_tensor_list, tensor_index_mapping
@@ -107,11 +107,13 @@ def split_by_dtype(tensor_list):
return buckets
def reduce_tensor_dp_group(tensor: torch.Tensor,
dtype: Optional[torch.dtype] = None,
dst_local_rank: Optional[int] = None,
dst_global_rank: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None):
def reduce_tensor_dp_group(
tensor: torch.Tensor,
dtype: Optional[torch.dtype] = None,
dst_local_rank: Optional[int] = None,
dst_global_rank: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None,
):
"""
Reduce the tensor in the data parallel process group
@@ -173,7 +175,7 @@ def has_inf_or_nan(tensor):
raise
return True
else:
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum:
return True
return False
@@ -184,8 +186,7 @@ def release_param_grad(tensor_list):
def calculate_global_norm_from_list(norm_list):
""" Compute total from a list of norms
"""
"""Compute total from a list of norms"""
total_norm = 0.0
for norm in norm_list:
total_norm += norm**2.0
@@ -221,7 +222,7 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
total_norm = 0.0
for g in gradients:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
total_norm += param_norm.item() ** 2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
@@ -230,9 +231,9 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
total_norm = -1
return total_norm

View File

@@ -3,4 +3,4 @@ from .gradient_store import GradientStore
from .parameter_store import ParameterStore
from .tensor_bucket import TensorBucket
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"]

View File

@@ -3,7 +3,6 @@ from torch.distributed import ProcessGroup
class BaseStore:
def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)

View File

@@ -9,7 +9,6 @@ from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
@@ -38,8 +37,7 @@ class BucketStore(BaseStore):
return self._num_elements_in_bucket
def reset_num_elements_in_bucket(self):
"""Set the number of elements in bucket to zero.
"""
"""Set the number of elements in bucket to zero."""
self._num_elements_in_bucket = 0
@@ -54,7 +52,7 @@ class BucketStore(BaseStore):
self._param_list.append(param)
self._padding_size.append(padding_size)
self._num_elements_in_bucket += (param.numel() + padding_size)
self._num_elements_in_bucket += param.numel() + padding_size
self.current_group_id = group_id
# number of tensors in current bucket
@@ -119,8 +117,7 @@ class BucketStore(BaseStore):
return self.grad_to_param_mapping[id(grad)]
def reset(self):
"""Reset the bucket storage after reduction, only release the tensors have been reduced
"""
"""Reset the bucket storage after reduction, only release the tensors have been reduced"""
cur_offset = self.offset_list.pop(0)
self._param_list = self._param_list[cur_offset:]
self._padding_size = self._padding_size[cur_offset:]

View File

@@ -1,13 +1,11 @@
from typing import List
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args)
"""

View File

@@ -5,7 +5,6 @@ from .base_store import BaseStore
class ParameterStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)

View File

@@ -2,7 +2,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class TensorBucket:
def __init__(self, size):
self._max_size = size
self._current_size = 0
@@ -26,8 +25,7 @@ class TensorBucket:
tensor_size = tensor.numel()
if not allow_oversize and self.will_exceed_max_size(tensor_size):
msg = f"The param bucket max size {self._max_size} is exceeded" \
+ f"by tensor (size {tensor_size})"
msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})"
raise RuntimeError(msg)
self._bucket.append(tensor)

View File

@@ -17,6 +17,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
@@ -32,19 +33,21 @@ from .bookkeeping import BucketStore, GradientStore, ParameterStore
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
num_working_param_groups: int,
grad_store: GradientStore,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
def __init__(
self,
num_working_param_groups: int,
grad_store: GradientStore,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store
@@ -57,32 +60,31 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
class LowLevelZeroOptimizer(OptimizerWrapper):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
"""Optimizer used for ZeRO-1 and ZeRO-2."""
def __init__(
self,
optimizer: Optimizer,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
backoff_factor: float = .5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
self,
optimizer: Optimizer,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._logger = get_dist_logger()
self._verbose = verbose
@@ -115,7 +117,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if forced_dtype:
for group in self.optim.param_groups:
group_params = group['params']
group_params = group["params"]
for param in group_params:
param.data = param.data.to(forced_dtype)
self._dtype = forced_dtype
@@ -134,7 +136,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
group_params = list()
for param in param_group['params']:
for param in param_group["params"]:
if param.requires_grad:
group_params.append(param)
@@ -148,7 +150,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group['params'] = master_param_current_rank
param_group["params"] = master_param_current_rank
# intialize communication stream for
# communication-compuation overlapping
@@ -164,15 +166,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# initialize mixed precision mixin
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
self._grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(
self.num_param_groups,
self._grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
@@ -185,17 +189,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return len(self._working_param_groups)
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
assert torch.cuda.is_available(), "CUDA is required"
for param_group in self.optim.param_groups:
group_params = param_group['params']
group_params = param_group["params"]
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
assert (
param.dtype == self._dtype
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
params_current_rank = []
device = 'cpu' if self._cpu_offload else get_current_device()
device = "cpu" if self._cpu_offload else get_current_device()
for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
@@ -275,8 +280,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
sync_tensor(flat_grads_per_rank[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
param_id)) < self._world_size:
if (
len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id))
< self._world_size
):
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
@@ -307,8 +314,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# if full, will reduce the grads already in the bucket
# or got a grad of param from another group
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
group_id != self._bucket_store.current_group_id:
if (
self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size
or group_id != self._bucket_store.current_group_id
):
self._run_reduction()
padding_size = self._param_store.get_param_padding_size(param)
@@ -319,8 +328,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
################################
def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and no_sync are not compatible"
assert not (
self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and no_sync are not compatible"
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
@@ -339,8 +349,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.zero_grad()
def backward_by_grad(self, tensor, grad):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
assert not (
self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
@@ -380,14 +391,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
####################
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
assert closure is None, "closure is not supported by step()"
if not self.require_grad_sync:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self._logger.info(f"Found overflow. Skip step")
self.zero_grad()
return
@@ -428,7 +439,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._grad_store.reset_grads_by_group_id(group_id)
# update the params in the optimizer
self.optim.param_groups[group_id]['params'] = real_master_params[group_id]
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
@@ -445,16 +456,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update working partition updated by the current rank
dtype = real_working_params[0][0].dtype
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]['params']
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param))
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
#############################
# Mixed Precision Utilities #
@@ -466,14 +477,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.mixed_precision_mixin is not None:
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
if self._clip_grad_norm > 0.:
if self._clip_grad_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
div_scale = clip * div_scale
for grad in grad_groups_flat:
grad.data.mul_(1. / div_scale)
grad.data.mul_(1.0 / div_scale)
############################
# Gradient Synchronization #
@@ -518,18 +529,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
packed = {k: v for k, v in group.items() if k != "params"}
param_mappings.update(
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
{id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings}
)
packed["params"] = [param_mappings[id(p)] for p in group["params"]]
start_index += len(packed["params"])
return packed
param_groups = [pack_group(g) for g in self.optim.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
return {'state': packed_state, 'param_groups': param_groups}
return {"state": packed_state, "param_groups": param_groups}
def state_dict(self) -> Dict:
"""Return a state_dict same with DDP
@@ -541,14 +553,15 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
if isinstance(v, torch.Tensor) and k != "step":
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
working_param).cpu()
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
@@ -562,16 +575,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
state_dict (dict): A pytorch form state_dict
"""
zero_state_dict = copy.deepcopy(state_dict)
for param_idx, state in zero_state_dict['state'].items():
for param_idx, state in zero_state_dict["state"].items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
if isinstance(v, torch.Tensor) and k != "step":
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
@@ -588,7 +601,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block = dict()
ret_block_size = 0
local_states = self.optim.state_dict()['state']
local_states = self.optim.state_dict()["state"]
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
@@ -601,11 +614,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = self._param_store.master_to_working_param[id(master_param)]
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != 'step':
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
working_param).cpu()
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
current_block_size += state_tensor.numel()
current_block[k] = state_tensor

View File

@@ -7,10 +7,9 @@ import torch.nn as nn
from .gemini import GeminiDDP
def zero_model_wrapper(model: nn.Module,
zero_stage: int = 1,
gemini_config: Optional[Dict] = None,
verbose: bool = False):
def zero_model_wrapper(
model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None, verbose: bool = False
):
"""This wrapper function is used to wrap your training model for ZeRO DDP.
Example:
@@ -50,19 +49,21 @@ def zero_model_wrapper(model: nn.Module,
return wrapped_model
def zero_optim_wrapper(model: nn.Module,
optimizer: torch.optim.Optimizer,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
optim_config: Optional[Dict] = None,
verbose: bool = False):
def zero_optim_wrapper(
model: nn.Module,
optimizer: torch.optim.Optimizer,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
optim_config: Optional[Dict] = None,
verbose: bool = False,
):
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP.
Args:
@@ -95,20 +96,22 @@ def zero_optim_wrapper(model: nn.Module,
else:
config_dict = copy(optim_config)
config_dict['initial_scale'] = initial_scale
config_dict['growth_factor'] = growth_factor
config_dict['backoff_factor'] = backoff_factor
config_dict['growth_interval'] = growth_interval
config_dict['hysteresis'] = hysteresis
config_dict['min_scale'] = min_scale
config_dict['max_scale'] = max_scale
config_dict["initial_scale"] = initial_scale
config_dict["growth_factor"] = growth_factor
config_dict["backoff_factor"] = backoff_factor
config_dict["growth_interval"] = growth_interval
config_dict["hysteresis"] = hysteresis
config_dict["min_scale"] = min_scale
config_dict["max_scale"] = max_scale
if zero_stage in [1, 2]:
from colossalai.zero.low_level import LowLevelZeroOptimizer
config_dict['partition_grad'] = zero_stage == 2
config_dict['clip_grad_norm'] = max_norm
config_dict["partition_grad"] = zero_stage == 2
config_dict["clip_grad_norm"] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
else:
from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
config_dict['clipping_norm'] = max_norm
config_dict["clipping_norm"] = max_norm
return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)