mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-30 20:55:17 +00:00
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* Support overall loss, update KTO logging
* [Docs] clarify launch port
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix sync condition (#6000)
* [plugin] add cast inputs option for zero (#6003)
* [pre-commit.ci] pre-commit autoupdate (#5995)
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)
* [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* [misc] Use dist logger in plugins (#6011)
* use dist logger in plugins
* remove trash
* print on rank 0
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
458 lines
16 KiB
Python
458 lines
16 KiB
Python
import copy
|
|
from contextlib import nullcontext
|
|
from typing import Any, Callable, Dict, List, Optional, Type
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import Tensor
|
|
from torch import distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
from torch.nn import Module
|
|
from torch.optim import Adam, Optimizer
|
|
from torch.testing import assert_close
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.booster import Booster
|
|
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
|
|
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
|
from colossalai.checkpoint_io.utils import gather_distributed_param
|
|
from colossalai.lazy import LazyInitContext
|
|
from colossalai.nn.optimizer import GaLoreAdamW8bit
|
|
from colossalai.nn.optimizer.galore import get_galore_param_groups
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
from colossalai.shardformer._utils import getattr_
|
|
from colossalai.shardformer.policies.auto_policy import Policy
|
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
|
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
|
|
|
|
|
|
def build_model(
|
|
model_fn,
|
|
enable_fused_normalization=True,
|
|
enable_tensor_parallelism=True,
|
|
enable_flash_attention=False,
|
|
enable_jit_fused=False,
|
|
enable_sequence_parallelism=False,
|
|
use_lazy_init: bool = False,
|
|
dtype=torch.float32,
|
|
):
|
|
# create new model
|
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
|
with ctx:
|
|
# create new model
|
|
org_model = model_fn()
|
|
model_copy = copy.deepcopy(org_model)
|
|
if use_lazy_init:
|
|
ctx.materialize(org_model)
|
|
# shard model
|
|
shard_config = ShardConfig(
|
|
enable_fused_normalization=enable_fused_normalization,
|
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
|
enable_flash_attention=enable_flash_attention,
|
|
enable_jit_fused=enable_jit_fused,
|
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
|
)
|
|
model_copy = copy.deepcopy(org_model)
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
|
return org_model.cuda().to(dtype), sharded_model.cuda().to(dtype)
|
|
|
|
|
|
def build_pipeline_model(
|
|
model_fn,
|
|
stage_manager=None,
|
|
enable_fused_normalization=False,
|
|
enable_tensor_parallelism=False,
|
|
use_lazy_init: bool = False,
|
|
policy: Optional[Policy] = None,
|
|
):
|
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
|
with ctx:
|
|
# create new model
|
|
org_model = model_fn()
|
|
model_copy = copy.deepcopy(org_model)
|
|
if use_lazy_init:
|
|
ctx.materialize(org_model)
|
|
|
|
# shard model
|
|
shard_config = ShardConfig(
|
|
enable_fused_normalization=enable_fused_normalization,
|
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
|
pipeline_stage_manager=stage_manager,
|
|
)
|
|
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
|
sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
|
|
return org_model.cuda(), sharded_model.cuda()
|
|
|
|
|
|
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
|
# prepare input
|
|
data = data_gen_fn()
|
|
data = {k: v.cuda() for k, v in data.items()}
|
|
# switch to train mode
|
|
original_model.train()
|
|
sharded_model.train()
|
|
# run forward
|
|
org_output = original_model(**data)
|
|
org_output = output_transform_fn(org_output)
|
|
org_loss = loss_fn(org_output)
|
|
|
|
shard_output = sharded_model(**data)
|
|
shard_output = output_transform_fn(shard_output)
|
|
shard_loss = loss_fn(shard_output)
|
|
return org_output, org_loss, shard_output, shard_loss
|
|
|
|
|
|
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
|
|
org_sd = org_model.state_dict()
|
|
shard_sd = sharded_model.state_dict()
|
|
for k, v in org_sd.items():
|
|
assert k in shard_sd, f"{name} {k} not in sharded model"
|
|
shard_v = shard_sd[k]
|
|
assert v.shape == shard_v.shape, f"{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}"
|
|
assert v.dtype == shard_v.dtype, f"{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}"
|
|
assert torch.equal(v, shard_v), f"{name} {k} value mismatch"
|
|
|
|
|
|
def build_model_from_hybrid_plugin(
|
|
model_fn: Callable,
|
|
loss_fn: Callable,
|
|
test_config: Dict[str, Any],
|
|
optim_class=Adam,
|
|
sharded_optim_class=Adam,
|
|
pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
|
|
):
|
|
use_lazy_init = False
|
|
if "use_lazy_init" in test_config:
|
|
use_lazy_init = test_config.pop("use_lazy_init")
|
|
|
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
|
with ctx:
|
|
org_model = model_fn()
|
|
sharded_model = copy.deepcopy(org_model)
|
|
if use_lazy_init:
|
|
ctx.materialize(org_model)
|
|
org_model = org_model.cuda()
|
|
if optim_class == GaLoreAdamW8bit:
|
|
# Disable clipping and block-wise quantization
|
|
org_optimizer = optim_class(
|
|
get_galore_param_groups(org_model, weight_decay=0, rank=4),
|
|
lr=1e-3,
|
|
percentile_clipping=101,
|
|
block_wise=False,
|
|
min_8bit_size=1e10,
|
|
)
|
|
sharded_optimizer = sharded_optim_class(
|
|
get_galore_param_groups(sharded_model, weight_decay=0, rank=4),
|
|
lr=1e-3,
|
|
percentile_clipping=101,
|
|
block_wise=False,
|
|
min_8bit_size=1e10,
|
|
)
|
|
else:
|
|
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
|
|
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
|
|
|
criterion = loss_fn
|
|
|
|
plugin = pluggin_cls(**test_config)
|
|
booster = Booster(plugin=plugin)
|
|
|
|
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
|
return (
|
|
org_model,
|
|
org_optimizer,
|
|
sharded_model,
|
|
sharded_optimizer,
|
|
criterion,
|
|
booster,
|
|
)
|
|
|
|
|
|
def build_model_from_low_level_zero_plugin(
|
|
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
|
|
):
|
|
use_lazy_init = False
|
|
if "use_lazy_init" in test_config:
|
|
use_lazy_init = test_config.pop("use_lazy_init")
|
|
|
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
|
with ctx:
|
|
org_model = model_fn()
|
|
sharded_model = copy.deepcopy(org_model)
|
|
if use_lazy_init:
|
|
ctx.materialize(org_model)
|
|
|
|
org_model = org_model.cuda()
|
|
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
|
|
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
|
criterion = loss_fn
|
|
|
|
plugin = LowLevelZeroPlugin(**test_config)
|
|
booster = Booster(plugin=plugin)
|
|
|
|
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
|
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
|
|
|
|
|
|
def run_forward_backward_with_hybrid_plugin(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
sharded_optimizer: Optimizer,
|
|
data_gen_fn: Callable,
|
|
output_transform_fn: Callable,
|
|
criterion: Callable,
|
|
booster: Booster,
|
|
):
|
|
org_model.cuda()
|
|
sharded_model.cuda()
|
|
|
|
def _criterion(outputs, inputs):
|
|
outputs = output_transform_fn(outputs)
|
|
loss = criterion(outputs)
|
|
return loss
|
|
|
|
data = data_gen_fn()
|
|
|
|
shard_test_data = {}
|
|
for k, v in data.items():
|
|
shard_test_data[k] = data[k].clone()
|
|
unshard_test_data = {}
|
|
for k, v in data.items():
|
|
unshard_test_data[k] = data[k].clone()
|
|
|
|
sharded_model.train()
|
|
if booster.plugin.stage_manager is not None:
|
|
for k, v in shard_test_data.items():
|
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
|
new_shape = [1] * v.dim()
|
|
new_shape[0] = 4
|
|
shard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
|
|
|
data_iter = iter([shard_test_data])
|
|
sharded_output = booster.execute_pipeline(
|
|
data_iter,
|
|
sharded_model,
|
|
_criterion,
|
|
sharded_optimizer,
|
|
return_loss=True,
|
|
return_outputs=True,
|
|
)
|
|
sharded_loss = sharded_output["loss"]
|
|
|
|
else:
|
|
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
|
|
sharded_output = sharded_model(**shard_test_data)
|
|
sharded_loss = criterion(sharded_output)
|
|
sharded_optimizer.backward(sharded_loss)
|
|
|
|
org_model.train()
|
|
if booster.plugin.stage_manager is not None:
|
|
for k, v in unshard_test_data.items():
|
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
|
new_shape = [1] * v.dim()
|
|
new_shape[0] = 4
|
|
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
|
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
|
|
org_output = org_model(**unshard_test_data)
|
|
org_loss = criterion(org_output)
|
|
org_loss.backward()
|
|
return org_loss, org_output, sharded_loss, sharded_output
|
|
|
|
|
|
def run_forward_backward_with_low_level_zero_plugin(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
sharded_optimizer: Optimizer,
|
|
data_gen_fn: Callable,
|
|
output_transform_fn: Callable,
|
|
criterion: Callable,
|
|
booster: Booster,
|
|
):
|
|
get_accelerator().get_current_device()
|
|
org_model.cuda()
|
|
sharded_model.cuda()
|
|
|
|
def _criterion(outputs, inputs):
|
|
outputs = output_transform_fn(outputs)
|
|
loss = criterion(outputs)
|
|
return loss
|
|
|
|
data = data_gen_fn()
|
|
|
|
# data = {
|
|
# k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
|
# }
|
|
data = {k: v.cuda() for k, v in data.items()}
|
|
|
|
sharded_model.train()
|
|
sharded_output = sharded_model(**data)
|
|
sharded_loss = criterion(sharded_output)
|
|
sharded_optimizer.backward(sharded_loss)
|
|
|
|
org_model.train()
|
|
org_output = org_model(**data)
|
|
org_loss = criterion(org_output)
|
|
org_loss.backward()
|
|
|
|
return org_loss, org_output, sharded_loss, sharded_output
|
|
|
|
|
|
def check_output_hidden_state(
|
|
org_output: BaseModelOutputWithPast,
|
|
sharded_output: BaseModelOutputWithPast,
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
atol: float = 1e-5,
|
|
rtol: float = 1e-3,
|
|
shard_config: Optional[ShardConfig] = None,
|
|
):
|
|
org_hidden_state = org_output.last_hidden_state
|
|
|
|
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
|
|
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
|
else:
|
|
sharded_hidden_state = sharded_output.last_hidden_state
|
|
|
|
# Check if the output sequence is gathered before cross entropy
|
|
if shard_config is not None:
|
|
seq_dim = 1
|
|
sp_group = shard_config.sequence_parallel_process_group
|
|
sp_size = shard_config.sequence_parallel_size
|
|
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
|
|
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
|
|
|
|
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
|
|
|
|
|
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
|
|
assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
|
|
|
|
|
|
def check_weight(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
layer_suffix: List[str],
|
|
tp_group: Optional[ProcessGroup] = None,
|
|
dim: int = 0,
|
|
atol: float = 1e-5,
|
|
rtol: float = 1e-3,
|
|
verbose: bool = False,
|
|
):
|
|
for suffix in layer_suffix:
|
|
org_weight = getattr_(org_model, suffix).weight
|
|
sharded_weight = getattr_(sharded_model, suffix).weight
|
|
|
|
# skip if layer is not held by this process
|
|
if sharded_weight is None:
|
|
continue
|
|
|
|
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
|
sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False)
|
|
|
|
if is_padded_tensor(sharded_weight):
|
|
sharded_weight = to_unpadded_tensor(sharded_weight)
|
|
|
|
if verbose and dist.get_rank() == 0:
|
|
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
|
|
|
|
assert_close(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol)
|
|
|
|
|
|
def get_grad_tensors_for_check(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
layer_suffix: List[str],
|
|
tp_group: ProcessGroup = None,
|
|
dim: int = 0,
|
|
atol: float = 1e-5,
|
|
rtol: float = 1e-3,
|
|
verbose: bool = False,
|
|
name: str = None,
|
|
):
|
|
grad_to_check = {}
|
|
for suffix in layer_suffix:
|
|
org_grad = getattr_(org_model, suffix).weight.grad
|
|
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
|
shard_weight = getattr_(sharded_model, suffix).weight
|
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
|
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
|
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
|
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
|
|
|
# embedding may be resized when using tensor parallel
|
|
try:
|
|
if shard_grad.shape[0] > org_grad.shape[0]:
|
|
shard_grad = shard_grad[: org_grad.shape[0], :]
|
|
except:
|
|
pass
|
|
if verbose and dist.get_rank() == 0:
|
|
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
|
|
|
grad_to_check[suffix] = {
|
|
"org_grad": org_grad.float(),
|
|
"shard_grad": shard_grad.float(),
|
|
"rtol": rtol,
|
|
"atol": atol,
|
|
}
|
|
|
|
return grad_to_check
|
|
|
|
|
|
# used by sam/blip2
|
|
def check_grad(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
layer_suffix: List[str],
|
|
tp_group: ProcessGroup = None,
|
|
dim: int = 0,
|
|
atol: float = 1e-5,
|
|
rtol: float = 1e-3,
|
|
verbose: bool = False,
|
|
):
|
|
for suffix in layer_suffix:
|
|
org_grad = getattr_(org_model, suffix).weight.grad
|
|
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
|
shard_weight = getattr_(sharded_model, suffix).weight
|
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
|
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
|
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
|
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
|
|
|
# embedding may be resized when using tensor parallel
|
|
if shard_grad.shape[0] > org_grad.shape[0]:
|
|
shard_grad = shard_grad[: org_grad.shape[0], :]
|
|
if verbose and dist.get_rank() == 0:
|
|
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
|
|
|
assert_close(org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol)
|
|
|
|
|
|
def unwrap_model(
|
|
module: Module,
|
|
base_model_class_name: Optional[str] = None,
|
|
base_model_attribute_name: Optional[str] = None,
|
|
):
|
|
if isinstance(module, HybridParallelModule):
|
|
module = module.unwrap()
|
|
if base_model_class_name is None:
|
|
return module
|
|
if module.__class__.__name__ == base_model_class_name:
|
|
return module
|
|
return getattr(module, base_model_attribute_name, None)
|
|
|
|
|
|
def check_all_grad_tensors(check_tensors):
|
|
"""
|
|
"org_grad": tensor to be compared from the original model
|
|
"shard_grad": tensor to be compared from the sharded model
|
|
"""
|
|
for idx, (suffix, check_info) in enumerate(check_tensors.items()):
|
|
org_grad = check_info["org_grad"]
|
|
shard_grad = check_info["shard_grad"]
|
|
rtol = check_info["rtol"]
|
|
atol = check_info["atol"]
|
|
assert_close(org_grad, shard_grad, atol=atol, rtol=rtol)
|