mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-23 10:11:37 +00:00
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
667 lines
30 KiB
Python
667 lines
30 KiB
Python
import itertools
|
|
from collections import OrderedDict
|
|
from functools import partial
|
|
from typing import Dict, Iterable, List, Optional, Set
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
|
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
|
from colossalai.tensor import ReplicaSpec
|
|
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
|
from colossalai.utils import get_current_device, is_ddp_ignored
|
|
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
|
|
|
from .reducer import Reducer
|
|
from .utils import get_static_torch_model
|
|
|
|
try:
|
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
|
except ImportError:
|
|
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
|
|
|
|
|
def free_storage(data: torch.Tensor) -> None:
|
|
"""Free underlying storage of a Tensor."""
|
|
if data.storage().size() > 0:
|
|
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
|
# is the sole occupant of the Storage.
|
|
assert data.storage_offset() == 0
|
|
data.storage().resize_(0)
|
|
|
|
|
|
def _cast_float(args, dtype: torch.dtype):
|
|
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
|
|
args = args.to(dtype)
|
|
elif isinstance(args, (list, tuple)):
|
|
args = type(args)(_cast_float(t, dtype) for t in args)
|
|
elif isinstance(args, dict):
|
|
args = {k: _cast_float(v, dtype) for k, v in args.items()}
|
|
return args
|
|
|
|
|
|
class ColoDDP(torch.nn.Module):
|
|
"""Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.
|
|
|
|
Example:
|
|
>>> from colossalai.core import global_context as gpc
|
|
>>> from colossalai.context import ParallelMode
|
|
>>> model = torch.nn.Linear(20, 1)
|
|
>>> pg = ProcessGroup(tp_degree = world_size//2)
|
|
>>> model = ColoDDP(model, pg)
|
|
>>> logits = model(x)
|
|
>>> loss = criterion(logits, labels)
|
|
>>> model.backward(loss)
|
|
|
|
Args:
|
|
module (torch.nn.Module): Module to apply DDP.
|
|
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
|
If it's None, the default data parallel group will be used. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
module: torch.nn.Module,
|
|
process_group: ColoProcessGroup,
|
|
bucket_cap_mb: int = 25,
|
|
rebuild_bucket: bool = True) -> None:
|
|
assert not isinstance(module, ColoDDP)
|
|
super().__init__()
|
|
self.module = module
|
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
|
assert process_group
|
|
|
|
self.process_group = process_group
|
|
self.dp_world_size = self.process_group.dp_world_size()
|
|
|
|
self.reducer = Reducer(bucket_cap_mb)
|
|
self.rebuild_bucket = rebuild_bucket
|
|
for p in module.parameters():
|
|
if is_ddp_ignored(p):
|
|
continue
|
|
if p.requires_grad:
|
|
p.register_hook(partial(self.grad_handle, p))
|
|
|
|
def parameters(self, recurse: bool = True):
|
|
return self.module.parameters(recurse)
|
|
|
|
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
|
return self.module.named_parameters(prefix, recurse)
|
|
|
|
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
|
return self.module.named_buffers(prefix, recurse)
|
|
|
|
def named_children(self):
|
|
return self.module.named_children()
|
|
|
|
def named_modules(self,
|
|
memo: Optional[Set[torch.nn.Module]] = None,
|
|
prefix: str = '',
|
|
remove_duplicate: bool = True):
|
|
return self.module.named_modules(memo, prefix, remove_duplicate)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
self.module.zero_grad(set_to_none=True)
|
|
return self.module(*args, **kwargs)
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
loss.backward()
|
|
with torch.cuda.stream(self.comm_stream):
|
|
self.reducer.flush()
|
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
|
if self.rebuild_bucket:
|
|
self.reducer.free()
|
|
for p in self.module.parameters():
|
|
if is_ddp_ignored(p):
|
|
continue
|
|
if p.grad.device.type != "cpu":
|
|
p.grad = p._saved_grad
|
|
|
|
def grad_handle(self, p, grad):
|
|
if grad.device.type != "cpu":
|
|
empty_grad = torch.empty_like(grad)
|
|
free_storage(empty_grad)
|
|
if self.dp_world_size > 1:
|
|
grad = grad / self.dp_world_size
|
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(self.comm_stream):
|
|
self.reducer.all_reduce_async(grad,
|
|
group=self.process_group.dp_process_group(),
|
|
callback_fn=partial(self._save_grad, p))
|
|
grad.record_stream(self.comm_stream)
|
|
else:
|
|
ColoDDP._save_grad(p, grad)
|
|
return empty_grad
|
|
|
|
else:
|
|
# TODO(jiaruifang) fixme
|
|
self.process_group.set_cpu_groups()
|
|
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
|
|
return grad
|
|
|
|
@staticmethod
|
|
def _save_grad(p, grad):
|
|
if hasattr(p, '_saved_grad'):
|
|
p._saved_grad.add_(grad)
|
|
else:
|
|
p._saved_grad = grad
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
self.module.zero_grad(set_to_none=True)
|
|
for p in self.module.parameters():
|
|
if getattr(p, '_saved_grad', None) is not None:
|
|
if set_to_none:
|
|
p._saved_grad = None
|
|
else:
|
|
if p._saved_grad.grad_fn is not None:
|
|
p._saved_grad.detach_()
|
|
else:
|
|
p._saved_grad.requires_grad_(False)
|
|
p._saved_grad.zero_()
|
|
|
|
@staticmethod
|
|
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
|
|
"""Sets parameters to be ignored by DDP.
|
|
This method must be called before initializing ColoDDP.
|
|
|
|
Example:
|
|
>>> params_to_ignore = []
|
|
>>> for p in module.parameters():
|
|
>>> if should_ignore(p):
|
|
>>> params_to_ignore.append(p)
|
|
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
|
|
>>> module = ColoDDP(module)
|
|
|
|
Args:
|
|
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
|
|
"""
|
|
for p in params_to_ignore:
|
|
p._ddp_to_ignore = True
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
|
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
|
|
|
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
|
return self.module.load_state_dict(state_dict, strict)
|
|
|
|
|
|
class ZeroDDP(ColoDDP):
|
|
"""ZeRO DDP for ColoTensor.
|
|
Warning: Nested ZeroDDP is not supported now.
|
|
It is designed to be used with ChunkManager and GeminiManager.
|
|
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
|
|
|
Args:
|
|
module (torch.nn.Module): Module to apply ZeRO-DP.
|
|
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
|
For more details, see the API reference of ``GeminiManager``.
|
|
pin_memory (bool): Chunks on CPU Memory use pin-memory.
|
|
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
|
|
Defaults to False.
|
|
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
|
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
|
"""
|
|
|
|
def __init__(self,
|
|
module: torch.nn.Module,
|
|
gemini_manager: GeminiManager,
|
|
pin_memory: bool = False,
|
|
force_outputs_fp32: bool = False,
|
|
strict_ddp_mode: bool = False) -> None:
|
|
super().__init__(module, process_group=ColoProcessGroup())
|
|
self.gemini_manager = gemini_manager
|
|
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
|
self.force_outputs_fp32 = force_outputs_fp32
|
|
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
|
self.fp32_params: List[ColoTensor] = []
|
|
self.overflow_counter = 0
|
|
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
|
|
|
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
|
|
|
if self.gemini_manager._premade_memstats_:
|
|
# build chunk in param runtime visited order.
|
|
param_order = self.gemini_manager.memstats()._param_runtime_order
|
|
else:
|
|
# build chunk in param initialized order.
|
|
# Note: in this way, it can not get filter unused params during runtime.
|
|
param_order = OrderedParamGenerator()
|
|
for p in module.parameters():
|
|
param_order.append(p)
|
|
|
|
for p in param_order.generate():
|
|
assert isinstance(p, ColoParameter)
|
|
|
|
if strict_ddp_mode and not p.is_replicate():
|
|
p.set_dist_spec(ReplicaSpec())
|
|
|
|
if is_ddp_ignored(p):
|
|
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
|
continue
|
|
|
|
fp32_data = p.data.float()
|
|
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
|
p.data = p.data.half()
|
|
dp_world_size = p.process_group.dp_world_size()
|
|
self.chunk_manager.register_tensor(tensor=p,
|
|
group_type='fp16_param',
|
|
config_key=dp_world_size,
|
|
cpu_offload=cpu_offload,
|
|
pin_memory=pin_memory)
|
|
self.chunk_manager.register_tensor(tensor=fp32_p,
|
|
group_type='fp32_param',
|
|
config_key=dp_world_size,
|
|
cpu_offload=cpu_offload,
|
|
pin_memory=pin_memory)
|
|
self.fp32_params.append(fp32_p)
|
|
self.grads_device[p] = self.gemini_manager.default_device
|
|
|
|
self.chunk_manager.close_all_groups()
|
|
self._cast_buffers()
|
|
|
|
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
|
|
for p, fp32_p in zip(params_list, self.fp32_params):
|
|
chunk_16 = self.chunk_manager.get_chunk(p)
|
|
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
|
chunk_32.init_pair(chunk_16)
|
|
|
|
# keep gathered chunks are in CUDA
|
|
if chunk_16.keep_gathered:
|
|
self.grads_device[p] = get_current_device()
|
|
|
|
self._logger = get_dist_logger()
|
|
|
|
def _post_forward(self):
|
|
"""This function is only triggered for inference.
|
|
"""
|
|
access_list = list(self.chunk_manager.accessed_chunks)
|
|
# we need to scatter all accessed chunks and move them to their original places
|
|
for chunk in access_list:
|
|
assert chunk.can_release
|
|
self.chunk_manager.release_chunk(chunk)
|
|
first_param = next(iter(chunk.tensors_info))
|
|
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
|
|
assert self.chunk_manager.accessed_mem == 0
|
|
# reset all recorded attributes
|
|
self.gemini_manager.reset_attributes()
|
|
|
|
def forward(self, *args, **kwargs):
|
|
# check whether we are in a inference mode
|
|
grad_flag = torch.is_grad_enabled()
|
|
if not grad_flag:
|
|
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"
|
|
|
|
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
|
self.module.zero_grad(set_to_none=True)
|
|
self.gemini_manager.pre_iter(*args)
|
|
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
|
outputs = self.module(*args, **kwargs)
|
|
# scatter chunks in the inference mode
|
|
if not grad_flag:
|
|
self._post_forward()
|
|
|
|
if self.force_outputs_fp32:
|
|
return _cast_float(outputs, torch.float)
|
|
return outputs
|
|
|
|
def _setup_grads_ptr(self):
|
|
for p in self.module.parameters():
|
|
if is_ddp_ignored(p):
|
|
continue
|
|
p.grad = None
|
|
|
|
def _post_backward(self):
|
|
if self.chunk_manager.accessed_mem != 0:
|
|
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
|
"The most possible reason is that the model is not compatible with ZeroDDP.")
|
|
self._setup_grads_ptr()
|
|
self._logger.debug(
|
|
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
|
)
|
|
self.gemini_manager.post_iter()
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
|
loss.backward()
|
|
self._post_backward()
|
|
|
|
def backward_by_grad(self, tensor, grad):
|
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
|
torch.autograd.backward(tensor, grad)
|
|
self._post_backward()
|
|
|
|
def grad_handle(self, p, grad):
|
|
empty_grad = torch.empty_like(grad)
|
|
free_storage(empty_grad)
|
|
with torch._C.DisableTorchFunction():
|
|
chunk = self.chunk_manager.get_chunk(p)
|
|
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
|
|
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
|
chunk.copy_tensor_to_chunk_slice(p, grad)
|
|
reduced = self.chunk_manager.reduce_chunk(chunk)
|
|
if reduced:
|
|
if chunk.is_gathered:
|
|
chunk.cuda_global_chunk.div_(chunk.pg_size)
|
|
else:
|
|
chunk.cuda_shard.div_(chunk.pg_size)
|
|
# check overflow elements
|
|
self.overflow_counter += chunk.has_inf_or_nan
|
|
# record l2 norm for gradient clipping
|
|
if chunk.l2_norm_flag:
|
|
chunk.set_l2_norm()
|
|
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
|
return empty_grad
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
self.module.zero_grad(set_to_none=True)
|
|
|
|
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
|
for tensor in chunk.get_tensors():
|
|
self.grads_device[tensor] = device
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
|
"""
|
|
Args:
|
|
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
|
|
|
|
Returns:
|
|
dict:
|
|
a dictionary containing a whole state of the module
|
|
|
|
Example:
|
|
|
|
>>> module.state_dict().keys()
|
|
['bias', 'weight']
|
|
"""
|
|
if strict:
|
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
|
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
|
|
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
|
return self._non_strict_state_dict(destination=destination,
|
|
prefix=prefix,
|
|
keep_vars=keep_vars,
|
|
only_rank_0=only_rank_0)
|
|
|
|
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
|
"""Returns a dictionary containing a whole state of the module.
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are included.
|
|
Keys are corresponding parameter and buffer names.
|
|
Parameters and buffers set to ``None`` are not included.
|
|
|
|
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
|
|
are shared with other parameters which have been included in the dictionary.
|
|
When you need to load the state dict, you should set the argument `strict` to False.
|
|
|
|
Returns:
|
|
dict:
|
|
a dictionary containing a whole state of the module
|
|
"""
|
|
if destination is None:
|
|
destination = OrderedDict()
|
|
destination._metadata = OrderedDict()
|
|
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
|
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
|
|
|
for hook in self._state_dict_hooks.values():
|
|
hook_result = hook(self, destination, prefix, local_metadata)
|
|
if hook_result is not None:
|
|
destination = hook_result
|
|
return destination
|
|
|
|
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
|
"""
|
|
get param content from chunks.
|
|
|
|
Args:
|
|
param_list (_type_): a list of torch.nn.Parameters
|
|
only_rank_0 (_type_): _description_
|
|
|
|
Returns:
|
|
Dict: a dict whose key is param name and value is param with correct payload
|
|
"""
|
|
# save parameters
|
|
param_to_save_data = dict()
|
|
chunk_list = self.chunk_manager.get_chunks(param_list)
|
|
for chunk in chunk_list:
|
|
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
|
|
|
for tensor, tensor_info in chunk.tensors_info.items():
|
|
record_tensor = torch.empty([0])
|
|
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
|
if record_flag:
|
|
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
|
|
|
assert tensor not in param_to_save_data
|
|
param_to_save_data[tensor] = record_tensor
|
|
|
|
del temp_chunk
|
|
return param_to_save_data
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
|
r"""Saves module state to `destination` dictionary, containing a state
|
|
of the module, but not its descendants. This is called on every
|
|
submodule in :meth:`~torch.nn.Module.state_dict`.
|
|
|
|
In rare cases, subclasses can achieve class-specific behavior by
|
|
overriding this method with custom logic.
|
|
|
|
Args:
|
|
destination (dict): a dict where state will be stored
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
"""
|
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
|
|
|
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
|
ddp_param_list = []
|
|
for name, param in self.named_parameters():
|
|
if is_ddp_ignored(param):
|
|
# deal with ddp ignored parameters
|
|
destination[prefix + name] = param if keep_vars else param.detach()
|
|
else:
|
|
ddp_param_list.append((name, param))
|
|
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
|
if p is not None:
|
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
|
record_parameter = param_to_save_data[fp32_p]
|
|
destination[prefix + name] = record_parameter
|
|
|
|
# save all buffers
|
|
for name, buf in self.named_buffers():
|
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
|
# save extra states
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
if getattr(self.__class__, "get_extra_state",
|
|
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
|
destination[extra_state_key] = self.get_extra_state()
|
|
|
|
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into
|
|
this module and its descendants. If :attr:`strict` is ``True``, then
|
|
the keys of :attr:`state_dict` must exactly match the keys returned
|
|
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
|
|
|
Args:
|
|
state_dict (dict): a dict containing parameters and
|
|
persistent buffers.
|
|
strict (bool, optional): whether to strictly enforce that the keys
|
|
in :attr:`state_dict` match the keys returned by this module's
|
|
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
|
|
|
Returns:
|
|
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
|
* **missing_keys** is a list of str containing the missing keys
|
|
* **unexpected_keys** is a list of str containing the unexpected keys
|
|
|
|
Note:
|
|
If a parameter or buffer is registered as ``None`` and its corresponding key
|
|
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
|
|
``RuntimeError``.
|
|
"""
|
|
missing_keys: List[str] = []
|
|
unexpected_keys: List[str] = []
|
|
error_msgs: List[str] = []
|
|
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
metadata = getattr(state_dict, '_metadata', None)
|
|
state_dict = state_dict.copy()
|
|
if metadata is not None:
|
|
# mypy isn't aware that "_metadata" exists in state_dict
|
|
state_dict._metadata = metadata # type: ignore[attr-defined]
|
|
|
|
prefix = ''
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
if strict:
|
|
if len(unexpected_keys) > 0:
|
|
error_msgs.insert(
|
|
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
|
'"{}"'.format(k) for k in unexpected_keys)))
|
|
if len(missing_keys) > 0:
|
|
error_msgs.insert(
|
|
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
|
|
|
if len(error_msgs) > 0:
|
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
|
self.__class__.__name__, "\n\t".join(error_msgs)))
|
|
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
|
error_msgs):
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
|
this module, but not its descendants. This is called on every submodule
|
|
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
|
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
|
For state dicts without metadata, :attr:`local_metadata` is empty.
|
|
Subclasses can achieve class-specific backward compatible loading using
|
|
the version number at `local_metadata.get("version", None)`.
|
|
|
|
.. note::
|
|
:attr:`state_dict` is not the same object as the input
|
|
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
|
it can be modified.
|
|
|
|
Args:
|
|
state_dict (dict): a dict containing parameters and
|
|
persistent buffers.
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
local_metadata (dict): a dict containing the metadata for this module.
|
|
See
|
|
strict (bool): whether to strictly enforce that the keys in
|
|
:attr:`state_dict` with :attr:`prefix` match the names of
|
|
parameters and buffers in this module
|
|
missing_keys (list of str): if ``strict=True``, add missing keys to
|
|
this list
|
|
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
|
keys to this list
|
|
error_msgs (list of str): error messages should be added to this
|
|
list, and will be reported together in
|
|
:meth:`~torch.nn.Module.load_state_dict`
|
|
"""
|
|
for hook in self._load_state_dict_pre_hooks.values():
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
|
|
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
|
|
def load(param_name, dest_tensor, copy_func):
|
|
state_key = prefix + param_name
|
|
if state_key in state_dict:
|
|
input_param = state_dict[state_key]
|
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
|
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
|
input_param = input_param[0]
|
|
if input_param.shape != dest_tensor.shape:
|
|
# local shape should match the one in checkpoint
|
|
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
|
'the shape in current model is {}.'.format(state_key, input_param.shape,
|
|
dest_tensor.shape))
|
|
return
|
|
try:
|
|
with torch.no_grad():
|
|
copy_func(input_param)
|
|
except Exception as ex:
|
|
error_msgs.append('While copying the parameter named "{}", '
|
|
'whose dimensions in the model are {} and '
|
|
'whose dimensions in the checkpoint are {}, '
|
|
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
|
|
input_param.size(), ex.args))
|
|
elif strict:
|
|
missing_keys.append(state_key)
|
|
|
|
def load_fp32_parameter(chunk_slice, data):
|
|
chunk_slice.copy_(data.flatten())
|
|
|
|
ddp_param_list = []
|
|
for name, param in self.named_parameters():
|
|
if is_ddp_ignored(param):
|
|
# deal with ddp ignored parameters
|
|
load(name, param, param.copy_)
|
|
else:
|
|
ddp_param_list.append((name, param))
|
|
|
|
fp32_to_name = dict()
|
|
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
|
if p is not None:
|
|
fp32_to_name[fp32_p] = name
|
|
|
|
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
|
for chunk in chunk_list:
|
|
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
|
|
|
for tensor, tensor_info in chunk.tensors_info.items():
|
|
parameter_name = fp32_to_name[tensor]
|
|
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
|
|
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
|
|
|
if chunk.is_gathered:
|
|
chunk.cuda_global_chunk.copy_(temp_chunk)
|
|
elif chunk.cuda_shard is not None:
|
|
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
|
else:
|
|
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
|
|
|
del temp_chunk
|
|
|
|
for chunk_32 in chunk_list:
|
|
chunk_16 = chunk_32.paired_chunk
|
|
assert chunk_16 is not None
|
|
chunk_16.optim_update()
|
|
|
|
for name, buf in persistent_buffers.items():
|
|
if buf is not None:
|
|
load(name, buf, buf.copy_)
|
|
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
if getattr(self.__class__, "set_extra_state",
|
|
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
|
|
if extra_state_key in state_dict:
|
|
self.set_extra_state(state_dict[extra_state_key])
|
|
elif strict:
|
|
missing_keys.append(extra_state_key)
|
|
elif strict and (extra_state_key in state_dict):
|
|
unexpected_keys.append(extra_state_key)
|
|
|
|
if strict:
|
|
for key in state_dict.keys():
|
|
if key.startswith(prefix) and key != extra_state_key:
|
|
input_name = key[len(prefix):]
|
|
if input_name not in local_state:
|
|
unexpected_keys.append(key)
|
|
|
|
def _cast_buffers(self):
|
|
for buffer in self.module.buffers():
|
|
buffer.data = buffer.cuda()
|
|
if torch.is_floating_point(buffer):
|
|
buffer.data = buffer.half()
|