1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-06 15:38:26 +00:00
ColossalAI/colossalai/zero/low_level/low_level_optim.py
Haze188 416580b314
[MoE/ZeRO] Moe refactor with zero refactor ()
* [moe] removed openmoe-coupled code and rectify mixstral code ()

* [Feauture] MoE refractor; Intergration with Mixtral  ()

* cherry pick from refractor-moe branch

* tests passed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support ep + zero

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add mixtral auto policy & move pipeline forward code to modeling folder

* [moe refactor] modify kernel test without Route Class

* [moe refactor] add moe tensor test path environment variable to github workflow

* fix typos

* fix moe test bug due to the code rebase

* [moe refactor] fix moe zero test, and little bug in low level zero

* fix typo

* add moe tensor path to github workflow

* remove some useless code

* fix typo & unify global variable XX_AXIS logic without using -1

* fix typo & prettifier the code

* remove print code & support zero 2 test

* remove useless code

* reanme function

* fix typo

* fix typo

* Further improve the test code

* remove print code

* [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test

* [moe refactor] skip some unit test which will be refactored later

* [moe refactor] fix unit import error

* [moe refactor] fix circular import issues

* [moe refactor] remove debug code

* [moe refactor] update github workflow

* [moe/zero] refactor low level optimizer ()

* [zero] refactor low level optimizer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] MoE refactor with newest version of ZeRO ()

* [zero] remove redundant members in BucketStore ()

* [zero] align api with previous version

* [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug ()

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* [hotfix]Solve the compatibility issue of zero refactor ()

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* Modify function parameter names to resolve compatibility issues

* [zero] fix missing hook removal ()

* [MoE] Resolve .github conflict ()

* [Fix/Example] Fix Llama Inference Loading Data Type ()

* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3

* [release] update version ()

* [release] update version

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [test] fix ddp plugin test

* [test] fix gptj and rpc test

* [devops] fix cuda ext compatibility

* [inference] fix flash decoding test

* [inference] fix flash decoding test

* fix ()

* [test] Fix/fix testcase ()

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [Hotfix] Add missing init file in inference.executor ()

* [CI/tests] simplify some test case to reduce testing time ()

* [ci/tests] simplify some test case to reduce testing time

* [ci/tests] continue to remove test case to reduce ci time cost

* restore some test config

* [ci/tests] continue to reduce ci time cost

* [misc] update dockerfile ()

* [misc] update dockerfile

* [misc] update dockerfile

* [devops] fix docker ci ()

* [Inference]Add Streaming LLM ()

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist

* [hotfix] fix llama flash attention forward ()

* [misc] Accelerate CI for zero and dist optim ()

* remove fp16 from lamb

* remove d2h copy in checking states

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Test/CI] remove test cases to reduce CI duration ()

* [test] smaller gpt2 test case

* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py

* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py

* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py

* Revert "[test] smaller gpt2 test case"

Some tests might depend on the size of model (num of chunks)

This reverts commit df705a5210.

* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py

* [CI] smaller test model for two mwo the two modifid cases

* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there

* [hotfix] fix testcase in test_fx/test_tracer ()

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [fix] fix test_deepfm_model & test_dlrf_model;

* [fix] fix test_hf_albert & test_hf_gpt;

* [gemini] optimize reduce scatter d2h copy ()

* [gemini] optimize reduce scatter d2h copy

* [fix] fix missing reduce variable

* [refactor] remove legacy async reduce scatter code

* [gemini] missing sync

* Revert "[refactor] remove legacy async reduce scatter code"

This reverts commit 58ad76d466.

* [gemini] further optimize with async all reduce

* [fix] pass flag from manager to chunk

* Allow building cuda extension without a device. ()

Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are.

* [misc] fix dist logger ()

* [install]fix setup ()

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update requirements ()

* [shardformer] fix import ()

* upgrade colossal-chat support tp_group>1, add sp for sft

* upgrade ppo dpo rm script

* run pre-commit

* moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy

* fix training script

* fix ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix transformers version

* remove duplicated test

* fix datasets version

* remove models that require huggingface auth from ci

* remove local data path

* update ci

* remove baichuan from template test due to transformer version conflict

* merge

* Refactor modeling by adding attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix tests and naming

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Clean up

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* replace the customized dataloader setup with the build-in one

* replace the customized dataloader setup with the build-in one

* Remove flash attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* fix readme

