mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-04 21:29:41 +00:00
* [branch rebase] rebase main to Feature/resize_embedding (#5554) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [CI] run pre-commit (#5577) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme * run pre-commit --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [rebase] rebase main to resize-embedding (#5581) * [release] grok-1 314b inference (#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference * [example] update Grok-1 inference (#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url * [hotfix] set return_outputs=False in examples and polish code (#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value * [release] grok-1 inference benchmark (#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [shardformer]Fix lm parallel. (#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix * [fix] fix grok-1 example typo (#5506) * [devops] fix example test ci (#5504) * Fix ColoTensorSpec for py11 (#5440) * fixed layout converter caching and updated tester * Empty-Commit * [shardformer] update colo attention to support custom mask (#5510) * [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests * [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions <github-actions@github.com> * [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [Fix] Grok-1 use tokenizer from the same pretrained path (#5532) * [fix] use tokenizer from the same pretrained path * trust remote code * [ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com> * [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests * fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * [hotfix] quick fixes to make legacy tutorials runnable (#5559) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [fix] fix typo s/muiti-node /multi-node etc. (#5448) * [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548) * [devops] remove post commit ci (#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]enable padding vocabulary size. (#5489) * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * padding vocab * padding vocabe * fix * fix * fxi * test ci * fix fix fix fix * fix fix * fix * fix * Update hybrid_parallel_plugin.py fix fix fix * fix fix * fix fix * fix * resolve super init resolve super init resolve super init resolve super init * resolve comments * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vocab checkpointio * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix fix fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * padding vocab * fix * fix fix * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * cherry-pick * revert moe modify * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix fix fix fix fix fix fix fix * resolve comments resolve comments resolve comments resolve comments resolve comments * ptensor ptensor resolve comments fix fix fix fix fix resolve comments resolve comments resolve comments resolve comments resolve comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rebase * fix rebase --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
354 lines
12 KiB
Python
354 lines
12 KiB
Python
import copy
|
|
from contextlib import nullcontext
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import Tensor
|
|
from torch import distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
from torch.nn import Module
|
|
from torch.optim import Adam, Optimizer
|
|
from torch.testing import assert_close
|
|
|
|
from colossalai.booster import Booster
|
|
from colossalai.booster.plugin import HybridParallelPlugin
|
|
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
|
from colossalai.checkpoint_io.utils import gather_distributed_param
|
|
from colossalai.lazy import LazyInitContext
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
from colossalai.shardformer._utils import getattr_
|
|
from colossalai.shardformer.policies.auto_policy import Policy
|
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
|
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
|
|
|
|
|
|
def build_model(
|
|
model_fn,
|
|
enable_fused_normalization=True,
|
|
enable_tensor_parallelism=True,
|
|
enable_flash_attention=False,
|
|
enable_jit_fused=False,
|
|
enable_sequence_parallelism=False,
|
|
use_lazy_init: bool = False,
|
|
dtype=torch.float32,
|
|
):
|
|
# create new model
|
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
|
with ctx:
|
|
# create new model
|
|
org_model = model_fn()
|
|
model_copy = copy.deepcopy(org_model)
|
|
if use_lazy_init:
|
|
ctx.materialize(org_model)
|
|
# shard model
|
|
shard_config = ShardConfig(
|
|
enable_fused_normalization=enable_fused_normalization,
|
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
|
enable_flash_attention=enable_flash_attention,
|
|
enable_jit_fused=enable_jit_fused,
|
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
|
)
|
|
model_copy = copy.deepcopy(org_model)
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
|
return org_model.cuda().to(dtype), sharded_model.cuda().to(dtype)
|
|
|
|
|
|
def build_pipeline_model(
|
|
model_fn,
|
|
stage_manager=None,
|
|
enable_fused_normalization=False,
|
|
enable_tensor_parallelism=False,
|
|
use_lazy_init: bool = False,
|
|
policy: Optional[Policy] = None,
|
|
):
|
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
|
with ctx:
|
|
# create new model
|
|
org_model = model_fn()
|
|
model_copy = copy.deepcopy(org_model)
|
|
if use_lazy_init:
|
|
ctx.materialize(org_model)
|
|
|
|
# shard model
|
|
shard_config = ShardConfig(
|
|
enable_fused_normalization=enable_fused_normalization,
|
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
|
pipeline_stage_manager=stage_manager,
|
|
)
|
|
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
|
sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
|
|
return org_model.cuda(), sharded_model.cuda()
|
|
|
|
|
|
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
|
# prepare input
|
|
data = data_gen_fn()
|
|
data = {k: v.cuda() for k, v in data.items()}
|
|
# switch to train mode
|
|
original_model.train()
|
|
sharded_model.train()
|
|
# run forward
|
|
org_output = original_model(**data)
|
|
org_output = output_transform_fn(org_output)
|
|
org_loss = loss_fn(org_output)
|
|
|
|
shard_output = sharded_model(**data)
|
|
shard_output = output_transform_fn(shard_output)
|
|
shard_loss = loss_fn(shard_output)
|
|
return org_output, org_loss, shard_output, shard_loss
|
|
|
|
|
|
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
|
|
org_sd = org_model.state_dict()
|
|
shard_sd = sharded_model.state_dict()
|
|
for k, v in org_sd.items():
|
|
assert k in shard_sd, f"{name} {k} not in sharded model"
|
|
shard_v = shard_sd[k]
|
|
assert v.shape == shard_v.shape, f"{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}"
|
|
assert v.dtype == shard_v.dtype, f"{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}"
|
|
assert torch.equal(v, shard_v), f"{name} {k} value mismatch"
|
|
|
|
|
|
def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):
|
|
use_lazy_init = False
|
|
if "use_lazy_init" in test_config:
|
|
use_lazy_init = test_config.pop("use_lazy_init")
|
|
|
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
|
with ctx:
|
|
org_model = model_fn()
|
|
sharded_model = copy.deepcopy(org_model)
|
|
if use_lazy_init:
|
|
ctx.materialize(org_model)
|
|
org_model = org_model.cuda()
|
|
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
|
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
|
criterion = loss_fn
|
|
|
|
plugin = HybridParallelPlugin(**test_config)
|
|
booster = Booster(plugin=plugin)
|
|
|
|
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
|
return (
|
|
org_model,
|
|
org_optimizer,
|
|
sharded_model,
|
|
sharded_optimizer,
|
|
criterion,
|
|
booster,
|
|
)
|
|
|
|
|
|
def run_forward_backward_with_hybrid_plugin(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
sharded_optimizer: Optimizer,
|
|
data_gen_fn: Callable,
|
|
output_transform_fn: Callable,
|
|
criterion: Callable,
|
|
booster: Booster,
|
|
):
|
|
org_model.cuda()
|
|
sharded_model.cuda()
|
|
|
|
def _criterion(outputs, inputs):
|
|
outputs = output_transform_fn(outputs)
|
|
loss = criterion(outputs)
|
|
return loss
|
|
|
|
data = data_gen_fn()
|
|
|
|
shard_test_data = {}
|
|
for k, v in data.items():
|
|
shard_test_data[k] = data[k].clone()
|
|
unshard_test_data = {}
|
|
for k, v in data.items():
|
|
unshard_test_data[k] = data[k].clone()
|
|
|
|
sharded_model.train()
|
|
if booster.plugin.stage_manager is not None:
|
|
for k, v in shard_test_data.items():
|
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
|
new_shape = [1] * v.dim()
|
|
new_shape[0] = 4
|
|
shard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
|
|
|
data_iter = iter([shard_test_data])
|
|
sharded_output = booster.execute_pipeline(
|
|
data_iter,
|
|
sharded_model,
|
|
_criterion,
|
|
sharded_optimizer,
|
|
return_loss=True,
|
|
return_outputs=True,
|
|
)
|
|
sharded_loss = sharded_output["loss"]
|
|
|
|
else:
|
|
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
|
|
sharded_output = sharded_model(**shard_test_data)
|
|
sharded_loss = criterion(sharded_output)
|
|
sharded_optimizer.backward(sharded_loss)
|
|
|
|
org_model.train()
|
|
if booster.plugin.stage_manager is not None:
|
|
for k, v in unshard_test_data.items():
|
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
|
new_shape = [1] * v.dim()
|
|
new_shape[0] = 4
|
|
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
|
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
|
|
org_output = org_model(**unshard_test_data)
|
|
org_loss = criterion(org_output)
|
|
org_loss.backward()
|
|
|
|
return org_loss, org_output, sharded_loss, sharded_output
|
|
|
|
|
|
def check_output_hidden_state(
|
|
org_output: Tensor,
|
|
sharded_output: Tensor,
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
atol: float = 1e-5,
|
|
rtol: float = 1e-3,
|
|
):
|
|
org_hidden_state = org_output.last_hidden_state
|
|
|
|
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
|
|
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
|
else:
|
|
sharded_hidden_state = sharded_output.last_hidden_state
|
|
|
|
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
|
|
|
|
|
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
|
|
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
|
|
|
|
|
|
def check_weight(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
layer_suffix: List[str],
|
|
tp_group: Optional[ProcessGroup] = None,
|
|
dim: int = 0,
|
|
atol: float = 1e-5,
|
|
rtol: float = 1e-3,
|
|
verbose: bool = False,
|
|
):
|
|
for suffix in layer_suffix:
|
|
org_weight = getattr_(org_model, suffix).weight
|
|
sharded_weight = getattr_(sharded_model, suffix).weight
|
|
|
|
# skip if layer is not held by this process
|
|
if sharded_weight is None:
|
|
continue
|
|
|
|
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
|
sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False)
|
|
|
|
if is_padded_tensor(sharded_weight):
|
|
sharded_weight = to_unpadded_tensor(sharded_weight)
|
|
|
|
if verbose and dist.get_rank() == 0:
|
|
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
|
|
|
|
assert_close(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol)
|
|
|
|
|
|
def get_grad_tensors_for_check(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
layer_suffix: List[str],
|
|
tp_group: ProcessGroup = None,
|
|
dim: int = 0,
|
|
atol: float = 1e-5,
|
|
rtol: float = 1e-3,
|
|
verbose: bool = False,
|
|
name: str = None,
|
|
):
|
|
grad_to_check = {}
|
|
for suffix in layer_suffix:
|
|
org_grad = getattr_(org_model, suffix).weight.grad
|
|
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
|
shard_weight = getattr_(sharded_model, suffix).weight
|
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
|
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
|
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
|
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
|
|
|
# embedding may be resized when using tensor parallel
|
|
if shard_grad.shape[0] > org_grad.shape[0]:
|
|
shard_grad = shard_grad[: org_grad.shape[0], :]
|
|
if verbose and dist.get_rank() == 0:
|
|
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
|
|
|
grad_to_check[suffix] = {
|
|
"org_grad": org_grad.float(),
|
|
"shard_grad": shard_grad.float(),
|
|
"rtol": rtol,
|
|
"atol": atol,
|
|
}
|
|
|
|
return grad_to_check
|
|
|
|
|
|
# used by sam/blip2
|
|
def check_grad(
|
|
org_model: Module,
|
|
sharded_model: Module,
|
|
layer_suffix: List[str],
|
|
tp_group: ProcessGroup = None,
|
|
dim: int = 0,
|
|
atol: float = 1e-5,
|
|
rtol: float = 1e-3,
|
|
verbose: bool = False,
|
|
):
|
|
for suffix in layer_suffix:
|
|
org_grad = getattr_(org_model, suffix).weight.grad
|
|
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
|
shard_weight = getattr_(sharded_model, suffix).weight
|
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
|
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
|
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
|
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
|
|
|
# embedding may be resized when using tensor parallel
|
|
if shard_grad.shape[0] > org_grad.shape[0]:
|
|
shard_grad = shard_grad[: org_grad.shape[0], :]
|
|
if verbose and dist.get_rank() == 0:
|
|
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
|
|
|
assert_close(org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol)
|
|
|
|
|
|
def unwrap_model(
|
|
module: Module,
|
|
base_model_class_name: Optional[str] = None,
|
|
base_model_attribute_name: Optional[str] = None,
|
|
):
|
|
if isinstance(module, HybridParallelModule):
|
|
module = module.unwrap()
|
|
if base_model_class_name is None:
|
|
return module
|
|
if module.__class__.__name__ == base_model_class_name:
|
|
return module
|
|
return getattr(module, base_model_attribute_name, None)
|
|
|
|
|
|
def check_all_grad_tensors(check_tensors):
|
|
"""
|
|
"org_grad": tensor to be compared from the original model
|
|
"shard_grad": tensor to be compared from the sharded model
|
|
"""
|
|
for suffix, check_info in check_tensors.items():
|
|
org_grad = check_info["org_grad"]
|
|
shard_grad = check_info["shard_grad"]
|
|
rtol = check_info["rtol"]
|
|
atol = check_info["atol"]
|
|
assert_close(org_grad, shard_grad, atol=atol, rtol=rtol)
|