mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[feature] new zero implementation (#1623)
This commit is contained in:
@@ -3,16 +3,18 @@ import itertools
|
||||
import torch.distributed as dist
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.gemini.chunk import TensorState, Chunk
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from .reducer import Reducer
|
||||
|
||||
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
@@ -208,28 +210,34 @@ class ZeroDDP(ColoDDP):
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False) -> None:
|
||||
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.fp32_params: List[ColoParameter] = []
|
||||
self.fp32_params: List[ColoTensor] = []
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
|
||||
self.chunk_manager.create_group('fp32_param')
|
||||
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert isinstance(p, ColoParameter)
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
p.data = p.half()
|
||||
continue
|
||||
fp32_p = p.float().detach()
|
||||
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
fp32_data = p.float().data
|
||||
p.data = p.half()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
self._cast_buffers()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@@ -248,10 +256,7 @@ class ZeroDDP(ColoDDP):
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
p.grad = p.data
|
||||
p.grad = None
|
||||
|
||||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
@@ -276,21 +281,22 @@ class ZeroDDP(ColoDDP):
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
if self.dp_world_size > 1:
|
||||
grad = grad / self.dp_world_size
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if reduced and not chunk.is_empty:
|
||||
if reduced:
|
||||
if chunk.is_gathered:
|
||||
chunk.chunk_total.div_(chunk.pg_size)
|
||||
else:
|
||||
chunk.cuda_shard.div_(chunk.pg_size)
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
@@ -311,14 +317,11 @@ class ZeroDDP(ColoDDP):
|
||||
['bias', 'weight']
|
||||
|
||||
"""
|
||||
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
||||
record_flag = (not only_rank_0) or is_rank_0
|
||||
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, record_flag)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
||||
|
||||
for hook in self._state_dict_hooks.values():
|
||||
hook_result = hook(self, destination, prefix, local_metadata)
|
||||
@@ -326,7 +329,7 @@ class ZeroDDP(ColoDDP):
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True):
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
@@ -339,30 +342,30 @@ class ZeroDDP(ColoDDP):
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
# save parameters
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
# record the original device of the chunk
|
||||
org_chunk_dev_typ = chunk.device_type
|
||||
self.chunk_manager.access_chunk(chunk)
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor in chunk.get_tensors():
|
||||
rec_p = torch.empty([0])
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
rec_p = tensor.cpu() # move the whole tensor to CPU mem
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in param_to_save_data
|
||||
param_to_save_data[tensor] = rec_p
|
||||
# release the actual memory of the chunk
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if not chunk.is_empty and org_chunk_dev_typ == 'cpu':
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
param_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
rec_p = param_to_save_data[fp32_p]
|
||||
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
destination[prefix + name] = record_parameter
|
||||
|
||||
# save all buffers
|
||||
for name, buf in self.named_buffers():
|
||||
@@ -466,40 +469,61 @@ class ZeroDDP(ColoDDP):
|
||||
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
def load(name, dest_tensor, copy_func):
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
def load(param_name, dest_tensor, copy_func):
|
||||
state_key = prefix + param_name
|
||||
if state_key in state_dict:
|
||||
input_param = state_dict[state_key]
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
if input_param.shape != dest_tensor.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(key, input_param.shape,
|
||||
'the shape in current model is {}.'.format(state_key, input_param.shape,
|
||||
dest_tensor.shape))
|
||||
return
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param)
|
||||
copy_func(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(),
|
||||
ex.args))
|
||||
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
|
||||
input_param.size(), ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
missing_keys.append(state_key)
|
||||
|
||||
def load_fp32_p(fp32_p, data):
|
||||
if fp32_p.storage().size() > 0:
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data)
|
||||
def load_fp32_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
fp32_to_name = dict()
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
load(name, fp32_p, partial(load_fp32_p, fp32_p))
|
||||
self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.chunk_total.copy_(temp_chunk)
|
||||
elif chunk.cuda_shard is not None:
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
else:
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.optim_update()
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
|
20
colossalai/nn/parallel/utils.py
Normal file
20
colossalai/nn/parallel/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||
if chunk.is_gathered:
|
||||
return chunk.chunk_total
|
||||
|
||||
if chunk.cuda_shard is not None:
|
||||
shard_temp = chunk.cuda_shard
|
||||
else:
|
||||
shard_temp = chunk.cpu_shard.to(get_current_device())
|
||||
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device())
|
||||
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
|
||||
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
|
||||
|
||||
return total_temp
|
Reference in New Issue
Block a user