* Fix test import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* update sft trainning script

* [Inference]refactor baichuan ()

* refactor baichuan

* remove unused code and add TODO for lazyinit

* [test] fix chatglm test kit ()

* [shardformer] fix modeling of bloom and falcon ()

* [test] fix qwen2 pytest distLarge ()

* [Inference] Fix flash-attn import and add model test ()

* Fix torch int32 dtype

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix flash-attn import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add generalized model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Remove exposed path to model

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add default value for use_flash_attn

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Rename model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* [Gemini] Use async stream to prefetch and h2d data moving ()

* use async stream to prefetch and h2d data moving

* Remove redundant code

* [gemini] quick fix on possible async operation ()

* [gemini] quick fix on possible async operation

* [gemini] quick fix on possible async operation

* [shardformer] upgrade transformers to 4.39.3 ()

* [shardformer]upgrade transformers for gpt2/gptj/whisper ()

* [shardformer] fix modeling of gpt2 and gptj

* [shardformer] fix whisper modeling

* [misc] update requirements

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* [shardformer]upgrade transformers for mistral ()

* upgrade transformers for mistral

* fix

* fix

* [shardformer]upgrade transformers for llama ()

* update transformers

fix

* fix

* fix

* [inference] upgrade transformers ()

* update transformers

fix

* fix

* fix

* fix

* fix

* [gemini] update transformers for gemini ()

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* Support 4d parallel + flash attention ()

* support tp + sp + pp

* remove comments

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>

* [zero] fix hook bug

* [zero] add low level optimizer back ()

* [zero] fix param & refactor

* [zero] add back original low level opt

* [zero] remove moe related

* [zero] pass zero tests

* [zero] refactor

* [chore] add del func back

* [zero] comments and naming ()

* [zero] modify api ()

* [zero] modify api

* [test] remove _grad_store access in tests

* [test] fix ()

* [CI] skip openmoe CI check

* [CI] fox pre-commit

* [zero] remove redundant memebr init ()

* [misc] remove useless code, modify the pg mesh implementation

* [misc] remove useless code, modify the pg mesh implementation

* [misc] use tempfile

* resolve conflict with main branch

* [misc] use tempfile in test_moe_checkpoint.py

* [misc] remove useless code, add assertion about sequence parallel, move logger into function

* [misc] remove useless code

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
2024-06-28 14:00:08 +08:00

888 lines
36 KiB
Python

# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
from contextlib import contextmanager
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
from weakref import proxy
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor, inf
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
MixedPrecisionMixin,
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(
self,
num_working_param_groups: int,
pg_to_grad_store: Dict[ProcessGroup, GradientStore],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
)
self.num_working_param_groups = num_working_param_groups
self.pg_to_grad_store = pg_to_grad_store
def check_local_overflow(self) -> bool:
for store in self.pg_to_grad_store.values():
for group_id in range(self.num_working_param_groups):
for avg_grad in store.get_working_grads_by_group_id(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
return True
return False
class LowLevelZeroOptimizer(OptimizerWrapper):
"""Optimizer used for ZeRO-1 and ZeRO-2."""
def __init__(
self,
optimizer: Optimizer,
pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._logger = get_dist_logger()
self._verbose = verbose
if dp_process_group is not None and pg_to_param_list is not None:
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
if pg_to_param_list is None:
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
pg_to_param_list = {unique_dp_group: []}
for group in self.optim.param_groups:
pg_to_param_list[unique_dp_group].extend(group["params"])
self.pg_to_param_list = pg_to_param_list
param_to_pg = {}
for grp, param_list in pg_to_param_list.items():
for p in param_list:
assert isinstance(p, nn.Parameter), f"got {type(p)}"
param_to_pg[p] = grp
self.param_to_pg = param_to_pg
# stage 2
self._partition_grads = partition_grad
self._cpu_offload = cpu_offload
# grad accumulation
self.require_grad_sync = True
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient clipping
self._clip_grad_norm = clip_grad_norm
# master weights copy
self._master_weights = master_weights
if forced_dtype:
for group in self.optim.param_groups:
group_params = group["params"]
for param in group_params:
param.data = param.data.to(forced_dtype)
self._dtype = forced_dtype
# check argument conflict
self._sanity_checks()
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
# record the padding size of each param
self._padding_map = dict()
# mapping working param and master param
self.master_to_working_param = dict()
self.working_to_master_param = dict()
# NOTE need to gurantee the order of process group is the same accross all ranks
# process_group <---> xxx_store
# process_group <---> [param1 param2 ...]
# each process group have its own stores
# param belonging to one process_group will use corresponding store
self.pg_to_grad_store = {
pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_to_param_list
}
# param id to grad store, have to use id(param) as key since it is used in stores
self.pid_to_grad_store = {id(param): self.pg_to_grad_store[param_to_pg[param]] for param in param_to_pg}
self.pg_to_bucket_store = {pg: BucketStore(pg, reduce_bucket_size) for pg in self.pg_to_param_list}
# param id to bucket store, have to use id(param) as key since it is used in stores
self.pid_to_bucket_store = {id(param): self.pg_to_bucket_store[param_to_pg[param]] for param in param_to_pg}
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
group_params = list()
for param in param_group["params"]:
if param.requires_grad:
group_params.append(param)
# add the working params to working_param_groups for bookkeeping
self._working_param_groups[group_id] = group_params
master_param_current_rank = self._create_master_param_current_rank(group_params)
self._master_param_groups_of_current_rank[group_id] = master_param_current_rank
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group["params"] = master_param_current_rank
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
self.grad_handles = []
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
# initialize mixed precision mixin
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(
self.num_param_groups,
self.pg_to_grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
def __del__(self):
for hook in self.grad_handles:
hook.remove()
@property
def dtype(self):
return self._dtype
@property
def num_param_groups(self):
return len(self._working_param_groups)
def _sanity_checks(self):
assert get_accelerator().name in ["cuda", "npu"], "device is required"
for param_group in self.optim.param_groups:
group_params = param_group["params"]
for param in group_params:
if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False:
assert (
param.dtype == self._dtype
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
params_current_rank = []
device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
for param in param_list:
padding_size = (
self.pid_to_bucket_store[id(param)].world_size
- param.numel() % self.pid_to_bucket_store[id(param)].world_size
) % self.pid_to_bucket_store[id(param)].world_size
self.record_param_padding_size(param, padding_size)
with torch.no_grad():
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights
if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
)
splited_params = splited_params[self.pid_to_bucket_store[id(param)].local_rank]
# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params.detach().float().to(device)
else:
splited_param_current_rank = splited_params
params_current_rank.append(splited_param_current_rank)
self.link_master_and_working_param(splited_param_current_rank, param)
return params_current_rank
###########################
# Backward Reduction Hook #
###########################
def _attach_reduction_hook(self):
# we iterate over the working params
# on each param, we register a hook to its AccumulateGrad object
self_weakref = proxy(self)
def _grad_handler(param, group_id):
# if run with no_sync context, would not sync grad when backward
if self_weakref.require_grad_sync:
self_weakref._add_to_bucket(param, group_id)
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
self.grad_handles.append(
param.register_post_accumulate_grad_hook(partial(_grad_handler, group_id=group_id))
)
#######################
# Reduction Functions #
#######################
def _run_reduction(self):
for bucket_store in self.pg_to_bucket_store.values():
if bucket_store.num_elements_in_bucket() <= 0:
continue
bucket_store.build_grad_in_bucket()
flat_grads = bucket_store.get_flatten_grad()
flat_grads /= bucket_store.world_size
# ready to add other tensors to bucket
bucket_store.reset_num_elements_in_bucket()
if self._overlap_communication:
stream = bucket_store.comm_stream
# in case of the memory being reused in the default stream
flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
stream.wait_stream(get_accelerator().current_stream())
else:
stream = get_accelerator().current_stream()
with get_accelerator().stream(stream):
group_id = bucket_store.current_group_id
grad_dtype = flat_grads.dtype
if self._communication_dtype is not None:
flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.world_size)
grad_in_bucket = bucket_store.get_grad()
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank]
self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1)
bucket_store.reset()
def _update_unpartitoned_grad(
self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int
) -> None:
for rank, grad_list in enumerate(origin_grad_list):
sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list:
param_id = bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, bucket_store.world_size, group_id, param_id, rank)
def _update_partitoned_grad(
self,
bucket_store: BucketStore,
origin_grad_list: List,
flat_grad: torch.Tensor,
group_id: int,
partition_num: int,
) -> None:
sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list:
param_id = bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, partition_num, group_id, param_id)
def _add_grad(
self,
grad: torch.Tensor,
partition_num: int,
group_id: int,
param_id: int,
rank: int = 0,
) -> None:
if (
len(self.pid_to_grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id))
< partition_num
):
self.pid_to_grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id)
else:
self.pid_to_grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id)
def _add_to_bucket(self, param, group_id):
param_size = param.numel()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# or got a grad of param from another group
# after reduction, the bucket will be empty
if (
self.pid_to_bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size
or group_id != self.pid_to_bucket_store[id(param)].current_group_id
):
self._run_reduction()
padding_size = self.get_param_padding_size(param)
self.pid_to_bucket_store[id(param)].add_param_grad(group_id, param, padding_size)
################################
# torch.optim.Optimizer methods
################################
def backward(self, loss, retain_graph=False):
assert not (
self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and no_sync are not compatible"
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
get_accelerator().synchronize()
def backward_by_grad(self, tensor, grad):
assert not (
self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
get_accelerator().synchronize()
def zero_bucket_stores(self):
for bucket_store in self.pg_to_bucket_store.values():
bucket_store.reset_all()
def zero_grad_stores(self):
for grad_store in self.pg_to_grad_store.values():
grad_store.reset_all_gradients()
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
will be set to None to save memory.
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
if self.mixed_precision_mixin is not None:
self.mixed_precision_mixin.pre_zero_grad()
for _, param_group in self._working_param_groups.items():
for param in param_group:
if set_to_none:
param.grad = None
else:
if param.grad is not None:
param.grad.detach()
param.grad.zero_()
self.zero_grad_stores()
self.zero_bucket_stores()
####################
# Update Parameter #
####################
def step(self, closure=None):
assert closure is None, "closure is not supported by step()"
if not self.require_grad_sync:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
if self._verbose:
self._logger.info(f"Found overflow. Skip step")
self.zero_grad()
return
# record all grads for unscale and clip
grad_partition_groups = []
norm_groups = []
# sometimes not all params are 'really' working
# for instance, when layer drop, the dropped layer has no grad
# and should not be updated
real_working_params = dict()
real_master_params = dict()
for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id]
working_params = self._working_param_groups[group_id]
real_working_params[group_id] = []
real_master_params[group_id] = []
working_grads = []
for working_param, master_param in zip(working_params, master_params):
# if a working param requires grad and has no grad
# it is not 'really' working, e.g. the droped layer
# else the splited grad should be attached to the splited param
grad_store = self.pid_to_grad_store[id(working_param)]
grads = grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
grad_index = 0 if self._partition_grads else grad_store.local_rank
if len(grads) > 0:
real_working_params[group_id].append(working_param)
grad = grads[grad_index]
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
grad = grad.to(master_param.dtype).to(master_param.device)
master_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(master_param)
# compute norm
norm_group = 0
for grad_store in self.pg_to_grad_store.values():
working_grads = grad_store.get_working_grads_by_group_id(group_id)
norm_group += self._compute_grad_norm(dp_pg=grad_store.torch_pg, gradients=working_grads)
norm_groups.append(norm_group)
# update the params in the optimizer
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# update the parameters
self.optim.step()
# release the grad
grad_partition_groups = []
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
self.pg_to_tensor_bucket = {
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
}
# update working partition updated by the current rank
device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, master_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param]
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
gradients (List[Tensor]): The gradients to compute norm
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
Returns:
float: The total norm of given gradients
"""
if len(gradients) == 0:
return 0.0
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor(
[float(total_norm)],
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
total_norm = total_norm_cuda.item()
else:
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
total_norm_exponentiated += grad_norm_exponentiated
# Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)],
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=dp_pg,
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
#############################
# Mixed Precision Utilities #
#############################
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
div_scale = 1.0
if self.mixed_precision_mixin is not None:
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
if self._clip_grad_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
div_scale = clip * div_scale
for grad in grad_groups_flat:
grad.data.mul_(1.0 / div_scale)
############################
# Gradient Synchronization #
############################
# this method is used to sync gradient manually
def _sync_grad(self):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad and param.grad is not None:
self._add_to_bucket(param, group_id)
self._run_reduction()
def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication:
self._sync_grad()
else:
self._run_reduction()
# this context comes from pytorch DDP
@contextmanager
def no_sync(self):
old_require_grad_sync = self.require_grad_sync
self.require_grad_sync = False
try:
yield
finally:
self.require_grad_sync = old_require_grad_sync
##############
# State Dict #
##############
def _pack_state(self, state: Dict) -> Dict:
# comes from pytorch optimizer.state_dict()
param_mappings = {}
start_index = 0
def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
param_mappings.update(
{id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings}
)
packed["params"] = [param_mappings[id(p)] for p in group["params"]]
start_index += len(packed["params"])
return packed
param_groups = [pack_group(g) for g in self.optim.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
return {"state": packed_state, "param_groups": param_groups}
def state_dict(self) -> Dict:
"""Return a state_dict same with DDP
Returns:
Dict: the pytorch form state_dict
"""
zero_state = dict()
device = get_accelerator().get_current_device()
for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
dist.all_gather(gather_tensor, v.to(device), group=pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
return states_dict
def load_state_dict(self, state_dict: Dict):
"""Load state dict, requires the state_dict be the pytorch form
Args:
state_dict (dict): A pytorch form state_dict
"""
zero_state_dict = copy.deepcopy(state_dict)
idx2master = {}
cnt = 0
for param_group in self.optim.param_groups:
for param in param_group["params"]:
idx2master[cnt] = param
cnt += 1
for param_idx, state in zero_state_dict["state"].items():
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size()
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // pg.size())
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone()
self.optim.load_state_dict(zero_state_dict)
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict.
Args:
max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024.
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
ret_block = dict()
ret_block_size = 0
device = get_accelerator().get_current_device()
local_states = self.optim.state_dict()["state"]
idx2master = {}
cnt = 0
for param_group in self.optim.param_groups:
for param in param_group["params"]:
idx2master[cnt] = param
cnt += 1
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
dist.all_gather(state_tensor, v.to(device), group=pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
current_block_size += state_tensor.numel()
current_block[k] = state_tensor
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size
ret_block = dict()
ret_block_size = 0
ret_block[param_idx] = current_block
ret_block_size += current_block_size
yield ret_block, ret_block_size
def update_master_params(self, model: nn.Module) -> None:
"""Update master params from working params
Args:
model (nn.Module): The model to update master params
"""
for p in model.parameters():
p_id = id(p)
pg = self.param_to_pg[p]
if p_id in self.working_to_master_param:
master_param = self.working_to_master_param[p_id]
padding_size = self.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(pg.size())[pg.rank()])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return self.master_to_working_param
def get_param_padding_map(self) -> Dict[int, torch.Tensor]:
return self._padding_map
def record_param_padding_size(self, param: Tensor, padding_size: int):
"""Record the padding size of a param
Args:
param (Tensor): The parameter
padding_size (int): The padding size of the parameter
"""
self._padding_map[id(param)] = padding_size
def get_param_padding_size(self, param: Tensor) -> int:
"""Return the padding size of the parameter
Args:
param (Tensor): The parameter
Returns:
int: the padding size of the parameter
"""
return self._padding_map[id(param)]
def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
"""Mapping master parameter and working parameter
Args:
master_param (Tensor): The parameter copy in optimizer
working_param (Tensor): The parameter of the model
"""
self.master_to_working_param[id(master_param)] = working_param
self.working_to_master_param[id(working_param)] = master_param
def get_padding_map(self) -> Dict[int, Tensor]:
"""Return the padding map
Returns:
Dict[int, Tensor]: The padding map
"""
return self._padding_map
def get_param_grad(self, working_param: nn.Parameter) -> Tensor:
grad_store = self.pid_to_grad_store[id(working_param)]
partial_grad = grad_store.get_working_grad_by_param_id(id(working_param))
if partial_grad is None:
return None
tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)]
dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg)
grad_flat = torch.cat(tensor_list, dim=0)
return grad_flat[: working_param.numel()].reshape_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
working_grads = []
for grad_store in self.pg_to_grad_store.values():
working_grads.extend(grad_store.get_working_grads_by_group_id(group_id))
return working_grads
def get_param_id_for_grad(self, grad: Tensor) -> int:
param_id = None
for grad_store in self.pg_to_grad_store.values():
id_maybe_none = grad_store.get_param_id_for_grad(grad)
if id_maybe_none is not None:
if param_id is not None:
raise ValueError("The grad mapping is not unique")
param_id = id_maybe_none
return param_id
def get_working_grad_by_param_id(self, param_id: int) -> Tensor:
grad_store = self.pid_to_grad_store[param_id]
return grad_store.get_working_grad_by_param_id(param_id)
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
grad_store = self.pid_to_grad_store[param_id]
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)