[gemini] support amp o3 for gemini (#4872)

* [gemini] support no reuse fp16 chunk

* [gemini] support no master weight for optim

* [gemini] support no master weight for gemini ddp

* [test] update gemini tests

* [test] update gemini tests

* [plugin] update gemini plugin

* [test] fix gemini checkpointio test

* [test] fix gemini checkpoint io
This commit is contained in:
Hongxin Liu
2023-10-12 10:39:08 +08:00
committed by GitHub
parent c1fab951e7
commit df63564184
15 changed files with 222 additions and 114 deletions

View File

@@ -160,6 +160,8 @@ class Chunk:
self.l2_norm_flag = False
self.l2_norm = None
self.grad_chunk = None
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
@@ -414,7 +416,9 @@ class Chunk:
return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
def copy_tensor_to_chunk_slice(
self, tensor: torch.Tensor, data_slice: torch.Tensor, update_ptr: bool = True
) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
@@ -427,7 +431,8 @@ class Chunk:
tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
if update_ptr:
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload."""
@@ -577,3 +582,46 @@ class Chunk:
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
return "".join(output)
def init_grad_chunk(self) -> "Chunk":
"""Init grad chunk. This should be called in grad handler.
Returns:
Chunk: Grad chunk
"""
if self.grad_chunk is None:
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
process_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size
grad_chunk.tensor_state_cnter[TensorState.HOLD] = self.num_tensors
for tensor, state in self.tensors_info.items():
grad_chunk.tensors_info[tensor] = TensorInfo(TensorState.HOLD, state.offset, state.end)
grad_chunk.valid_end = self.valid_end
if grad_chunk.chunk_temp.device.type == "cpu":
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device())
else:
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
grad_chunk.chunk_temp = None
if grad_chunk.pin_memory:
grad_chunk.cpu_shard = torch.empty(
grad_chunk.shard_size, dtype=grad_chunk.dtype, pin_memory=grad_chunk.pin_memory
)
self.grad_chunk = grad_chunk
else:
# grad chunk is initialized, just reallocate cuda global chunk
self.grad_chunk.cuda_shard = None
self.grad_chunk.is_gathered = True
alloc_storage(self.grad_chunk.cuda_global_chunk)
return self.grad_chunk

View File

@@ -245,3 +245,13 @@ class ChunkManager:
chunk.release_chunk()
self.accessed_chunks.remove(chunk)
self.accessed_mem -= chunk.chunk_mem
def init_grad_chunk(self, chunk: Chunk) -> Chunk:
if chunk.grad_chunk is not None:
self.__sub_memory_usage(chunk.grad_chunk.memory_usage)
grad_chunk = chunk.init_grad_chunk()
self.__add_memory_usage(grad_chunk.memory_usage)
if grad_chunk not in self.accessed_chunks:
self.accessed_chunks.add(grad_chunk)
self.accessed_mem += grad_chunk.chunk_mem
return grad_chunk

View File

@@ -74,6 +74,7 @@ class GeminiDDP(ModelWrapper):
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
master_weights: bool = True,
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
@@ -115,6 +116,9 @@ class GeminiDDP(ModelWrapper):
self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group()
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
@@ -321,20 +325,37 @@ class GeminiDDP(ModelWrapper):
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter."
)
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
grad_chunk = chunk
if not self.reuse_fp16_chunk:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
# fp16 param chunk: hold after bwd -> ready for reduce -> hold
chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
chunk.tensor_trans_state(p, TensorState.HOLD)
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if chunk.is_gathered:
chunk.cuda_global_chunk.div_(chunk.pg_size)
if not self.reuse_fp16_chunk:
if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
else:
self.chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
grad_chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
self.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not self.master_weights:
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
@@ -344,9 +365,7 @@ class GeminiDDP(ModelWrapper):
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(
self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16
):
def state_dict(self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
@@ -365,7 +384,7 @@ class GeminiDDP(ModelWrapper):
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
@@ -373,7 +392,7 @@ class GeminiDDP(ModelWrapper):
destination = hook_result
return destination
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict:
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
"""
get gathered chunk content.
@@ -386,9 +405,8 @@ class GeminiDDP(ModelWrapper):
"""
# save parameters
chunk_to_save_data = dict()
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
if torch.is_floating_point(temp_chunk):
temp_chunk = temp_chunk.to(dtype)
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
@@ -401,9 +419,7 @@ class GeminiDDP(ModelWrapper):
del temp_chunk
return chunk_to_save_data
def _get_param_to_save_data(
self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype
) -> Dict:
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
"""
get param content from chunks.
@@ -418,10 +434,10 @@ class GeminiDDP(ModelWrapper):
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16):
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
@@ -438,14 +454,18 @@ class GeminiDDP(ModelWrapper):
# get copies of fp32 parameters in CPU
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype)
params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
p_mapping[p] = record_parameter
if self.reuse_fp16_chunk:
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
p_mapping[p] = record_parameter
else:
p_mapping = param_to_save_data
for name, param in self.name2param.items():
if param is not None:
if is_ddp_ignored(param):
@@ -593,7 +613,7 @@ class GeminiDDP(ModelWrapper):
elif strict:
missing_keys.append(state_key)
def load_fp32_parameter(chunk_slice, data):
def load_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
for name, param in self.named_parameters():
@@ -607,14 +627,15 @@ class GeminiDDP(ModelWrapper):
name = self.param2name[p]
fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
chunk_list = self.chunk_manager.get_chunks(params_to_load)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items():
parameter_name = fp32_to_name[tensor]
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
load(parameter_name, tensor, partial(load_parameter, parameter_slice))
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
@@ -624,11 +645,11 @@ class GeminiDDP(ModelWrapper):
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
del temp_chunk
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.payload.copy_(chunk_32.payload)
if self.reuse_fp16_chunk:
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.payload.copy_(chunk_32.payload)
for name, buf in persistent_buffers.items():
if buf is not None:
@@ -668,12 +689,9 @@ class GeminiDDP(ModelWrapper):
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
continue
# create a fp32 parameter
fp32_p = p.data.float()
# create a fp16 parameter
p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager
# register the fp16 parameter
self.chunk_manager.register_tensor(
tensor=p,
group_type="fp16_param",
@@ -682,22 +700,27 @@ class GeminiDDP(ModelWrapper):
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
if self.master_weights:
# create a fp32 parameter
fp32_p = p.data.float()
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
self.fp32_params.append(fp32_p)
self.chunk_manager.close_all_groups()
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
# move master weights to corresponding device and setup paired chunks
# if no master weights, fp32_params should be empty and this loop will be skipped
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
@@ -734,7 +757,6 @@ class GeminiDDP(ModelWrapper):
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16,
) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
@@ -769,11 +791,11 @@ class GeminiDDP(ModelWrapper):
gathered_param = param if keep_vars else param.detach()
else:
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
fp32_param = fp16_to_fp32[param]
if fp32_param not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(fp32_param)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
if param_to_save not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(param_to_save)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
gathered_param = gathered_param_buffer.pop(param_to_save)
block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:

