mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
[zero] alleviate memory usage in ZeRODDP state_dict (#1398)
This commit is contained in:
parent
4f5f8f77d1
commit
4e98e938ce
@ -6,12 +6,13 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
|||||||
from colossalai.gemini.chunk import TensorState, Chunk
|
from colossalai.gemini.chunk import TensorState, Chunk
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from typing import Dict, Iterable, List, Optional
|
from typing import Dict, Iterable, List, Optional, Set
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
from .reducer import Reducer
|
from .reducer import Reducer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -84,6 +85,18 @@ class ColoDDP(torch.nn.Module):
|
|||||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||||
return self.module.named_parameters(prefix, recurse)
|
return self.module.named_parameters(prefix, recurse)
|
||||||
|
|
||||||
|
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||||
|
return self.module.named_buffers(prefix, recurse)
|
||||||
|
|
||||||
|
def named_children(self):
|
||||||
|
return self.module.named_children()
|
||||||
|
|
||||||
|
def named_modules(self,
|
||||||
|
memo: Optional[Set[torch.nn.Module]] = None,
|
||||||
|
prefix: str = '',
|
||||||
|
remove_duplicate: bool = True):
|
||||||
|
return self.module.named_modules(memo, prefix, remove_duplicate)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
return self.module(*args, **kwargs)
|
return self.module(*args, **kwargs)
|
||||||
@ -274,7 +287,7 @@ class ZeroDDP(ColoDDP):
|
|||||||
for tensor in chunk.get_tensors():
|
for tensor in chunk.get_tensors():
|
||||||
self.grads_device[tensor] = device
|
self.grads_device[tensor] = device
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||||
r"""Returns a dictionary containing a whole state of the module.
|
r"""Returns a dictionary containing a whole state of the module.
|
||||||
|
|
||||||
Both parameters and persistent buffers (e.g. running averages) are
|
Both parameters and persistent buffers (e.g. running averages) are
|
||||||
@ -291,18 +304,22 @@ class ZeroDDP(ColoDDP):
|
|||||||
['bias', 'weight']
|
['bias', 'weight']
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
||||||
|
record_flag = (not only_rank_0) or is_rank_0
|
||||||
|
|
||||||
if destination is None:
|
if destination is None:
|
||||||
destination = OrderedDict()
|
destination = OrderedDict()
|
||||||
destination._metadata = OrderedDict()
|
destination._metadata = OrderedDict()
|
||||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||||
self._save_to_state_dict(destination, prefix, keep_vars)
|
self._save_to_state_dict(destination, prefix, keep_vars, record_flag)
|
||||||
|
|
||||||
for hook in self._state_dict_hooks.values():
|
for hook in self._state_dict_hooks.values():
|
||||||
hook_result = hook(self, destination, prefix, local_metadata)
|
hook_result = hook(self, destination, prefix, local_metadata)
|
||||||
if hook_result is not None:
|
if hook_result is not None:
|
||||||
destination = hook_result
|
destination = hook_result
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True):
|
||||||
r"""Saves module state to `destination` dictionary, containing a state
|
r"""Saves module state to `destination` dictionary, containing a state
|
||||||
of the module, but not its descendants. This is called on every
|
of the module, but not its descendants. This is called on every
|
||||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||||
@ -315,22 +332,36 @@ class ZeroDDP(ColoDDP):
|
|||||||
prefix (str): the prefix for parameters and buffers used in this
|
prefix (str): the prefix for parameters and buffers used in this
|
||||||
module
|
module
|
||||||
"""
|
"""
|
||||||
chunks = self.chunk_manager.get_chunks(self.fp32_params)
|
# save parameters
|
||||||
chunks_orig_device_type = []
|
param_to_save_data = dict()
|
||||||
for chunk in chunks:
|
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||||
chunks_orig_device_type.append(chunk.device_type)
|
for chunk in chunk_list:
|
||||||
|
# record the original device of the chunk
|
||||||
|
org_chunk_dev_typ = chunk.device_type
|
||||||
self.chunk_manager.access_chunk(chunk)
|
self.chunk_manager.access_chunk(chunk)
|
||||||
|
|
||||||
|
for tensor in chunk.get_tensors():
|
||||||
|
rec_p = torch.empty([0])
|
||||||
|
if record_flag:
|
||||||
|
rec_p = tensor.cpu() # move the whole tensor to CPU mem
|
||||||
|
assert tensor not in param_to_save_data
|
||||||
|
param_to_save_data[tensor] = rec_p
|
||||||
|
# release the actual memory of the chunk
|
||||||
|
self.chunk_manager.release_chunk(chunk)
|
||||||
|
if not chunk.is_empty and org_chunk_dev_typ == 'cpu':
|
||||||
|
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||||
|
|
||||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||||
if p is not None:
|
if p is not None:
|
||||||
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu()
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||||
|
rec_p = param_to_save_data[fp32_p]
|
||||||
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
|
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
|
||||||
for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks):
|
|
||||||
self.chunk_manager.release_chunk(chunk)
|
# save all buffers
|
||||||
if not chunk.is_empty and orig_dvice_type == 'cpu':
|
|
||||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
|
||||||
for name, buf in self.named_buffers():
|
for name, buf in self.named_buffers():
|
||||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||||
|
# save extra states
|
||||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||||
if getattr(self.__class__, "get_extra_state",
|
if getattr(self.__class__, "get_extra_state",
|
||||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||||
@ -368,7 +399,7 @@ class ZeroDDP(ColoDDP):
|
|||||||
state_dict = state_dict.copy()
|
state_dict = state_dict.copy()
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
# mypy isn't aware that "_metadata" exists in state_dict
|
# mypy isn't aware that "_metadata" exists in state_dict
|
||||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||||
|
|
||||||
prefix = ''
|
prefix = ''
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import colossalai
|
import colossalai
|
||||||
import torch
|
import torch
|
||||||
@ -11,9 +13,9 @@ from functools import partial
|
|||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from typing import Callable
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from colossalai.tensor import ProcessGroup, ColoParameter
|
from colossalai.tensor import ProcessGroup, ColoParameter
|
||||||
|
from colossalai.testing import parameterize
|
||||||
|
|
||||||
|
|
||||||
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
||||||
@ -25,7 +27,27 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
|
|||||||
else:
|
else:
|
||||||
temp_t2 = t2
|
temp_t2 = t2
|
||||||
|
|
||||||
assert torch.equal(t1, temp_t2)
|
assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2)
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_equal(model_a, model_b, allow_empty: bool = False, same_dtype: bool = True):
|
||||||
|
for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()):
|
||||||
|
assert na == nb
|
||||||
|
|
||||||
|
if not allow_empty:
|
||||||
|
assert pa.storage().size() > 0
|
||||||
|
assert pb.storage().size() > 0
|
||||||
|
else:
|
||||||
|
if pa.storage().size() == 0 or pb.storage().size() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if same_dtype:
|
||||||
|
assert pa.dtype == pb.dtype
|
||||||
|
temp_pb = pb
|
||||||
|
else:
|
||||||
|
temp_pb = pb.to(pa.dtype)
|
||||||
|
|
||||||
|
assert torch.equal(pa, temp_pb), "Parameter '{}' is not equal.\n {} {}".format(na, pa, pb)
|
||||||
|
|
||||||
|
|
||||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||||
@ -33,22 +55,26 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
|||||||
return ColoDDP(module, process_group=pg)
|
return ColoDDP(module, process_group=pg)
|
||||||
|
|
||||||
|
|
||||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
|
def init_ddpv2(module: torch.nn.Module,
|
||||||
|
use_chunk: bool = False,
|
||||||
|
use_zero: bool = False,
|
||||||
|
placement_policy: str = 'cuda') -> ZeroDDP:
|
||||||
pg = ProcessGroup()
|
pg = ProcessGroup()
|
||||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
|
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
|
||||||
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
|
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
|
||||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
return ZeroDDP(module, gemini_manager)
|
return ZeroDDP(module, gemini_manager)
|
||||||
|
|
||||||
|
|
||||||
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
|
def run_ddp_state_dict():
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('nested_model')
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
torch_model = model_builder().cuda()
|
torch_model = model_builder().cuda()
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = model_builder()
|
model = model_builder()
|
||||||
model = ddp_init_func(model)
|
model = init_ddp(model)
|
||||||
torch_state_dict = torch_model.state_dict()
|
torch_state_dict = torch_model.state_dict()
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if isinstance(param, ColoParameter):
|
if isinstance(param, ColoParameter):
|
||||||
assert param.get_process_group() is not None
|
assert param.get_process_group() is not None
|
||||||
@ -62,13 +88,44 @@ def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
|
|||||||
check_state_dict_equal(torch_state_dict, state_dict)
|
check_state_dict_equal(torch_state_dict, state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('use_chunk', [False, True])
|
||||||
|
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||||
|
@parameterize('use_zero', [False, True])
|
||||||
|
@parameterize('only_rank_0', [False, True])
|
||||||
|
def run_zero_state_dict(use_chunk, placement_policy, use_zero, only_rank_0):
|
||||||
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
torch_model = model_builder().cuda()
|
||||||
|
org_torch_model = copy.deepcopy(torch_model)
|
||||||
|
torch_state_dict = torch_model.state_dict()
|
||||||
|
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = model_builder()
|
||||||
|
model = init_ddpv2(model, use_chunk, use_zero, placement_policy)
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
if isinstance(param, ColoParameter):
|
||||||
|
assert param.get_process_group() is not None
|
||||||
|
|
||||||
|
model.load_state_dict(torch_state_dict, strict=False)
|
||||||
|
check_model_equal(model, torch_model, allow_empty=True, same_dtype=False)
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
if isinstance(param, ColoParameter):
|
||||||
|
assert param.get_process_group() is not None
|
||||||
|
|
||||||
|
pg = ProcessGroup()
|
||||||
|
state_dict = model.state_dict(only_rank_0=only_rank_0)
|
||||||
|
if not only_rank_0 or pg.dp_local_rank() == 0:
|
||||||
|
torch_model.load_state_dict(state_dict, strict=False)
|
||||||
|
check_model_equal(torch_model, org_torch_model, allow_empty=False, same_dtype=True)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_state_dict(init_ddp)
|
run_ddp_state_dict()
|
||||||
run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=False))
|
run_zero_state_dict()
|
||||||
run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=True))
|
|
||||||
run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=False))
|
|
||||||
run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=True))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
Loading…
Reference in New Issue
Block a user