mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,6 +10,13 @@ from .low_level import LowLevelZeroOptimizer
|
||||
from .wrapper import zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
__all__ = [
|
||||
'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
|
||||
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
|
||||
"GeminiDDP",
|
||||
"GeminiOptimizer",
|
||||
"GeminiAdamOptimizer",
|
||||
"zero_model_wrapper",
|
||||
"zero_optim_wrapper",
|
||||
"LowLevelZeroOptimizer",
|
||||
"ColoInitContext",
|
||||
"post_process_colo_init_ctx",
|
||||
"get_static_torch_model",
|
||||
]
|
||||
|
||||
@@ -6,6 +6,15 @@ from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
|
||||
from .utils import get_static_torch_model
|
||||
|
||||
__all__ = [
|
||||
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
|
||||
'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
|
||||
"GeminiManager",
|
||||
"TensorInfo",
|
||||
"TensorState",
|
||||
"ChunkManager",
|
||||
"search_chunk_configuration",
|
||||
"GeminiDDP",
|
||||
"get_static_torch_model",
|
||||
"GeminiAdamOptimizer",
|
||||
"GeminiOptimizer",
|
||||
"ColoInitContext",
|
||||
"post_process_colo_init_ctx",
|
||||
]
|
||||
|
||||
@@ -3,4 +3,4 @@ from .manager import ChunkManager
|
||||
from .search_utils import classify_params_by_dp_degree, search_chunk_configuration
|
||||
from .utils import init_chunk_manager
|
||||
|
||||
__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager']
|
||||
__all__ = ["Chunk", "ChunkManager", "classify_params_by_dp_degree", "search_chunk_configuration", "init_chunk_manager"]
|
||||
|
||||
@@ -17,12 +17,17 @@ class TensorState(Enum):
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE,
|
||||
TensorState.HOLD),
|
||||
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||
TensorState.HOLD))
|
||||
STATE_TRANS = (
|
||||
(TensorState.FREE, TensorState.HOLD),
|
||||
(TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE),
|
||||
(TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD),
|
||||
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
|
||||
(TensorState.READY_FOR_REDUCE, TensorState.HOLD),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,14 +58,16 @@ def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
class Chunk:
|
||||
_total_number = 0
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
process_group: ProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
cpu_shard_init: bool = False,
|
||||
keep_gathered: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
process_group: ProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
cpu_shard_init: bool = False,
|
||||
keep_gathered: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Chunk: A container owning a piece of contiguous memory space for tensors
|
||||
Here we use all-gather operation to gather the whole chunk.
|
||||
@@ -99,9 +106,9 @@ class Chunk:
|
||||
device = init_device or get_current_device()
|
||||
|
||||
# chunk_temp is a global chunk, which only exists during building the chunks.
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||
|
||||
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
|
||||
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
|
||||
|
||||
# cuda local chunk, which is sharded on GPUs
|
||||
self.cuda_shard = None
|
||||
@@ -134,7 +141,7 @@ class Chunk:
|
||||
# they are treated the same as that of the parameters in DDP during training.
|
||||
self.keep_gathered = keep_gathered
|
||||
if self.keep_gathered:
|
||||
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
||||
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
||||
|
||||
# if pin_memory is True, we allocate a piece of CPU pin-memory
|
||||
# for it all the time
|
||||
@@ -160,7 +167,7 @@ class Chunk:
|
||||
|
||||
if self.chunk_temp is not None:
|
||||
# this chunk is not closed
|
||||
if self.chunk_temp.device.type == 'cuda':
|
||||
if self.chunk_temp.device.type == "cuda":
|
||||
cuda_memory += self.chunk_mem
|
||||
else:
|
||||
cpu_memory += self.chunk_mem
|
||||
@@ -180,11 +187,11 @@ class Chunk:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
if self.is_gathered:
|
||||
return 'cuda'
|
||||
return "cuda"
|
||||
elif self.cuda_shard is not None:
|
||||
return 'cuda'
|
||||
return "cuda"
|
||||
else:
|
||||
return 'cpu'
|
||||
return "cpu"
|
||||
|
||||
@property
|
||||
def payload(self) -> torch.Tensor:
|
||||
@@ -217,8 +224,10 @@ class Chunk:
|
||||
if self.keep_gathered:
|
||||
return False
|
||||
else:
|
||||
return self.tensor_state_cnter[TensorState.HOLD] + \
|
||||
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
return (
|
||||
self.tensor_state_cnter[TensorState.HOLD] + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD]
|
||||
== self.num_tensors
|
||||
)
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
@@ -226,27 +235,25 @@ class Chunk:
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
"""Check if the chunk has inf or nan values on CUDA.
|
||||
"""
|
||||
"""Check if the chunk has inf or nan values on CUDA."""
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
|
||||
valid_tensor = self.cuda_global_chunk[: self.utilized_size]
|
||||
else:
|
||||
assert self.cuda_shard is not None # only check on CUDA
|
||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||
assert self.cuda_shard is not None # only check on CUDA
|
||||
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
|
||||
def set_l2_norm(self) -> None:
|
||||
"""Record l2 norm of this chunks on CUDA.
|
||||
"""
|
||||
"""Record l2 norm of this chunks on CUDA."""
|
||||
assert self.l2_norm is None, "you are calculating the l2 norm twice"
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
|
||||
valid_tensor = self.cuda_global_chunk[: self.utilized_size]
|
||||
else:
|
||||
assert self.cuda_shard is not None # calculate on CUDA
|
||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||
assert self.cuda_shard is not None # calculate on CUDA
|
||||
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||
chunk_l2_norm = valid_tensor.data.float().norm(2)
|
||||
self.l2_norm = chunk_l2_norm.item()**2
|
||||
self.l2_norm = chunk_l2_norm.item() ** 2
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor):
|
||||
"""Add a tensor to the chunk.
|
||||
@@ -263,9 +270,9 @@ class Chunk:
|
||||
if new_utilized_size > self.chunk_size:
|
||||
raise ChunkFullError
|
||||
|
||||
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
|
||||
self.chunk_temp[self.utilized_size : new_utilized_size].copy_(tensor.data.flatten())
|
||||
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
tensor.data = self.chunk_temp[self.utilized_size : new_utilized_size].view(tensor.shape)
|
||||
|
||||
# record all the information about the tensor
|
||||
self.num_tensors += 1
|
||||
@@ -275,8 +282,7 @@ class Chunk:
|
||||
self.utilized_size = new_utilized_size
|
||||
|
||||
def close_chunk(self):
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
|
||||
"""
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk later."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is not None
|
||||
|
||||
@@ -286,7 +292,7 @@ class Chunk:
|
||||
elif self.utilized_size < self.shard_end:
|
||||
self.valid_end = self.utilized_size - self.shard_begin
|
||||
|
||||
if self.chunk_temp.device.type == 'cpu':
|
||||
if self.chunk_temp.device.type == "cpu":
|
||||
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
|
||||
self.__update_tensors_ptr()
|
||||
else:
|
||||
@@ -298,12 +304,12 @@ class Chunk:
|
||||
if self.keep_gathered:
|
||||
return
|
||||
|
||||
if self.pin_memory or self.shard_device.type == 'cpu':
|
||||
if self.pin_memory or self.shard_device.type == "cpu":
|
||||
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
|
||||
self.cpu_shard.copy_(self.cuda_shard)
|
||||
self.cpu_vis_flag = True # cpu_shard has been visited
|
||||
self.cpu_vis_flag = True # cpu_shard has been visited
|
||||
|
||||
if self.shard_device.type == 'cpu':
|
||||
if self.shard_device.type == "cpu":
|
||||
self.cuda_shard = None
|
||||
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False):
|
||||
@@ -318,12 +324,12 @@ class Chunk:
|
||||
# when the current chunk is not synchronized with the optimizer
|
||||
# just use another way for the movement
|
||||
if not self.optim_sync_flag:
|
||||
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
|
||||
assert device.type == "cuda", "each chunk should first be moved to CUDA"
|
||||
self.__paired_shard_move()
|
||||
self.optim_sync_flag = True
|
||||
return
|
||||
|
||||
if device.type == 'cuda':
|
||||
if device.type == "cuda":
|
||||
assert device == get_current_device(), "can't move chunk to another device"
|
||||
|
||||
if self.cuda_shard:
|
||||
@@ -333,7 +339,7 @@ class Chunk:
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
elif device.type == 'cpu':
|
||||
elif device.type == "cpu":
|
||||
if self.cuda_shard is None:
|
||||
return
|
||||
|
||||
@@ -350,8 +356,7 @@ class Chunk:
|
||||
raise NotImplementedError
|
||||
|
||||
def access_chunk(self):
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
|
||||
"""
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
@@ -360,8 +365,7 @@ class Chunk:
|
||||
self.__update_tensors_ptr()
|
||||
|
||||
def release_chunk(self):
|
||||
"""Release the usable chunk. It's an operation done in CUDA.
|
||||
"""
|
||||
"""Release the usable chunk. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
@@ -369,8 +373,7 @@ class Chunk:
|
||||
self.__scatter()
|
||||
|
||||
def reduce(self):
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA.
|
||||
"""
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
||||
@@ -423,20 +426,18 @@ class Chunk:
|
||||
assert self.is_gathered
|
||||
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
|
||||
|
||||
def get_valid_length(self) -> int:
|
||||
"""Get the valid length of the chunk's payload.
|
||||
"""
|
||||
"""Get the valid length of the chunk's payload."""
|
||||
if self.keep_gathered:
|
||||
return self.utilized_size
|
||||
else:
|
||||
return self.valid_end
|
||||
|
||||
def init_pair(self, friend_chunk: 'Chunk') -> None:
|
||||
"""Initialize the paired chunk.
|
||||
"""
|
||||
def init_pair(self, friend_chunk: "Chunk") -> None:
|
||||
"""Initialize the paired chunk."""
|
||||
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
|
||||
self.paired_chunk = friend_chunk
|
||||
friend_chunk.paired_chunk = self
|
||||
@@ -445,8 +446,7 @@ class Chunk:
|
||||
assert friend_chunk.paired_chunk is self
|
||||
|
||||
def optim_update(self) -> None:
|
||||
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
|
||||
"""
|
||||
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer."""
|
||||
# sanity check
|
||||
assert self.paired_chunk is not None
|
||||
|
||||
@@ -455,15 +455,15 @@ class Chunk:
|
||||
assert friend_chunk.is_gathered is True
|
||||
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
|
||||
self.optim_sync_flag = True
|
||||
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
|
||||
elif friend_chunk.device_type == "cuda" and self.device_type == "cuda":
|
||||
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
||||
self.optim_sync_flag = True
|
||||
self.cpu_vis_flag = False
|
||||
else:
|
||||
# optim_sync_flag is set to False
|
||||
# see shard_move function for more details
|
||||
assert friend_chunk.device_type == 'cpu'
|
||||
assert self.device_type == 'cpu'
|
||||
assert friend_chunk.device_type == "cpu"
|
||||
assert self.device_type == "cpu"
|
||||
self.optim_sync_flag = False
|
||||
self.cpu_vis_flag = False
|
||||
|
||||
@@ -492,7 +492,7 @@ class Chunk:
|
||||
|
||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device)
|
||||
|
||||
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end])
|
||||
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin : self.shard_end])
|
||||
|
||||
free_storage(self.cuda_global_chunk)
|
||||
self.is_gathered = False
|
||||
@@ -518,7 +518,7 @@ class Chunk:
|
||||
assert type(self.cuda_global_chunk) == torch.Tensor
|
||||
|
||||
for tensor, tensor_info in self.tensors_info.items():
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
|
||||
|
||||
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
|
||||
self.tensor_state_cnter[tensor_info.state] -= 1
|
||||
@@ -539,38 +539,41 @@ class Chunk:
|
||||
def __repr__(self, detailed: bool = True):
|
||||
output = [
|
||||
"Chunk Information:\n",
|
||||
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
|
||||
self.pg_size),
|
||||
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
|
||||
self.chunk_size, self.dtype, self.pg_size
|
||||
),
|
||||
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
|
||||
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
|
||||
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size
|
||||
),
|
||||
]
|
||||
|
||||
def print_tensor(tensor, prefix=''):
|
||||
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
|
||||
tensor.device))
|
||||
def print_tensor(tensor, prefix=""):
|
||||
output.append(
|
||||
"{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, tensor.device)
|
||||
)
|
||||
|
||||
if self.chunk_temp is not None:
|
||||
output.append("\tchunk temp:\n")
|
||||
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
|
||||
print_tensor(tensor=self.chunk_temp, prefix="\t\t")
|
||||
|
||||
if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0:
|
||||
output.append("\tchunk total:\n")
|
||||
print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t')
|
||||
print_tensor(tensor=self.cuda_global_chunk, prefix="\t\t")
|
||||
|
||||
if self.cuda_shard is not None:
|
||||
output.append("\tcuda shard:\n")
|
||||
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
|
||||
print_tensor(tensor=self.cuda_shard, prefix="\t\t")
|
||||
|
||||
if self.cpu_shard is not None:
|
||||
output.append("\tcpu shard:\n")
|
||||
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
|
||||
print_tensor(tensor=self.cpu_shard, prefix="\t\t")
|
||||
|
||||
memory_info = self.memory_usage
|
||||
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
|
||||
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info["cuda"], memory_info["cpu"]))
|
||||
|
||||
if detailed:
|
||||
output.append("\ttensor state monitor:\n")
|
||||
for st in TensorState:
|
||||
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
|
||||
|
||||
return ''.join(output)
|
||||
return "".join(output)
|
||||
|
||||
@@ -20,27 +20,28 @@ class ChunkManager:
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||
|
||||
self.device = init_device or get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size")
|
||||
v["init_device"] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
|
||||
|
||||
def register_tensor(self,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
def register_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Register a tensor to the chunk manager.
|
||||
Then, the tensor should be accessed by `get_chunks`.
|
||||
@@ -94,25 +95,22 @@ class ChunkManager:
|
||||
self.tensor_chunk_map[tensor] = chunk_group[-1]
|
||||
|
||||
def close_all_groups(self):
|
||||
"""Close all the chunks of all groups.
|
||||
"""
|
||||
"""Close all the chunks of all groups."""
|
||||
for group_name in self.chunk_groups:
|
||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
"""Make the chunk can be used for calculation.
|
||||
"""
|
||||
"""Make the chunk can be used for calculation."""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == 'cpu':
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""Scatter the chunk in CUDA.
|
||||
"""
|
||||
"""Scatter the chunk in CUDA."""
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
@@ -121,8 +119,7 @@ class ChunkManager:
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
"""Move the shard of the chunk to the target device.
|
||||
"""
|
||||
"""Move the shard of the chunk to the target device."""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
@@ -130,14 +127,12 @@ class ChunkManager:
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
"""Transit tensor state according to pre-defined state machine.
|
||||
"""
|
||||
"""Transit tensor state according to pre-defined state machine."""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Reduce or all reduce the chunk.
|
||||
"""
|
||||
"""Reduce or all reduce the chunk."""
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
@@ -213,18 +208,17 @@ class ChunkManager:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = [
|
||||
'Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
"Chunk Manager Information:\n",
|
||||
"Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n",
|
||||
]
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
msg.append(f"Group {group_name}:\n")
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
msg.append(f"[{i}] {chunk}\n")
|
||||
return "".join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
"""Register a chunk group."""
|
||||
if group_name not in self.chunk_groups:
|
||||
self.chunk_groups[group_name] = deque()
|
||||
return self.chunk_groups[group_name]
|
||||
|
||||
@@ -76,8 +76,9 @@ def _tensor_numel(local_param: ColoParameter) -> int:
|
||||
return local_param.numel()
|
||||
|
||||
|
||||
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
|
||||
def classify_params_by_dp_degree(
|
||||
param_order: OrderedParamGenerator, process_group: ProcessGroup
|
||||
) -> Dict[int, List[ColoParameter]]:
|
||||
"""classify_params_by_dp_degree
|
||||
|
||||
Classify the parameters by their dp degree
|
||||
@@ -105,14 +106,15 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
|
||||
|
||||
def search_chunk_configuration(
|
||||
model: nn.Module,
|
||||
search_range_m: float,
|
||||
search_interval: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_m: float = 32,
|
||||
filter_exlarge_params: bool = True,
|
||||
strict_ddp_flag: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
||||
model: nn.Module,
|
||||
search_range_m: float,
|
||||
search_interval: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_m: float = 32,
|
||||
filter_exlarge_params: bool = True,
|
||||
strict_ddp_flag: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstas: Optional[MemStats] = None,
|
||||
) -> Tuple[Dict, int, int]:
|
||||
"""search_chunk_configuration
|
||||
|
||||
Search the chunk configuration for a model.
|
||||
@@ -168,7 +170,7 @@ def search_chunk_configuration(
|
||||
max_size = max(max_size, max(size_dict[key]))
|
||||
start_size = int(math.ceil(max_size / search_interval) * search_interval)
|
||||
|
||||
min_chunk_waste = float('+inf')
|
||||
min_chunk_waste = float("+inf")
|
||||
best_chunk_size = start_size
|
||||
|
||||
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
|
||||
|
||||
@@ -5,8 +5,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import search_chunk_configuration
|
||||
|
||||
@@ -17,15 +15,17 @@ def safe_div(a, b):
|
||||
return a / b
|
||||
|
||||
|
||||
def init_chunk_manager(model: nn.Module,
|
||||
init_device: Optional[torch.device] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs) -> ChunkManager:
|
||||
def init_chunk_manager(
|
||||
model: nn.Module,
|
||||
init_device: Optional[torch.device] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> ChunkManager:
|
||||
if hidden_dim:
|
||||
search_interval = hidden_dim
|
||||
else:
|
||||
search_interval = 1024 # defaults to 1024
|
||||
search_interval = 1024 # defaults to 1024
|
||||
kwargs["search_interval"] = search_interval
|
||||
|
||||
dist.barrier()
|
||||
@@ -41,11 +41,13 @@ def init_chunk_manager(model: nn.Module,
|
||||
wasted_size /= mega_unit
|
||||
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||
sep='',
|
||||
flush=True)
|
||||
print(
|
||||
"searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||
sep="",
|
||||
flush=True,
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
chunk_manager = ChunkManager(config_dict, init_device)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||
from typing import Any, Iterator, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -12,7 +12,7 @@ from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
|
||||
def _named_params_with_replica(
|
||||
module: nn.Module,
|
||||
prefix: str = '',
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
|
||||
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
|
||||
@@ -21,16 +21,17 @@ def _named_params_with_replica(
|
||||
for name, val in mod._parameters.items():
|
||||
if val is None:
|
||||
continue
|
||||
name = mod_prefix + ('.' if mod_prefix else '') + name
|
||||
name = mod_prefix + ("." if mod_prefix else "") + name
|
||||
yield name, val
|
||||
|
||||
|
||||
def _convert_to_coloparam(param: torch.nn.Parameter,
|
||||
device: torch.device,
|
||||
dtype=torch.float,
|
||||
default_pg: Optional[ProcessGroup] = None,
|
||||
default_dist_spec: Optional[Any] = None) -> ColoParameter:
|
||||
|
||||
def _convert_to_coloparam(
|
||||
param: torch.nn.Parameter,
|
||||
device: torch.device,
|
||||
dtype=torch.float,
|
||||
default_pg: Optional[ProcessGroup] = None,
|
||||
default_dist_spec: Optional[Any] = None,
|
||||
) -> ColoParameter:
|
||||
if type(param) is ColoParameter:
|
||||
return param
|
||||
# detaching tensor is necessary for optimizers.
|
||||
@@ -66,12 +67,13 @@ def ColoModulize(module):
|
||||
|
||||
|
||||
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
def __init__(self,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
dtype: torch.dtype = torch.float,
|
||||
default_pg: Optional[ProcessGroup] = None,
|
||||
default_dist_spec=None):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float,
|
||||
default_pg: Optional[ProcessGroup] = None,
|
||||
default_dist_spec=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
|
||||
@@ -89,6 +91,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
def _register_colo_modules(self):
|
||||
from colossalai.legacy.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
|
||||
|
||||
register_colo_module(torch.nn.Linear, ColoLinear())
|
||||
register_colo_module(torch.nn.Embedding, ColoEmbedding())
|
||||
|
||||
@@ -105,25 +108,25 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
if type(param) is ColoParameter:
|
||||
continue
|
||||
|
||||
split = name.rfind('.')
|
||||
if split >= 0: # param in submodule
|
||||
split = name.rfind(".")
|
||||
if split >= 0: # param in submodule
|
||||
module_name = name[:split]
|
||||
param_name = name[split + 1:]
|
||||
param_name = name[split + 1 :]
|
||||
else:
|
||||
module_name = '' # param in current module
|
||||
module_name = "" # param in current module
|
||||
param_name = name
|
||||
name_list.append((module_name, param_name))
|
||||
|
||||
replaced_tensors = dict(
|
||||
) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
|
||||
replaced_tensors = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
|
||||
for module_name, param_name in name_list:
|
||||
submodule = module.get_submodule(module_name)
|
||||
param = submodule.get_parameter(param_name)
|
||||
if param in replaced_tensors:
|
||||
colo_param = replaced_tensors[param]
|
||||
else:
|
||||
colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg,
|
||||
self._default_dist_spec)
|
||||
colo_param = _convert_to_coloparam(
|
||||
param, self._device, self._dtype, self._default_pg, self._default_dist_spec
|
||||
)
|
||||
replaced_tensors[param] = colo_param
|
||||
delattr(submodule, param_name)
|
||||
setattr(submodule, param_name, colo_param)
|
||||
@@ -136,11 +139,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
for param in module.parameters():
|
||||
param_number += 1
|
||||
meta_param_number += (param.device.type == 'meta')
|
||||
meta_param_number += param.device.type == "meta"
|
||||
|
||||
for buffer in module.buffers():
|
||||
buffer_number += 1
|
||||
meta_buffer_number += (buffer.device.type == 'meta')
|
||||
meta_buffer_number += buffer.device.type == "meta"
|
||||
|
||||
if meta_param_number > 0 and meta_param_number != param_number:
|
||||
raise ValueError("Meta parameters and valued parameters can not be in the same model")
|
||||
@@ -152,11 +155,13 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
buffer.data = buffer.data.to(device=self._device)
|
||||
|
||||
|
||||
def post_process_colo_init_ctx(model: torch.nn.Module,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
dtype: torch.dtype = torch.float,
|
||||
default_pg: Optional[ProcessGroup] = None,
|
||||
default_dist_spec=None):
|
||||
def post_process_colo_init_ctx(
|
||||
model: torch.nn.Module,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float,
|
||||
default_pg: Optional[ProcessGroup] = None,
|
||||
default_dist_spec=None,
|
||||
):
|
||||
"""post_process_colo_init_ctx
|
||||
|
||||
This function is called after `ColoInitContext`.
|
||||
@@ -178,8 +183,8 @@ def post_process_colo_init_ctx(model: torch.nn.Module,
|
||||
# print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter")
|
||||
torch_params.append((n, p))
|
||||
|
||||
for (n, param) in torch_params:
|
||||
name_list = n.split('.')
|
||||
for n, param in torch_params:
|
||||
name_list = n.split(".")
|
||||
module = model
|
||||
for i in range(len(name_list) - 1):
|
||||
module = module._modules[name_list[i]]
|
||||
|
||||
@@ -10,7 +10,7 @@ import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, calculate_tensor_size
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -27,10 +27,10 @@ from .utils import get_temp_total_chunk_on_cuda
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
__all__ = [
|
||||
'GeminiDDP',
|
||||
"GeminiDDP",
|
||||
]
|
||||
|
||||
|
||||
@@ -54,27 +54,28 @@ class GeminiDDP(ModelWrapper):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device('cpu'),
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
search_range_m: int = 32, # chunk search options
|
||||
hidden_dim: Optional[int] = None, # chunk search options
|
||||
min_chunk_size_m: float = 32, # chunk search options
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
verbose: bool = False) -> None:
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device("cpu"),
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
search_range_m: int = 32, # chunk search options
|
||||
hidden_dim: Optional[int] = None, # chunk search options
|
||||
min_chunk_size_m: float = 32, # chunk search options
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
if chunk_config_dict is not None:
|
||||
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
|
||||
@@ -82,22 +83,26 @@ class GeminiDDP(ModelWrapper):
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_m is None:
|
||||
search_range_m = 32
|
||||
self.chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=chunk_init_device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
process_group=process_group,
|
||||
verbose=verbose)
|
||||
self.gemini_manager = GeminiManager(placement_policy,
|
||||
self.chunk_manager,
|
||||
memstats,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
|
||||
self.chunk_manager = init_chunk_manager(
|
||||
model=module,
|
||||
init_device=chunk_init_device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
process_group=process_group,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.gemini_manager = GeminiManager(
|
||||
placement_policy,
|
||||
self.chunk_manager,
|
||||
memstats,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||
)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
@@ -126,13 +131,15 @@ class GeminiDDP(ModelWrapper):
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
param_name = m_name + "." + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
self._init_chunks(
|
||||
param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != "cuda",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
super().__init__(module)
|
||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
@@ -146,19 +153,18 @@ class GeminiDDP(ModelWrapper):
|
||||
def parameters(self, recurse: bool = True):
|
||||
return self.module.parameters(recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||
def named_parameters(self, prefix: str = "", recurse: bool = True):
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||
def named_buffers(self, prefix: str = "", recurse: bool = True):
|
||||
return self.module.named_buffers(prefix, recurse)
|
||||
|
||||
def named_children(self):
|
||||
return self.module.named_children()
|
||||
|
||||
def named_modules(self,
|
||||
memo: Optional[Set[torch.nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
def named_modules(
|
||||
self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||
):
|
||||
return self.module.named_modules(memo, prefix, remove_duplicate)
|
||||
|
||||
@staticmethod
|
||||
@@ -184,11 +190,9 @@ class GeminiDDP(ModelWrapper):
|
||||
# as save/load state dict is overwrited, only return self
|
||||
return self
|
||||
|
||||
def _get_non_persistent_buffers_set(self,
|
||||
module,
|
||||
memo: Optional[Set[nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
def _get_non_persistent_buffers_set(
|
||||
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
memo: a memo to store the set of modules already added to the result
|
||||
@@ -204,19 +208,20 @@ class GeminiDDP(ModelWrapper):
|
||||
if remove_duplicate:
|
||||
memo.add(module)
|
||||
self_non_persistent_set = set(
|
||||
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
|
||||
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
|
||||
)
|
||||
for name, sub_module in module._modules.items():
|
||||
if sub_module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
|
||||
remove_duplicate)
|
||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||
child_non_persistent_set = self._get_non_persistent_buffers_set(
|
||||
sub_module, memo, submodule_prefix, remove_duplicate
|
||||
)
|
||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||
return self_non_persistent_set
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
"""This function is only triggered for inference."""
|
||||
access_list = list(self.chunk_manager.accessed_chunks)
|
||||
# we need to scatter all accessed chunks and move them to their original places
|
||||
for chunk in access_list:
|
||||
@@ -233,7 +238,8 @@ class GeminiDDP(ModelWrapper):
|
||||
# check whether we are in a inference mode
|
||||
grad_flag = torch.is_grad_enabled()
|
||||
if not grad_flag:
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
assert (
|
||||
not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup()
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
|
||||
@@ -250,8 +256,7 @@ class GeminiDDP(ModelWrapper):
|
||||
return outputs
|
||||
|
||||
def _inference_forward(self, *args, **kwargs):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
"""This function is only triggered for inference."""
|
||||
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
|
||||
if not self.scatter_after_inference:
|
||||
# gather all chunks
|
||||
@@ -287,12 +292,14 @@ class GeminiDDP(ModelWrapper):
|
||||
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
|
||||
error_params.append(self.param2name[param])
|
||||
error_str = "\n\t".join(error_params)
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
|
||||
f"{error_str}")
|
||||
raise RuntimeError(
|
||||
"ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
|
||||
f"{error_str}",
|
||||
)
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
|
||||
)
|
||||
self.gemini_manager.post_iter()
|
||||
|
||||
@@ -314,8 +321,10 @@ class GeminiDDP(ModelWrapper):
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
|
||||
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter.")
|
||||
raise RuntimeError(
|
||||
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter."
|
||||
)
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
@@ -339,12 +348,9 @@ class GeminiDDP(ModelWrapper):
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self,
|
||||
destination=None,
|
||||
prefix='',
|
||||
keep_vars=False,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
def state_dict(
|
||||
self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
@@ -391,7 +397,7 @@ class GeminiDDP(ModelWrapper):
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in chunk_to_save_data
|
||||
chunk_to_save_data[tensor] = record_tensor
|
||||
@@ -399,8 +405,9 @@ class GeminiDDP(ModelWrapper):
|
||||
del temp_chunk
|
||||
return chunk_to_save_data
|
||||
|
||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool,
|
||||
dtype: torch.dtype) -> Dict:
|
||||
def _get_param_to_save_data(
|
||||
self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype
|
||||
) -> Dict:
|
||||
"""
|
||||
get param content from chunks.
|
||||
|
||||
@@ -459,11 +466,13 @@ class GeminiDDP(ModelWrapper):
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
if (
|
||||
getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants. If :attr:`strict` is ``True``, then
|
||||
the keys of :attr:`state_dict` must exactly match the keys returned
|
||||
@@ -491,32 +500,38 @@ class GeminiDDP(ModelWrapper):
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
# mypy isn't aware that "_metadata" exists in state_dict
|
||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||
|
||||
prefix = ''
|
||||
prefix = ""
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys)))
|
||||
0,
|
||||
"Unexpected key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in unexpected_keys)
|
||||
),
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
|
||||
)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs))
|
||||
)
|
||||
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
@@ -564,19 +579,21 @@ class GeminiDDP(ModelWrapper):
|
||||
input_param = input_param[0]
|
||||
if input_param.shape != dest_tensor.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(state_key, input_param.shape,
|
||||
dest_tensor.shape))
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(state_key, input_param.shape, dest_tensor.shape)
|
||||
)
|
||||
return
|
||||
try:
|
||||
with torch.no_grad():
|
||||
copy_func(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
|
||||
input_param.size(), ex.args))
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(state_key, dest_tensor.size(), input_param.size(), ex.args)
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(state_key)
|
||||
|
||||
@@ -600,15 +617,15 @@ class GeminiDDP(ModelWrapper):
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
|
||||
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||
elif chunk.cuda_shard is not None:
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
||||
else:
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
@@ -622,8 +639,10 @@ class GeminiDDP(ModelWrapper):
|
||||
load(name, buf, buf.copy_)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state",
|
||||
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
|
||||
if (
|
||||
getattr(self.__class__, "set_extra_state", torch.nn.Module.set_extra_state)
|
||||
is not torch.nn.Module.set_extra_state
|
||||
):
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
@@ -634,7 +653,7 @@ class GeminiDDP(ModelWrapper):
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = key[len(prefix) :]
|
||||
if input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
@@ -659,18 +678,22 @@ class GeminiDDP(ModelWrapper):
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=p,
|
||||
group_type="fp16_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=fp32_p,
|
||||
group_type="fp32_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
@@ -694,7 +717,7 @@ class GeminiDDP(ModelWrapper):
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, "LazyTensor"]) -> None:
|
||||
"""Convert parameter to ColoParameter in-place.
|
||||
Args:
|
||||
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
|
||||
@@ -709,12 +732,14 @@ class GeminiDDP(ModelWrapper):
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=requires_grad)
|
||||
|
||||
def state_dict_shard(self,
|
||||
prefix: str = '',
|
||||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def state_dict_shard(
|
||||
self,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
@@ -770,8 +795,10 @@ class GeminiDDP(ModelWrapper):
|
||||
yield block, block_size
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
if (
|
||||
getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = self.get_extra_state()
|
||||
block, block_size = sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
|
||||
@@ -17,7 +17,6 @@ class TrainingPhase(Enum):
|
||||
|
||||
|
||||
class GeminiZeROHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__()
|
||||
self._gemini_manager = gemini_manager
|
||||
@@ -40,7 +39,11 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||
def post_op(self, params):
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
for p in params:
|
||||
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
|
||||
tensor_state = (
|
||||
TensorState.HOLD
|
||||
if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad
|
||||
else TensorState.HOLD_AFTER_BWD
|
||||
)
|
||||
self._chunk_manager.trans_tensor_state(p, tensor_state)
|
||||
|
||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||
|
||||
@@ -26,12 +26,13 @@ class GeminiManager:
|
||||
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
placement_policy: str,
|
||||
chunk_manager: ChunkManager,
|
||||
memstats: Optional[MemStats] = None,
|
||||
**placement_kwargs) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
placement_policy: str,
|
||||
chunk_manager: ChunkManager,
|
||||
memstats: Optional[MemStats] = None,
|
||||
**placement_kwargs,
|
||||
) -> None:
|
||||
assert placement_policy in PlacementPolicyFactory.get_policy_names()
|
||||
self.policy_name = placement_policy
|
||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||
@@ -39,8 +40,9 @@ class GeminiManager:
|
||||
|
||||
self._premade_memstats_ = memstats is not None
|
||||
self._memstats = memstats
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
||||
self._memstats) if policy_cls.need_mem_stats else None
|
||||
self._mem_stats_collector = (
|
||||
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
|
||||
)
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
@@ -62,7 +64,7 @@ class GeminiManager:
|
||||
|
||||
@property
|
||||
def need_warmup(self) -> bool:
|
||||
return self.policy_name in ('auto', 'const')
|
||||
return self.policy_name in ("auto", "const")
|
||||
|
||||
def is_warmup(self):
|
||||
return self._warmup
|
||||
@@ -85,15 +87,14 @@ class GeminiManager:
|
||||
self._mem_stats_collector.start_collection()
|
||||
|
||||
def post_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
"""This function must be called when each iteration finishes"""
|
||||
if self._mem_stats_collector and self._warmup:
|
||||
self._mem_stats_collector.finish_collection()
|
||||
self._warmup = False
|
||||
self.reset_attributes()
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
""" Adjust the layout of stateful tensors according to the information provided
|
||||
"""Adjust the layout of stateful tensors according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
@@ -102,11 +103,13 @@ class GeminiManager:
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||
self._layout_time += time() - start
|
||||
|
||||
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
vol, evict_time = self._placement_policy.evict_tensors(
|
||||
can_evict_chunks=hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx,
|
||||
)
|
||||
|
||||
self._d2h_volume += vol
|
||||
self._evict_time += evict_time
|
||||
@@ -118,12 +121,12 @@ class GeminiManager:
|
||||
start = time()
|
||||
cuda_demand = 0
|
||||
for chunk in chunks:
|
||||
if chunk.device_type == 'cuda':
|
||||
if chunk.device_type == "cuda":
|
||||
if chunk.is_gathered:
|
||||
pass
|
||||
else:
|
||||
cuda_demand += chunk.chunk_mem - chunk.shard_mem
|
||||
elif chunk.device_type == 'cpu':
|
||||
elif chunk.device_type == "cpu":
|
||||
cuda_demand += chunk.chunk_mem
|
||||
else:
|
||||
raise RuntimeError
|
||||
@@ -159,6 +162,7 @@ class GeminiManager:
|
||||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
||||
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
def setup_grads_device(
|
||||
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
|
||||
) -> None:
|
||||
self._placement_policy.setup_grads_device(params, grads_device_map)
|
||||
|
||||
@@ -10,34 +10,35 @@ from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import GeminiDDP
|
||||
|
||||
__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer']
|
||||
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
|
||||
|
||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
|
||||
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
module: GeminiDDP,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
def __init__(
|
||||
self,
|
||||
module: GeminiDDP,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
|
||||
)
|
||||
self.module = module
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
@@ -77,25 +78,28 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
module: GeminiDDP,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
verbose: bool = False,
|
||||
**defaults: Any):
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
module: GeminiDDP,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
verbose: bool = False,
|
||||
**defaults: Any,
|
||||
):
|
||||
super().__init__(optim)
|
||||
assert isinstance(module, GeminiDDP)
|
||||
assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
|
||||
f"{_AVAIL_OPTIM_LIST}"
|
||||
assert type(optim) in _AVAIL_OPTIM_LIST, (
|
||||
"You should use an optimizer in the available list:\n" f"{_AVAIL_OPTIM_LIST}"
|
||||
)
|
||||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||
@@ -118,8 +122,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
for name, param in module.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
if param.requires_grad:
|
||||
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
|
||||
"You should handle its optimizer update by yourself!")
|
||||
warnings.warn(
|
||||
f"Parameter `{name}` is ignored by DDP but requires gradient! "
|
||||
"You should handle its optimizer update by yourself!"
|
||||
)
|
||||
else:
|
||||
ddp_param_list.append(param)
|
||||
|
||||
@@ -132,14 +138,16 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self.__init__optimizer()
|
||||
|
||||
if module.mixed_precision is torch.float16:
|
||||
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(
|
||||
module,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
elif module.mixed_precision is torch.bfloat16:
|
||||
self.mix_precision_mixin = BF16MixedPrecisionMixin()
|
||||
else:
|
||||
@@ -148,12 +156,15 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0"
|
||||
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
||||
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
||||
# and it must set `num_fp32_shards_per_param` correctly
|
||||
self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
||||
optim, 'num_fp32_shards_per_param', 0) >= 2
|
||||
self._should_move_fp32_params_h2d: bool = (
|
||||
self.gemini_manager.is_cuda_margin_mem_avail
|
||||
and self.gpu_margin_mem_ratio > 0.0
|
||||
and getattr(optim, "num_fp32_shards_per_param", 0) >= 2
|
||||
)
|
||||
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
|
||||
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
|
||||
|
||||
@@ -161,7 +172,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def _set_grad_ptr(self):
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
for fake_param in group["params"]:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
chunk16 = chunk32.paired_chunk
|
||||
@@ -173,7 +184,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
def _update_fp16_params(self):
|
||||
none_tensor = torch.empty([0])
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
for fake_param in group["params"]:
|
||||
assert fake_param.grad is None
|
||||
fake_param.data = none_tensor.to(fake_param.device)
|
||||
|
||||
@@ -198,7 +209,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
group_to_norm[c16.torch_pg] = 0.0
|
||||
group_to_norm[c16.torch_pg] += c16.l2_norm
|
||||
|
||||
c16.l2_norm = None # clear l2 norm
|
||||
c16.l2_norm = None # clear l2 norm
|
||||
|
||||
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
for group, part_norm in group_to_norm.items():
|
||||
@@ -230,9 +241,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
if self.mix_precision_mixin.should_skip_step():
|
||||
if self.verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self._clear_global_norm() # clear recorded norm
|
||||
self.zero_grad() # reset all gradients
|
||||
self._logger.info(f"Found overflow. Skip step")
|
||||
self._clear_global_norm() # clear recorded norm
|
||||
self.zero_grad() # reset all gradients
|
||||
self._update_fp16_params()
|
||||
return
|
||||
|
||||
@@ -269,11 +280,11 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
fp32_params_used_cuda_margin_mem = 0
|
||||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
for fake_param in group["params"]:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
chunk16 = chunk32.paired_chunk
|
||||
|
||||
if chunk32.device_type == 'cuda':
|
||||
if chunk32.device_type == "cuda":
|
||||
continue
|
||||
|
||||
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
||||
@@ -284,9 +295,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
|
||||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
for fake_param in group["params"]:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
if chunk32.device_type == 'cuda':
|
||||
if chunk32.device_type == "cuda":
|
||||
state = self.optim.state[fake_param]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
@@ -294,14 +305,13 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def _register_states_(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
state = self.optim.state[p]
|
||||
for val in state.values():
|
||||
if isinstance(val, torch.Tensor):
|
||||
self.chunk_manager.add_extern_static_tensor(val)
|
||||
|
||||
def __init__optimizer(self):
|
||||
|
||||
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
||||
param_info = local_chunk.tensors_info[local_param]
|
||||
if local_chunk.keep_gathered:
|
||||
@@ -313,10 +323,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
param_id = -1
|
||||
for group in self.optim.param_groups:
|
||||
fake_params_list = list()
|
||||
group_backup = {k: v for k, v in group.items() if k != 'params'}
|
||||
group_backup = {k: v for k, v in group.items() if k != "params"}
|
||||
group_ids = []
|
||||
for param in group['params']:
|
||||
|
||||
for param in group["params"]:
|
||||
# Record the mapping of id to current param.
|
||||
param_id += 1
|
||||
self.id_to_real_params[param_id] = param
|
||||
@@ -337,12 +346,12 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
fake_params_list.append(fake_param)
|
||||
|
||||
# Update self.optim.param_groups as well as backup group.
|
||||
group['params'] = fake_params_list
|
||||
group_backup['params'] = group_ids
|
||||
group["params"] = fake_params_list
|
||||
group_backup["params"] = group_ids
|
||||
self.param_groups_backup.append(group_backup)
|
||||
|
||||
def get_offsets(self, param_id: int) -> tuple:
|
||||
'''
|
||||
"""
|
||||
Args:
|
||||
param_id(int): The id of parameter.
|
||||
|
||||
@@ -351,7 +360,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
shard_offset(int): Offset of its optimizer state shard
|
||||
relative to the whole optimizer state.
|
||||
shard_size(int): Length of parameter shard owned by current process.
|
||||
'''
|
||||
"""
|
||||
|
||||
if param_id not in self.id_to_fake_params:
|
||||
return -1, -1, -1
|
||||
@@ -425,11 +434,11 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
if is_collector:
|
||||
states = self.optim.state[fake_param]
|
||||
for state_name in state_names:
|
||||
if state_name == 'step':
|
||||
if state_name == "step":
|
||||
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
|
||||
collected_states[state_name] = torch.tensor(states['step'],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).cpu()
|
||||
collected_states[state_name] = torch.tensor(
|
||||
states["step"], dtype=torch.float32, requires_grad=False
|
||||
).cpu()
|
||||
else:
|
||||
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
@@ -441,12 +450,13 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
# Collector gets prepared for state collecting.
|
||||
if is_collector:
|
||||
for state_name in state_names:
|
||||
if state_name == 'step':
|
||||
if state_name == "step":
|
||||
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
|
||||
collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu()
|
||||
else:
|
||||
collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32,
|
||||
requires_grad=False).cpu()
|
||||
collected_states[state_name] = torch.zeros(
|
||||
param.numel(), dtype=torch.float32, requires_grad=False
|
||||
).cpu()
|
||||
|
||||
# Materials for gathering, including compacted state tensors, and the offset of shard inside each state.
|
||||
compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
|
||||
@@ -465,8 +475,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
shard_size = state_shard[2]
|
||||
if compacted_states is None:
|
||||
continue
|
||||
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
|
||||
shard_size)
|
||||
self.load_from_compacted_states(
|
||||
compacted_states, collected_states, state_names, shard_offset, shard_size
|
||||
)
|
||||
|
||||
# Reshape tensors
|
||||
if is_collector:
|
||||
@@ -476,14 +487,16 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
return collected_states
|
||||
|
||||
def pack_optimizer_states_to_tensor(self,
|
||||
param_id: int,
|
||||
state_names: list,
|
||||
device: torch.device = torch.device('cuda'),
|
||||
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
'''
|
||||
def pack_optimizer_states_to_tensor(
|
||||
self,
|
||||
param_id: int,
|
||||
state_names: list,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
With param id given, pack its optimizer states into a compact tensor and return.
|
||||
'''
|
||||
"""
|
||||
if param_id not in self.id_to_fake_params:
|
||||
return None
|
||||
|
||||
@@ -493,7 +506,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
shard_size = param_range[1] - param_range[0]
|
||||
compacted_size = 0
|
||||
for name in state_names:
|
||||
if name == 'step':
|
||||
if name == "step":
|
||||
compacted_size += 1
|
||||
else:
|
||||
compacted_size += shard_size
|
||||
@@ -502,7 +515,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
next_state_offset = 0
|
||||
for state_name, state_tensor in states.items():
|
||||
# State 'step' needs special operation.
|
||||
if state_name == 'step':
|
||||
if state_name == "step":
|
||||
if isinstance(state_tensor, torch.Tensor):
|
||||
compacted_states[next_state_offset] = state_tensor[0].item()
|
||||
else:
|
||||
@@ -511,47 +524,53 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
next_state_offset += 1
|
||||
else:
|
||||
assert state_tensor.numel() == shard_size
|
||||
compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor)
|
||||
compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor)
|
||||
next_state_offset += shard_size
|
||||
|
||||
return compacted_states
|
||||
|
||||
def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list,
|
||||
shard_start: int, shard_size: int):
|
||||
'''
|
||||
def load_from_compacted_states(
|
||||
self,
|
||||
compacted_states: torch.Tensor,
|
||||
collected_states: dict,
|
||||
state_names: list,
|
||||
shard_start: int,
|
||||
shard_size: int,
|
||||
):
|
||||
"""
|
||||
Given a tensor carrying compacted optimizer states,
|
||||
update these states to collected_states.
|
||||
'''
|
||||
"""
|
||||
shard_end = shard_start + shard_size
|
||||
next_state_offset = 0
|
||||
|
||||
for state_name in state_names:
|
||||
if state_name == 'step':
|
||||
collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(),
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).cpu()
|
||||
if state_name == "step":
|
||||
collected_states["step"].data = torch.tensor(
|
||||
compacted_states[next_state_offset].item(), dtype=torch.float32, requires_grad=False
|
||||
).cpu()
|
||||
next_state_offset += 1
|
||||
else:
|
||||
target_segment = collected_states[state_name][shard_start:shard_end]
|
||||
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
|
||||
target_segment.copy_(compacted_states[next_state_offset : next_state_offset + shard_size])
|
||||
next_state_offset += shard_size
|
||||
|
||||
def get_param_groups_for_saving(self) -> list:
|
||||
'''
|
||||
"""
|
||||
Return the param_groups in Pytorch format when saving to checkpoint.
|
||||
'''
|
||||
"""
|
||||
|
||||
param_groups = copy.deepcopy(self.param_groups_backup)
|
||||
|
||||
# To be compatible with pytorch checkpointing,
|
||||
# store extra hyperparameters used by pytorch Adam optimizer.
|
||||
torch_special_hyperparameters = {
|
||||
'amsgrad': False,
|
||||
'maximize': False,
|
||||
'foreach': None,
|
||||
'capturable': False,
|
||||
'differentiable': False,
|
||||
'fused': False
|
||||
"amsgrad": False,
|
||||
"maximize": False,
|
||||
"foreach": None,
|
||||
"capturable": False,
|
||||
"differentiable": False,
|
||||
"fused": False,
|
||||
}
|
||||
|
||||
for group in param_groups:
|
||||
@@ -580,13 +599,13 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
so it should be called only when memory resources are abundant.
|
||||
"""
|
||||
state_dict = {}
|
||||
state_dict['param_groups'] = self.get_param_groups_for_saving()
|
||||
state_dict["param_groups"] = self.get_param_groups_for_saving()
|
||||
|
||||
# Collect optimizer states.
|
||||
state_dict['state'] = dict()
|
||||
state_dict["state"] = dict()
|
||||
for param_id in self.id_to_real_params.keys():
|
||||
dist.barrier()
|
||||
state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
state_dict["state"][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
return state_dict
|
||||
|
||||
def load_param_groups(self, saved_param_groups: list):
|
||||
@@ -601,13 +620,13 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
for group in saved_param_groups:
|
||||
fake_params_list = list()
|
||||
updated_group = {k: v for k, v in group.items() if k != 'params'}
|
||||
for param_id in group['params']:
|
||||
updated_group = {k: v for k, v in group.items() if k != "params"}
|
||||
for param_id in group["params"]:
|
||||
if param_id not in self.id_to_fake_params:
|
||||
continue
|
||||
fake_param = self.id_to_fake_params[param_id]
|
||||
fake_params_list.append(fake_param)
|
||||
updated_group['params'] = fake_params_list
|
||||
updated_group["params"] = fake_params_list
|
||||
self.optim.param_groups.append(updated_group)
|
||||
|
||||
def load_single_param_states(self, param_id: int, saved_states: dict):
|
||||
@@ -621,15 +640,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
"""
|
||||
assert isinstance(value, torch.Tensor)
|
||||
ret_val = value
|
||||
if (key == "step"):
|
||||
if key == "step":
|
||||
assert value.numel() == 1
|
||||
ret_val = int(value.item())
|
||||
else:
|
||||
state_start, state_end = state_range
|
||||
ret_val = torch.zeros(state_end - state_start,
|
||||
dtype=torch.float32,
|
||||
device=param.device,
|
||||
requires_grad=False)
|
||||
ret_val = torch.zeros(
|
||||
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
|
||||
)
|
||||
ret_val.copy_(value.flatten()[state_start:state_end])
|
||||
return ret_val
|
||||
|
||||
@@ -642,7 +660,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
updated_states = dict()
|
||||
for k, v in saved_states.items():
|
||||
updated_states[k] = cast(fake_param, state_range, v, k)
|
||||
del v # clean loaded states
|
||||
del v # clean loaded states
|
||||
self.optim.state[fake_param].update(updated_states)
|
||||
|
||||
def load_param_states(self, param_states: dict):
|
||||
@@ -658,8 +676,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def optimizer_loading_epilogue(self):
|
||||
# Epilogue when loading state_dict to pytorch optimizer.
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
self.optim.defaults.setdefault('differentiable', False)
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
self.optim.defaults.setdefault("differentiable", False)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""Loads optimizer state from complete optimizer state_dict.
|
||||
@@ -669,16 +687,15 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
state_dict (dict): optimizer state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
assert 'param_groups' in state_dict
|
||||
assert 'state' in state_dict
|
||||
self.load_param_groups(state_dict['param_groups'])
|
||||
self.load_param_states(state_dict['state'])
|
||||
assert "param_groups" in state_dict
|
||||
assert "state" in state_dict
|
||||
self.load_param_groups(state_dict["param_groups"])
|
||||
self.load_param_states(state_dict["state"])
|
||||
self.optimizer_loading_epilogue()
|
||||
|
||||
def state_shard(self,
|
||||
prefix: str = '',
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def state_shard(
|
||||
self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing shards of optimizer states one by one.
|
||||
The max size of each dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
@@ -694,7 +711,6 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
sharder = StateDictSharder(max_shard_size)
|
||||
for param_id in self.id_to_real_params.keys():
|
||||
|
||||
dist.barrier()
|
||||
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
|
||||
@@ -705,19 +721,20 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||
raise NotImplementedError('Gemini does not support clip_grad_by_value')
|
||||
raise NotImplementedError("Gemini does not support clip_grad_by_value")
|
||||
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
|
||||
def clip_grad_by_norm(
|
||||
self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm")
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(GeminiOptimizer):
|
||||
|
||||
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
||||
optimizer = HybridAdam(model.parameters(), **defaults)
|
||||
super().__init__(optimizer, model, **defaults)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from .param_runtime_order import OrderedParamGenerator # isort:skip
|
||||
from .memory_stats import MemStats # isort:skip
|
||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||
from .memstats_collector import MemStatsCollector # isort:skip
|
||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||
from .param_runtime_order import OrderedParamGenerator # isort:skip
|
||||
from .memory_stats import MemStats # isort:skip
|
||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||
from .memstats_collector import MemStatsCollector # isort:skip
|
||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||
|
||||
__all__ = [
|
||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'MemStats',
|
||||
'OrderedParamGenerator'
|
||||
"AsyncMemoryMonitor",
|
||||
"SyncCudaMemoryMonitor",
|
||||
"MemStatsCollector",
|
||||
"ChunkMemStatsCollector",
|
||||
"MemStats",
|
||||
"OrderedParamGenerator",
|
||||
]
|
||||
|
||||
@@ -8,7 +8,6 @@ from .memstats_collector import MemStatsCollector
|
||||
|
||||
|
||||
class ChunkMemStatsCollector(MemStatsCollector):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||
"""
|
||||
|
||||
@@ -27,10 +26,11 @@ class ChunkMemStatsCollector(MemStatsCollector):
|
||||
record model data volume on cuda and cpu.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||
cuda_mem = self._chunk_manager.total_mem["cuda"]
|
||||
self._memstats.record_max_cuda_model_data(cuda_mem)
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> float:
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
|
||||
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda
|
||||
|
||||
@@ -111,6 +111,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
|
||||
|
||||
def _measure_usage(self):
|
||||
from colossalai.legacy.utils import colo_device_memory_used
|
||||
|
||||
max_usage = 0
|
||||
while self.keep_measuring:
|
||||
max_usage = max(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -6,7 +6,6 @@ from .param_runtime_order import OrderedParamGenerator
|
||||
|
||||
|
||||
class MemStats(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Store the non model data statistics used for Gemini and GeminiOptimizer.
|
||||
@@ -92,17 +91,17 @@ class MemStats(object):
|
||||
return self._param_runtime_order
|
||||
|
||||
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
if device_type == "cuda":
|
||||
return self._non_model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
elif device_type == "cpu":
|
||||
return self._non_model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def max_non_model_data(self, device_type: str) -> float:
|
||||
if device_type == 'cuda':
|
||||
if device_type == "cuda":
|
||||
return max(self._non_model_data_cuda_list)
|
||||
elif device_type == 'cpu':
|
||||
elif device_type == "cpu":
|
||||
return max(self._non_model_data_cpu_list)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
@@ -40,11 +40,12 @@ class MemStatsCollector:
|
||||
Returns:
|
||||
int: max non model data memory usage of current sampling period
|
||||
"""
|
||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
||||
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
|
||||
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
|
||||
assert not self._start_flag, "Cannot get mem stats info during collection phase."
|
||||
assert self._step_total > 0, "Cannot get mem stats info before collection phase."
|
||||
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, (
|
||||
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "
|
||||
f"step total {self._step_total}"
|
||||
)
|
||||
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
|
||||
self._step_idx = (self._step_idx + 1) % self._step_total
|
||||
return next_non_model_data
|
||||
@@ -60,9 +61,9 @@ class MemStatsCollector:
|
||||
def finish_collection(self):
|
||||
self.sample_overall_data()
|
||||
# self._step_total = len(self._sampling_time)
|
||||
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
||||
self._step_total = len(self._memstats.non_model_data_list("cuda"))
|
||||
self._start_flag = False
|
||||
print(f'finish_collection {self._step_total}')
|
||||
print(f"finish_collection {self._step_total}")
|
||||
|
||||
# deprecated
|
||||
def record_model_data_volume(self) -> None:
|
||||
@@ -73,7 +74,7 @@ class MemStatsCollector:
|
||||
from colossalai.legacy.zero.gemini import StatefulTensor
|
||||
|
||||
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"]
|
||||
self._memstats.record_max_cuda_model_data(cuda_mem)
|
||||
|
||||
def sample_overall_data(self) -> None:
|
||||
|
||||
@@ -4,7 +4,6 @@ import torch
|
||||
|
||||
|
||||
class ParamGenerator(ABC):
|
||||
|
||||
def append(self, param: torch.nn.Parameter):
|
||||
pass
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from colossalai.utils import _cast_float
|
||||
|
||||
from .memory_stats import MemStats
|
||||
|
||||
__all__ = ['RuntimeMemTracer']
|
||||
__all__ = ["RuntimeMemTracer"]
|
||||
|
||||
|
||||
class RuntimeMemTracer():
|
||||
class RuntimeMemTracer:
|
||||
"""RuntimeMemTracer for the module training using ColoParameter.
|
||||
|
||||
Trace non-model memory usage during fwd+bwd process.
|
||||
|
||||
@@ -15,9 +15,9 @@ from .chunk_memstats_collector import ChunkMemStatsCollector
|
||||
|
||||
|
||||
class ModuleInfos:
|
||||
|
||||
def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str,
|
||||
parent_module: torch.nn.Module):
|
||||
def __init__(
|
||||
self, module: torch.nn.Module, module_name: str, module_full_name: str, parent_module: torch.nn.Module
|
||||
):
|
||||
self.module = module
|
||||
self.module_name = module_name
|
||||
self.module_full_name = module_full_name
|
||||
@@ -35,14 +35,13 @@ class StaticMemStatsCollector(ChunkMemStatsCollector):
|
||||
self.module_info_list = []
|
||||
|
||||
def init_mem_stats(self, *inputs):
|
||||
|
||||
self.register_opnodes_recursively(self.module)
|
||||
self.refactor_module()
|
||||
|
||||
self.module = self.module.cpu()
|
||||
self.module.train()
|
||||
|
||||
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
|
||||
data = [MetaTensor(torch.rand(inp.shape, device="meta"), fake_device="cpu") for inp in inputs]
|
||||
gm = symbolic_trace(self.module)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(*data)
|
||||
@@ -87,12 +86,13 @@ class StaticMemStatsCollector(ChunkMemStatsCollector):
|
||||
for modInfo in self.module_info_list:
|
||||
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
|
||||
|
||||
def register_opnodes_recursively(self,
|
||||
module: torch.nn.Module,
|
||||
name: str = "",
|
||||
full_name: str = "",
|
||||
parent_module: Optional[torch.nn.Module] = None):
|
||||
|
||||
def register_opnodes_recursively(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
name: str = "",
|
||||
full_name: str = "",
|
||||
parent_module: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
|
||||
for child_name, child in module.named_children():
|
||||
|
||||
@@ -14,7 +14,7 @@ def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
|
||||
"""
|
||||
if optim is None:
|
||||
return 0, 0
|
||||
assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()"
|
||||
assert hasattr(optim, "get_memory_usage"), f"{type(optim)} has no attr get_memory_usage()"
|
||||
return optim.get_memory_usage()
|
||||
|
||||
|
||||
@@ -35,16 +35,16 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return 0, 0
|
||||
assert isinstance(t, torch.Tensor)
|
||||
_cpu_mem_usage, _cuda_mem_usage = 0, 0
|
||||
if t.device.type == 'cpu':
|
||||
if t.device.type == "cpu":
|
||||
_cpu_mem_usage += t.numel() * t.element_size()
|
||||
elif t.device.type == 'cuda':
|
||||
elif t.device.type == "cuda":
|
||||
_cuda_mem_usage += t.numel() * t.element_size()
|
||||
return _cuda_mem_usage, _cpu_mem_usage
|
||||
|
||||
cuda_mem_usage = 0
|
||||
cpu_mem_usage = 0
|
||||
for param in model.parameters():
|
||||
if hasattr(param, 'colo_attr'):
|
||||
if hasattr(param, "colo_attr"):
|
||||
t_cuda, t_cpu = param.colo_attr.get_memory_usage()
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
|
||||
@@ -17,10 +17,9 @@ from .memory_tracer import ChunkMemStatsCollector
|
||||
class PlacementPolicy(ABC):
|
||||
need_mem_stats: bool = False
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs
|
||||
) -> None:
|
||||
self.chunk_manager = chunk_manager
|
||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||
|
||||
@@ -29,23 +28,25 @@ class PlacementPolicy(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
def setup_grads_device(
|
||||
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
shard_param_frac: float = 1.0,
|
||||
offload_optim_frac: float = 0.0,
|
||||
offload_param_frac: float = 0.0,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
shard_param_frac: float = 1.0,
|
||||
offload_optim_frac: float = 0.0,
|
||||
offload_param_frac: float = 0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||
warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
|
||||
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
||||
offload_param_frac = 0.0
|
||||
self.shard_param_frac = shard_param_frac
|
||||
self.offload_optim_frac = offload_optim_frac
|
||||
@@ -66,13 +67,14 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||
for chunk in can_evict_chunks:
|
||||
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
|
||||
break
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
self.chunk_manager.move_chunk(chunk, torch.device("cpu"))
|
||||
# real saved mem is shard_mem, for simplicity we use chunk_mem
|
||||
can_offload_chunk_mem -= chunk.chunk_mem
|
||||
return 0, 0.0
|
||||
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
def setup_grads_device(
|
||||
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
|
||||
) -> None:
|
||||
total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
|
||||
|
||||
offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
|
||||
@@ -85,7 +87,7 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
|
||||
device = get_current_device()
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
device = torch.device("cpu")
|
||||
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
|
||||
offloaded_optim_chunk_mem += chunk.chunk_mem
|
||||
for p in params:
|
||||
@@ -97,12 +99,14 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||
class AutoPlacementPolicy(PlacementPolicy):
|
||||
need_mem_stats: bool = True
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
warmup_non_model_data_ratio: float = 0.8,
|
||||
steady_cuda_cap_ratio: float = 0.9,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
warmup_non_model_data_ratio: float = 0.8,
|
||||
steady_cuda_cap_ratio: float = 0.9,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
@@ -110,13 +114,15 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
|
||||
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
|
||||
|
||||
def evict_tensors(self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> Tuple[int, float]:
|
||||
def evict_tensors(
|
||||
self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
|
||||
compute_idx: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[int, float]:
|
||||
"""
|
||||
Evict tensors from CUDA device.
|
||||
|
||||
@@ -135,13 +141,13 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
|
||||
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda")
|
||||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
@@ -160,11 +166,13 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
break
|
||||
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
self.chunk_manager.move_chunk(chunk, torch.device("cpu"))
|
||||
freed_cuda_model_data += chunk.chunk_mem
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
|
||||
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}")
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! "
|
||||
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
return freed_cuda_model_data, time() - start
|
||||
|
||||
@staticmethod
|
||||
@@ -178,8 +186,9 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
def setup_grads_device(
|
||||
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
|
||||
) -> None:
|
||||
for p in params:
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
# init offload optim settings
|
||||
@@ -187,13 +196,13 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
if chunk.keep_gathered:
|
||||
grads_device_map[p] = get_current_device()
|
||||
else:
|
||||
grads_device_map[p] = torch.device('cpu')
|
||||
grads_device_map[p] = torch.device("cpu")
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
'auto': AutoPlacementPolicy,
|
||||
'static': StaticPlacementPolicy,
|
||||
"auto": AutoPlacementPolicy,
|
||||
"static": StaticPlacementPolicy,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -27,16 +27,15 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||
return total_temp
|
||||
|
||||
|
||||
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
|
||||
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
|
||||
"""
|
||||
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ""):
|
||||
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules."""
|
||||
if memo is None:
|
||||
memo = set()
|
||||
if module not in memo:
|
||||
for name, submodule in module._modules.items():
|
||||
if submodule is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
|
||||
yield m
|
||||
|
||||
@@ -60,10 +59,9 @@ def _get_shallow_copy_model(model: nn.Module):
|
||||
return old_to_new[model]
|
||||
|
||||
|
||||
def get_static_torch_model(zero_ddp_model,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
only_rank_0=True) -> torch.nn.Module:
|
||||
def get_static_torch_model(
|
||||
zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True
|
||||
) -> torch.nn.Module:
|
||||
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
||||
You should notice that the original GeminiDDP model is not modified.
|
||||
Thus, you can use the original model in further training.
|
||||
@@ -79,6 +77,7 @@ def get_static_torch_model(zero_ddp_model,
|
||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||
"""
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
|
||||
assert isinstance(zero_ddp_model, GeminiDDP)
|
||||
|
||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||
@@ -86,15 +85,17 @@ def get_static_torch_model(zero_ddp_model,
|
||||
torch_model = _get_shallow_copy_model(colo_model)
|
||||
|
||||
if not only_rank_0 or dist.get_rank() == 0:
|
||||
for (name, colo_module), (_, torch_module) in \
|
||||
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
|
||||
for (name, colo_module), (_, torch_module) in zip(
|
||||
_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)
|
||||
):
|
||||
# clean the parameter list of the new torch module
|
||||
torch_module._parameters = OrderedDict()
|
||||
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
|
||||
# get the full name of the parameter
|
||||
full_param_name = name + ('.' if name else '') + sufix_param_name
|
||||
assert full_param_name in state_dict, \
|
||||
f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
|
||||
full_param_name = name + ("." if name else "") + sufix_param_name
|
||||
assert (
|
||||
full_param_name in state_dict
|
||||
), f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
|
||||
state_param = state_dict[full_param_name]
|
||||
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .low_level_optim import LowLevelZeroOptimizer
|
||||
|
||||
__all__ = ['LowLevelZeroOptimizer']
|
||||
__all__ = ["LowLevelZeroOptimizer"]
|
||||
|
||||
@@ -44,8 +44,8 @@ def shuffle_by_round_robin(tensor_list, num_partitions):
|
||||
for partition_id in range(partitions_count):
|
||||
partition_tensors = partitions[partition_id]
|
||||
for item in partition_tensors:
|
||||
tensor_index_mapping[item['index']] = len(new_tensor_list)
|
||||
new_tensor_list.append(item['tensor'])
|
||||
tensor_index_mapping[item["index"]] = len(new_tensor_list)
|
||||
new_tensor_list.append(item["tensor"])
|
||||
|
||||
return new_tensor_list, tensor_index_mapping
|
||||
|
||||
@@ -107,11 +107,13 @@ def split_by_dtype(tensor_list):
|
||||
return buckets
|
||||
|
||||
|
||||
def reduce_tensor_dp_group(tensor: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dst_local_rank: Optional[int] = None,
|
||||
dst_global_rank: Optional[int] = None,
|
||||
group: Optional[dist.ProcessGroup] = None):
|
||||
def reduce_tensor_dp_group(
|
||||
tensor: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dst_local_rank: Optional[int] = None,
|
||||
dst_global_rank: Optional[int] = None,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
"""
|
||||
Reduce the tensor in the data parallel process group
|
||||
|
||||
@@ -173,7 +175,7 @@ def has_inf_or_nan(tensor):
|
||||
raise
|
||||
return True
|
||||
else:
|
||||
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
|
||||
if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -184,8 +186,7 @@ def release_param_grad(tensor_list):
|
||||
|
||||
|
||||
def calculate_global_norm_from_list(norm_list):
|
||||
""" Compute total from a list of norms
|
||||
"""
|
||||
"""Compute total from a list of norms"""
|
||||
total_norm = 0.0
|
||||
for norm in norm_list:
|
||||
total_norm += norm**2.0
|
||||
@@ -221,7 +222,7 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
|
||||
total_norm = 0.0
|
||||
for g in gradients:
|
||||
param_norm = g.data.double().norm(2)
|
||||
total_norm += param_norm.item()**2
|
||||
total_norm += param_norm.item() ** 2
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
@@ -230,9 +231,9 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
|
||||
if tp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
|
||||
|
||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
|
||||
|
||||
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||||
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
||||
@@ -3,4 +3,4 @@ from .gradient_store import GradientStore
|
||||
from .parameter_store import ParameterStore
|
||||
from .tensor_bucket import TensorBucket
|
||||
|
||||
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
|
||||
__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"]
|
||||
|
||||
@@ -3,7 +3,6 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class BaseStore:
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
self._world_size = dist.get_world_size(group=torch_pg)
|
||||
self._local_rank = dist.get_rank(group=torch_pg)
|
||||
|
||||
@@ -9,7 +9,6 @@ from .base_store import BaseStore
|
||||
|
||||
|
||||
class BucketStore(BaseStore):
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
|
||||
@@ -38,8 +37,7 @@ class BucketStore(BaseStore):
|
||||
return self._num_elements_in_bucket
|
||||
|
||||
def reset_num_elements_in_bucket(self):
|
||||
"""Set the number of elements in bucket to zero.
|
||||
"""
|
||||
"""Set the number of elements in bucket to zero."""
|
||||
|
||||
self._num_elements_in_bucket = 0
|
||||
|
||||
@@ -54,7 +52,7 @@ class BucketStore(BaseStore):
|
||||
|
||||
self._param_list.append(param)
|
||||
self._padding_size.append(padding_size)
|
||||
self._num_elements_in_bucket += (param.numel() + padding_size)
|
||||
self._num_elements_in_bucket += param.numel() + padding_size
|
||||
self.current_group_id = group_id
|
||||
|
||||
# number of tensors in current bucket
|
||||
@@ -119,8 +117,7 @@ class BucketStore(BaseStore):
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the bucket storage after reduction, only release the tensors have been reduced
|
||||
"""
|
||||
"""Reset the bucket storage after reduction, only release the tensors have been reduced"""
|
||||
cur_offset = self.offset_list.pop(0)
|
||||
self._param_list = self._param_list[cur_offset:]
|
||||
self._padding_size = self._padding_size[cur_offset:]
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args, partition_grad: bool = False):
|
||||
super().__init__(*args)
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,6 @@ from .base_store import BaseStore
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
|
||||
def __init__(self, size):
|
||||
self._max_size = size
|
||||
self._current_size = 0
|
||||
@@ -26,8 +25,7 @@ class TensorBucket:
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" \
|
||||
+ f"by tensor (size {tensor_size})"
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._bucket.append(tensor)
|
||||
|
||||
@@ -17,6 +17,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
@@ -32,19 +33,21 @@ from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
num_working_param_groups: int,
|
||||
grad_store: GradientStore,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
def __init__(
|
||||
self,
|
||||
num_working_param_groups: int,
|
||||
grad_store: GradientStore,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
|
||||
)
|
||||
self.num_working_param_groups = num_working_param_groups
|
||||
self.grad_store = grad_store
|
||||
|
||||
@@ -57,32 +60,31 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||
"""
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.,
|
||||
backoff_factor: float = .5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
self._verbose = verbose
|
||||
|
||||
@@ -115,7 +117,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
if forced_dtype:
|
||||
for group in self.optim.param_groups:
|
||||
group_params = group['params']
|
||||
group_params = group["params"]
|
||||
for param in group_params:
|
||||
param.data = param.data.to(forced_dtype)
|
||||
self._dtype = forced_dtype
|
||||
@@ -134,7 +136,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# and add buffers to parameter store for future access
|
||||
for group_id, param_group in enumerate(self.optim.param_groups):
|
||||
group_params = list()
|
||||
for param in param_group['params']:
|
||||
for param in param_group["params"]:
|
||||
if param.requires_grad:
|
||||
group_params.append(param)
|
||||
|
||||
@@ -148,7 +150,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# need to replace the params in the `params` field in the optimizer
|
||||
# so that when the optimizer calls step(), it only updates the tensors
|
||||
# managed by this data parallel rank
|
||||
param_group['params'] = master_param_current_rank
|
||||
param_group["params"] = master_param_current_rank
|
||||
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
@@ -164,15 +166,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# initialize mixed precision mixin
|
||||
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
|
||||
if self._dtype is torch.float16:
|
||||
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
|
||||
self._grad_store,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(
|
||||
self.num_param_groups,
|
||||
self._grad_store,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
elif self._dtype is torch.bfloat16:
|
||||
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||
|
||||
@@ -185,17 +189,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
return len(self._working_param_groups)
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available(), 'CUDA is required'
|
||||
assert torch.cuda.is_available(), "CUDA is required"
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group['params']
|
||||
group_params = param_group["params"]
|
||||
for param in group_params:
|
||||
assert param.dtype == self._dtype, \
|
||||
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
assert (
|
||||
param.dtype == self._dtype
|
||||
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
|
||||
def _create_master_param_current_rank(self, param_list):
|
||||
# split each param evenly by world size
|
||||
params_current_rank = []
|
||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||
device = "cpu" if self._cpu_offload else get_current_device()
|
||||
|
||||
for param in param_list:
|
||||
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
|
||||
@@ -275,8 +280,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
sync_tensor(flat_grads_per_rank[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
|
||||
param_id)) < self._world_size:
|
||||
if (
|
||||
len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id))
|
||||
< self._world_size
|
||||
):
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
||||
@@ -307,8 +314,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# or got a grad of param from another group
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
if (
|
||||
self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size
|
||||
or group_id != self._bucket_store.current_group_id
|
||||
):
|
||||
self._run_reduction()
|
||||
|
||||
padding_size = self._param_store.get_param_padding_size(param)
|
||||
@@ -319,8 +328,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and no_sync are not compatible"
|
||||
assert not (
|
||||
self._partition_grads and not self.require_grad_sync
|
||||
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
@@ -339,8 +349,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
self.zero_grad()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
assert not (
|
||||
self._partition_grads and not self.require_grad_sync
|
||||
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
||||
@@ -380,14 +391,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
####################
|
||||
|
||||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||
self._grad_store.reset_all_gradients()
|
||||
if self._verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self._logger.info(f"Found overflow. Skip step")
|
||||
self.zero_grad()
|
||||
return
|
||||
|
||||
@@ -428,7 +439,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
self._grad_store.reset_grads_by_group_id(group_id)
|
||||
|
||||
# update the params in the optimizer
|
||||
self.optim.param_groups[group_id]['params'] = real_master_params[group_id]
|
||||
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
@@ -445,16 +456,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# update working partition updated by the current rank
|
||||
dtype = real_working_params[0][0].dtype
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]['params']
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
|
||||
working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param))
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
|
||||
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
#############################
|
||||
# Mixed Precision Utilities #
|
||||
@@ -466,14 +477,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if self.mixed_precision_mixin is not None:
|
||||
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
|
||||
|
||||
if self._clip_grad_norm > 0.:
|
||||
if self._clip_grad_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1:
|
||||
div_scale = clip * div_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1. / div_scale)
|
||||
grad.data.mul_(1.0 / div_scale)
|
||||
|
||||
############################
|
||||
# Gradient Synchronization #
|
||||
@@ -518,18 +529,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
def pack_group(group):
|
||||
nonlocal start_index
|
||||
packed = {k: v for k, v in group.items() if k != 'params'}
|
||||
packed = {k: v for k, v in group.items() if k != "params"}
|
||||
param_mappings.update(
|
||||
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
|
||||
packed['params'] = [param_mappings[id(p)] for p in group['params']]
|
||||
start_index += len(packed['params'])
|
||||
{id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings}
|
||||
)
|
||||
packed["params"] = [param_mappings[id(p)] for p in group["params"]]
|
||||
start_index += len(packed["params"])
|
||||
return packed
|
||||
|
||||
param_groups = [pack_group(g) for g in self.optim.param_groups]
|
||||
# Remap state to use order indices as keys
|
||||
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
|
||||
|
||||
return {'state': packed_state, 'param_groups': param_groups}
|
||||
return {"state": packed_state, "param_groups": param_groups}
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
"""Return a state_dict same with DDP
|
||||
@@ -541,14 +553,15 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for param, state in self.optim.state.items():
|
||||
zero_state[param] = copy.deepcopy(state)
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
working_param = self._param_store.master_to_working_param[id(param)]
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
|
||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
|
||||
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
param_state = (
|
||||
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
zero_state[param][k] = param_state
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
@@ -562,16 +575,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
state_dict (dict): A pytorch form state_dict
|
||||
"""
|
||||
zero_state_dict = copy.deepcopy(state_dict)
|
||||
for param_idx, state in zero_state_dict['state'].items():
|
||||
for param_idx, state in zero_state_dict["state"].items():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
|
||||
@@ -588,7 +601,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
ret_block = dict()
|
||||
ret_block_size = 0
|
||||
|
||||
local_states = self.optim.state_dict()['state']
|
||||
local_states = self.optim.state_dict()["state"]
|
||||
for param_idx, states in local_states.items():
|
||||
current_block_size = 0
|
||||
current_block = copy.deepcopy(states)
|
||||
@@ -601,11 +614,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
working_param = self._param_store.master_to_working_param[id(master_param)]
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
state_tensor = (
|
||||
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
|
||||
@@ -7,10 +7,9 @@ import torch.nn as nn
|
||||
from .gemini import GeminiDDP
|
||||
|
||||
|
||||
def zero_model_wrapper(model: nn.Module,
|
||||
zero_stage: int = 1,
|
||||
gemini_config: Optional[Dict] = None,
|
||||
verbose: bool = False):
|
||||
def zero_model_wrapper(
|
||||
model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None, verbose: bool = False
|
||||
):
|
||||
"""This wrapper function is used to wrap your training model for ZeRO DDP.
|
||||
|
||||
Example:
|
||||
@@ -50,19 +49,21 @@ def zero_model_wrapper(model: nn.Module,
|
||||
return wrapped_model
|
||||
|
||||
|
||||
def zero_optim_wrapper(model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
optim_config: Optional[Dict] = None,
|
||||
verbose: bool = False):
|
||||
def zero_optim_wrapper(
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
optim_config: Optional[Dict] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP.
|
||||
|
||||
Args:
|
||||
@@ -95,20 +96,22 @@ def zero_optim_wrapper(model: nn.Module,
|
||||
else:
|
||||
config_dict = copy(optim_config)
|
||||
|
||||
config_dict['initial_scale'] = initial_scale
|
||||
config_dict['growth_factor'] = growth_factor
|
||||
config_dict['backoff_factor'] = backoff_factor
|
||||
config_dict['growth_interval'] = growth_interval
|
||||
config_dict['hysteresis'] = hysteresis
|
||||
config_dict['min_scale'] = min_scale
|
||||
config_dict['max_scale'] = max_scale
|
||||
config_dict["initial_scale"] = initial_scale
|
||||
config_dict["growth_factor"] = growth_factor
|
||||
config_dict["backoff_factor"] = backoff_factor
|
||||
config_dict["growth_interval"] = growth_interval
|
||||
config_dict["hysteresis"] = hysteresis
|
||||
config_dict["min_scale"] = min_scale
|
||||
config_dict["max_scale"] = max_scale
|
||||
|
||||
if zero_stage in [1, 2]:
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
config_dict['partition_grad'] = zero_stage == 2
|
||||
config_dict['clip_grad_norm'] = max_norm
|
||||
|
||||
config_dict["partition_grad"] = zero_stage == 2
|
||||
config_dict["clip_grad_norm"] = max_norm
|
||||
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
|
||||
else:
|
||||
from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
|
||||
config_dict['clipping_norm'] = max_norm
|
||||
|
||||
config_dict["clipping_norm"] = max_norm
|
||||
return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)
|
||||
|
||||
Reference in New Issue
Block a user