View File

@@ -105,7 +105,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.param_to_chunk16: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm
@@ -130,7 +130,7 @@ class GeminiOptimizer(OptimizerWrapper):
else:
ddp_param_list.append(param)
for p, fp32_p in zip(ddp_param_list, module.fp32_params):
for p in ddp_param_list:
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
@@ -174,13 +174,15 @@ class GeminiOptimizer(OptimizerWrapper):
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param]
chunk16 = self.param_to_chunk16[fake_param]
begin, end = self.param_to_range[fake_param]
chunk16 = chunk32.paired_chunk
fake_param.data = chunk16.payload[begin:end]
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
fake_param.data = grad_chunk16.payload[begin:end]
fake_param.grad = fake_param.data
fake_param.data = chunk32.payload[begin:end]
to_update_chunk = chunk16.paired_chunk if self.module.master_weights else chunk16
fake_param.data = to_update_chunk.payload[begin:end]
def _update_fp16_params(self):
none_tensor = torch.empty([0])
@@ -194,23 +196,25 @@ class GeminiOptimizer(OptimizerWrapper):
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
c16.l2_norm = None
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk.l2_norm = None
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
for c16 in self.chunk16_set:
assert c16.l2_norm is not None
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
assert grad_chunk.l2_norm is not None
if c16.is_gathered:
norm_sqr += c16.l2_norm
if grad_chunk.is_gathered:
norm_sqr += grad_chunk.l2_norm
else:
# this chunk is sharded, use communication to collect total norm
if c16.torch_pg not in group_to_norm:
group_to_norm[c16.torch_pg] = 0.0
group_to_norm[c16.torch_pg] += c16.l2_norm
if grad_chunk.torch_pg not in group_to_norm:
group_to_norm[grad_chunk.torch_pg] = 0.0
group_to_norm[grad_chunk.torch_pg] += grad_chunk.l2_norm
c16.l2_norm = None # clear l2 norm
grad_chunk.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
for group, part_norm in group_to_norm.items():
@@ -237,7 +241,8 @@ class GeminiOptimizer(OptimizerWrapper):
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
if self.module.master_weights:
self._maybe_move_fp32_params()
self._set_grad_ptr()
if self.mix_precision_mixin.should_skip_step():
@@ -245,7 +250,8 @@ class GeminiOptimizer(OptimizerWrapper):
self._logger.info(f"Found overflow. Skip step")
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
self._update_fp16_params()
if self.module.reuse_fp16_chunk:
self._update_fp16_params()
return
# get combined scale. combined scale = loss scale * clipping norm
@@ -255,7 +261,8 @@ class GeminiOptimizer(OptimizerWrapper):
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states()
self.zero_grad()
self._update_fp16_params()
if self.module.master_weights:
self._update_fp16_params()
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
@@ -282,8 +289,8 @@ class GeminiOptimizer(OptimizerWrapper):
for group in self.param_groups:
for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param]
chunk16 = chunk32.paired_chunk
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda":
continue
@@ -297,7 +304,8 @@ class GeminiOptimizer(OptimizerWrapper):
for group in self.param_groups:
for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param]
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda":
state = self.optim.state[fake_param]
for k, v in state.items():
@@ -341,7 +349,7 @@ class GeminiOptimizer(OptimizerWrapper):
continue
grad_device = self.module.grads_device[param]
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
self.param_to_chunk16[fake_param] = chunk16
self.param_to_range[fake_param] = range_pair
self.id_to_fake_params[param_id] = fake_param
fake_params_list.append(fake_param)
@@ -366,7 +374,7 @@ class GeminiOptimizer(OptimizerWrapper):
if param_id not in self.id_to_fake_params:
return -1, -1, -1
fake_param = self.id_to_fake_params[param_id]
chunk = self.param_to_chunk32[fake_param].paired_chunk
chunk = self.param_to_chunk16[fake_param]
param = self.id_to_real_params[param_id]
param_info = chunk.tensors_info[param]

View File

@@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
from .chunk import Chunk
def get_temp_total_chunk_on_cuda(chunk: Chunk):
def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
if chunk.is_gathered:
return chunk.cuda_global_chunk
@@ -20,7 +20,9 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
else:
shard_temp = chunk.cpu_shard.to(get_current_device())
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device())
shard_temp = shard_temp.to(dtype)
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device())
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)