mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[gemini] support amp o3 for gemini (#4872)
* [gemini] support no reuse fp16 chunk * [gemini] support no master weight for optim * [gemini] support no master weight for gemini ddp * [test] update gemini tests * [test] update gemini tests * [plugin] update gemini plugin * [test] fix gemini checkpointio test * [test] fix gemini checkpoint io
This commit is contained in:
@@ -74,6 +74,7 @@ class GeminiDDP(ModelWrapper):
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
@@ -115,6 +116,9 @@ class GeminiDDP(ModelWrapper):
|
||||
self.mixed_precision = mixed_precision
|
||||
self.dp_process_group = process_group or _get_default_group()
|
||||
|
||||
self.reuse_fp16_chunk = master_weights
|
||||
self.master_weights = master_weights
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
@@ -321,20 +325,37 @@ class GeminiDDP(ModelWrapper):
|
||||
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter."
|
||||
)
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
grad_chunk = chunk
|
||||
if not self.reuse_fp16_chunk:
|
||||
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
|
||||
# hold -> compute -> hold after bwd
|
||||
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
|
||||
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
|
||||
# fp16 param chunk: hold after bwd -> ready for reduce -> hold
|
||||
chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.tensor_trans_state(p, TensorState.HOLD)
|
||||
|
||||
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
|
||||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
|
||||
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
|
||||
if reduced:
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
if not self.reuse_fp16_chunk:
|
||||
if chunk.keep_gathered:
|
||||
self.chunk_manager.fake_release_chunk(chunk)
|
||||
else:
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if grad_chunk.is_gathered:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
else:
|
||||
chunk.cuda_shard.div_(chunk.pg_size)
|
||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||
# check overflow elements
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping
|
||||
self.overflow_counter += grad_chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||
if chunk.l2_norm_flag:
|
||||
chunk.set_l2_norm()
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
grad_chunk.set_l2_norm()
|
||||
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
|
||||
if not self.master_weights:
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
@@ -344,9 +365,7 @@ class GeminiDDP(ModelWrapper):
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(
|
||||
self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
@@ -365,7 +384,7 @@ class GeminiDDP(ModelWrapper):
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
||||
|
||||
for hook in self._state_dict_hooks.values():
|
||||
hook_result = hook(self, destination, prefix, local_metadata)
|
||||
@@ -373,7 +392,7 @@ class GeminiDDP(ModelWrapper):
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict:
|
||||
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get gathered chunk content.
|
||||
|
||||
@@ -386,9 +405,8 @@ class GeminiDDP(ModelWrapper):
|
||||
"""
|
||||
# save parameters
|
||||
chunk_to_save_data = dict()
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
if torch.is_floating_point(temp_chunk):
|
||||
temp_chunk = temp_chunk.to(dtype)
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
@@ -401,9 +419,7 @@ class GeminiDDP(ModelWrapper):
|
||||
del temp_chunk
|
||||
return chunk_to_save_data
|
||||
|
||||
def _get_param_to_save_data(
|
||||
self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype
|
||||
) -> Dict:
|
||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get param content from chunks.
|
||||
|
||||
@@ -418,10 +434,10 @@ class GeminiDDP(ModelWrapper):
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||
for chunk in chunk_list:
|
||||
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||
return param_to_save_data
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16):
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
@@ -438,14 +454,18 @@ class GeminiDDP(ModelWrapper):
|
||||
|
||||
# get copies of fp32 parameters in CPU
|
||||
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype)
|
||||
params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
|
||||
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
|
||||
# get the mapping between copies and fp16 parameters
|
||||
p_mapping = dict()
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
name = self.param2name[p]
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
p_mapping[p] = record_parameter
|
||||
if self.reuse_fp16_chunk:
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
name = self.param2name[p]
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
p_mapping[p] = record_parameter
|
||||
else:
|
||||
p_mapping = param_to_save_data
|
||||
for name, param in self.name2param.items():
|
||||
if param is not None:
|
||||
if is_ddp_ignored(param):
|
||||
@@ -593,7 +613,7 @@ class GeminiDDP(ModelWrapper):
|
||||
elif strict:
|
||||
missing_keys.append(state_key)
|
||||
|
||||
def load_fp32_parameter(chunk_slice, data):
|
||||
def load_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
@@ -607,14 +627,15 @@ class GeminiDDP(ModelWrapper):
|
||||
name = self.param2name[p]
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
|
||||
chunk_list = self.chunk_manager.get_chunks(params_to_load)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
load(parameter_name, tensor, partial(load_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||
@@ -624,11 +645,11 @@ class GeminiDDP(ModelWrapper):
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.payload.copy_(chunk_32.payload)
|
||||
if self.reuse_fp16_chunk:
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.payload.copy_(chunk_32.payload)
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
@@ -668,12 +689,9 @@ class GeminiDDP(ModelWrapper):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_p = p.data.float()
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
# register the fp16 parameter
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=p,
|
||||
group_type="fp16_param",
|
||||
@@ -682,22 +700,27 @@ class GeminiDDP(ModelWrapper):
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=fp32_p,
|
||||
group_type="fp32_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
|
||||
if self.master_weights:
|
||||
# create a fp32 parameter
|
||||
fp32_p = p.data.float()
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=fp32_p,
|
||||
group_type="fp32_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.fp32_params.append(fp32_p)
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
|
||||
|
||||
# move master weights to corresponding device and setup paired chunks
|
||||
# if no master weights, fp32_params should be empty and this loop will be skipped
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
@@ -734,7 +757,6 @@ class GeminiDDP(ModelWrapper):
|
||||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
@@ -769,11 +791,11 @@ class GeminiDDP(ModelWrapper):
|
||||
gathered_param = param if keep_vars else param.detach()
|
||||
else:
|
||||
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
|
||||
fp32_param = fp16_to_fp32[param]
|
||||
if fp32_param not in gathered_param_buffer:
|
||||
chunk = self.chunk_manager.get_chunk(fp32_param)
|
||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||
gathered_param = gathered_param_buffer.pop(fp32_param)
|
||||
param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
|
||||
if param_to_save not in gathered_param_buffer:
|
||||
chunk = self.chunk_manager.get_chunk(param_to_save)
|
||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||
gathered_param = gathered_param_buffer.pop(param_to_save)
|
||||
|
||||
block, block_size = sharder.append_param(prefix + name, gathered_param)
|
||||
if block is not None:
|
||||
|
Reference in New Issue
Block a user