ColossalAI/tests/test_optimizer/test_dist_came.py
Haze188 416580b314
[MoE/ZeRO] Moe refactor with zero refactor (#5821)
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471)

* [Feauture] MoE refractor; Intergration with Mixtral  (#5682)

* cherry pick from refractor-moe branch

* tests passed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support ep + zero

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add mixtral auto policy & move pipeline forward code to modeling folder

* [moe refactor] modify kernel test without Route Class

* [moe refactor] add moe tensor test path environment variable to github workflow

* fix typos

* fix moe test bug due to the code rebase

* [moe refactor] fix moe zero test, and little bug in low level zero

* fix typo

* add moe tensor path to github workflow

* remove some useless code

* fix typo & unify global variable XX_AXIS logic without using -1

* fix typo & prettifier the code

* remove print code & support zero 2 test

* remove useless code

* reanme function

* fix typo

* fix typo

* Further improve the test code

* remove print code

* [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test

* [moe refactor] skip some unit test which will be refactored later

* [moe refactor] fix unit import error

* [moe refactor] fix circular import issues

* [moe refactor] remove debug code

* [moe refactor] update github workflow

* [moe/zero] refactor low level optimizer (#5767)

* [zero] refactor low level optimizer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] MoE refactor with newest version of ZeRO (#5801)

* [zero] remove redundant members in BucketStore (#5802)

* [zero] align api with previous version

* [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819)

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* [hotfix]Solve the compatibility issue of zero refactor (#5823)

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* Modify function parameter names to resolve compatibility issues

* [zero] fix missing hook removal (#5824)

* [MoE] Resolve .github conflict (#5829)

* [Fix/Example] Fix Llama Inference Loading Data Type (#5763)

* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3

* [release] update version (#5752)

* [release] update version

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [test] fix ddp plugin test

* [test] fix gptj and rpc test

* [devops] fix cuda ext compatibility

* [inference] fix flash decoding test

* [inference] fix flash decoding test

* fix (#5765)

* [test] Fix/fix testcase (#5770)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [Hotfix] Add missing init file in inference.executor (#5774)

* [CI/tests] simplify some test case to reduce testing time (#5755)

* [ci/tests] simplify some test case to reduce testing time

* [ci/tests] continue to remove test case to reduce ci time cost

* restore some test config

* [ci/tests] continue to reduce ci time cost

* [misc] update dockerfile (#5776)

* [misc] update dockerfile

* [misc] update dockerfile

* [devops] fix docker ci (#5780)

* [Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist

* [hotfix] fix llama flash attention forward (#5777)

* [misc] Accelerate CI for zero and dist optim (#5758)

* remove fp16 from lamb

* remove d2h copy in checking states

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Test/CI] remove test cases to reduce CI duration (#5753)

* [test] smaller gpt2 test case

* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py

* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py

* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py

* Revert "[test] smaller gpt2 test case"

Some tests might depend on the size of model (num of chunks)

This reverts commit df705a5210.

* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py

* [CI] smaller test model for two mwo the two modifid cases

* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there

* [hotfix] fix testcase in test_fx/test_tracer (#5779)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [fix] fix test_deepfm_model & test_dlrf_model;

* [fix] fix test_hf_albert & test_hf_gpt;

* [gemini] optimize reduce scatter d2h copy (#5760)

* [gemini] optimize reduce scatter d2h copy

* [fix] fix missing reduce variable

* [refactor] remove legacy async reduce scatter code

* [gemini] missing sync

* Revert "[refactor] remove legacy async reduce scatter code"

This reverts commit 58ad76d466.

* [gemini] further optimize with async all reduce

* [fix] pass flag from manager to chunk

* Allow building cuda extension without a device. (#5535)

Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are.

* [misc] fix dist logger (#5782)

* [install]fix setup (#5786)

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update requirements (#5787)

* [shardformer] fix import (#5788)

* upgrade colossal-chat support tp_group>1, add sp for sft

* upgrade ppo dpo rm script

* run pre-commit

* moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy

* fix training script

* fix ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix transformers version

* remove duplicated test

* fix datasets version

* remove models that require huggingface auth from ci

* remove local data path

* update ci

* remove baichuan from template test due to transformer version conflict

* merge

* Refactor modeling by adding attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix tests and naming

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Clean up

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* replace the customized dataloader setup with the build-in one

* replace the customized dataloader setup with the build-in one

* Remove flash attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* fix readme

* Fix test import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* update sft trainning script

* [Inference]refactor baichuan (#5791)

* refactor baichuan

* remove unused code and add TODO for lazyinit

* [test] fix chatglm test kit (#5793)

* [shardformer] fix modeling of bloom and falcon (#5796)

* [test] fix qwen2 pytest distLarge (#5797)

* [Inference] Fix flash-attn import and add model test (#5794)

* Fix torch int32 dtype

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix flash-attn import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add generalized model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Remove exposed path to model

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add default value for use_flash_attn

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Rename model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* [Gemini] Use async stream to prefetch and h2d data moving (#5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code

* [gemini] quick fix on possible async operation (#5803)

* [gemini] quick fix on possible async operation

* [gemini] quick fix on possible async operation

* [shardformer] upgrade transformers to 4.39.3 (#5815)

* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)

* [shardformer] fix modeling of gpt2 and gptj

* [shardformer] fix whisper modeling

* [misc] update requirements

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* [shardformer]upgrade transformers for mistral (#5808)

* upgrade transformers for mistral

* fix

* fix

* [shardformer]upgrade transformers for llama (#5809)

* update transformers

fix

* fix

* fix

* [inference] upgrade transformers (#5810)

* update transformers

fix

* fix

* fix

* fix

* fix

* [gemini] update transformers for gemini (#5814)

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* Support 4d parallel + flash attention (#5789)

* support tp + sp + pp

* remove comments

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>

* [zero] fix hook bug

* [zero] add low level optimizer back (#5839)

* [zero] fix param & refactor

* [zero] add back original low level opt

* [zero] remove moe related

* [zero] pass zero tests

* [zero] refactor

* [chore] add del func back

* [zero] comments and naming (#5840)

* [zero] modify api (#5843)

* [zero] modify api

* [test] remove _grad_store access in tests

* [test] fix (#5857)

* [CI] skip openmoe CI check

* [CI] fox pre-commit

* [zero] remove redundant memebr init (#5862)

* [misc] remove useless code, modify the pg mesh implementation

* [misc] remove useless code, modify the pg mesh implementation

* [misc] use tempfile

* resolve conflict with main branch

* [misc] use tempfile in test_moe_checkpoint.py

* [misc] remove useless code, add assertion about sequence parallel, move logger into function

* [misc] remove useless code

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
2024-06-28 14:00:08 +08:00

459 lines
16 KiB
Python

import copy
import pytest
import torch
import torch.distributed as dist
from torch import nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer.came import CAME
from colossalai.nn.optimizer.distributed_came import DistributedCAME
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor import get_layout, get_sharding_spec, is_distributed_tensor
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import check_dist_grad, check_dist_optim_state, check_dist_param, check_optim_states
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
build_model_from_low_level_zero_plugin,
run_forward_backward_with_hybrid_plugin,
run_forward_backward_with_low_level_zero_plugin,
unwrap_model,
)
HEIGHT = 128
WIDTH = 128
_TP_SPEC = DimSpec([0])
_SEED = 0
def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float32:
rtol = 5e-04
atol = 5e-04
elif dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
# return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol))
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
# setup param groups; (For zero test optim)
def setup_param_groups_zero(model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
# setup param groups; (For base optim)
def setup_param_groups(model: nn.Module) -> list:
optimizer_grouped_parameters = [p for n, p in model.named_parameters()]
return optimizer_grouped_parameters
# setup flatten param groups, sharding spec and shape; (For dist optim)
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
flatten_optimizer_grouped_parameters = []
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
for n, p in model.named_parameters():
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
flatten_optimizer_grouped_parameters.append(flatten_p)
if is_distributed_tensor(p):
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
param_shape[id(flatten_p)] = get_layout(p).global_shape
else:
sharding_spec[id(flatten_p)] = None
param_shape[id(flatten_p)] = p.shape
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape
def set_dist_grad(
dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup
) -> None:
"""
Set split grads for Tensor Parallel or ZeRO DP.
We do not need a separate treatment for ZeRO,
as the wrapper takes care of reduce-scattering grads.
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
if torch_p.grad is None:
torch_p.grad = torch.zeros_like(torch_p)
is_distributed = hasattr(p, "dist_layout")
if is_distributed:
sharding = p.dist_layout.sharding_spec.sharding_sequence
split_dim = sharding.index(_TP_SPEC)
shape = torch_p.split(world_size, dim=split_dim)[rank].shape
indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
# Generate grads only for the correctly split chunk
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))
else:
shape = torch_p.shape
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)
# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p
def set_master_param_to_shard_param(master_param_list) -> dict:
master_param_to_shard_param = {id(p): p for p in master_param_list}
return master_param_to_shard_param
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(HEIGHT, WIDTH)
self.linear2 = nn.Linear(WIDTH, HEIGHT)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
class TPModel(nn.Module):
def __init__(self, linear1, linear2, tp_group=None):
super().__init__()
self.linear1 = Linear1D_Col.from_native_module(
linear1, process_group=tp_group, gather_output=False, overlap=True
)
self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (4, 1), (1, 4)
def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
tp_size, zero_size = tp_zero_size
use_zero = True if zero_size > 1 else False
local_rank = dist.get_rank()
clear_layout_converter()
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1)
torch.set_default_dtype(dtype)
# set_seed(42)
# ==============================
# Model Init
# ==============================
base_model = MlpModel().to(local_rank)
tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
base_param_group = setup_param_groups(base_model)
tp_param_group = setup_param_groups(tp_model)
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
# ==============================
# Optimizer Init
# ==============================
base_optim = CAME(base_param_group, lr=1e-3)
dist_optim = DistributedCAME(tp_param_group, lr=1e-3)
# Setup distributed optimizer
if zero_size > 1:
dist_optim = LowLevelZeroOptimizer(
dist_optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
else:
shard_to_param = set_master_param_to_shard_param(tp_param_group)
dist_optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
# ==============================
# Correctness Verify
# ==============================
seed_all(1024)
x = torch.randn(WIDTH, HEIGHT, device=local_rank)
out = base_model(x)
out_tp = tp_model(x)
if zero_size > 1:
dist_optim.backward(out_tp.sum())
out.sum().backward()
else:
out_tp.sum().backward()
out.sum().backward()
base_optim.step()
dist_optim.step()
base_optim.zero_grad()
dist_optim.zero_grad()
for p, tp_p in zip(base_param_group, tp_param_group):
param_is_distributed = is_distributed_tensor(tp_p)
if param_is_distributed:
shard_spec = get_sharding_spec(tp_p)
if len(shard_spec.sharding_sequence) >= 2:
# Col Parallel
if shard_spec.sharding_sequence[0] == "R":
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
# ROW Parallel
if shard_spec.sharding_sequence[-1] == "R":
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
else:
# TP bias
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
else:
# No TP bias
pass
correctness_verify(p.data, tp_p.data, dtype)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Fwd/Bwd Test Passed")
@parameterize(
"test_config",
[
{
"stage": 1,
"precision": "bf16",
},
{
"stage": 2,
"precision": "bf16",
},
],
)
def exam_bert_test_on_lowlevelzero_plugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["initial_scale"] = 2**10
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-4, 5e-4
else:
atol, rtol = 5e-4, 5e-4
# test_config["initial_scale"] = 1
model_list = [
"transformers_bert",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
seed_all(_SEED)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, CAME, DistributedCAME)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
# assert same output
# assert_close(org_output, org_output, atol=atol, rtol=rtol)
weight_layer_for_check = [
"bert.encoder.layer.1.intermediate.dense",
# TODO: error in layer:
# "bert.encoder.layer.0.output.dense",
# "bert.encoder.layer.1.output.dense",
]
# assert same weight before step; pass
check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol)
# asserr loss; pass
assert_close(org_loss, sharded_loss)
# assert same grad before step
# TODO: err here; backward diff gard; Only transformers_bert pass;
check_dist_grad(sharded_optimizer, org_model, sharded_model, weight_layer_for_check, atol, rtol)
org_optimizer.step()
sharded_optimizer.step()
# assert same weight after step
check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol)
check_optim_states(org_optimizer, sharded_optimizer.optim)
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"LowLevelZeroPlugin + Bert Model Zoo Test Passed")
@parameterize(
"test_config",
[
{
"tp_size": 1,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 4,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
},
{
"tp_size": 4,
"num_microbatches": 4,
"zero_stage": 0,
"precision": "bf16",
},
],
)
def exam_bert_test_on_hybrid_plugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
]
# pass "transformers_bert",
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-3, 5e-3
else:
atol, rtol = 5e-3, 5e-3
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, CAME)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
# TODO: model
# "encoder.layer.0.output.dense.weight", "encoder.layer.1.output.dense.weight" not match
# "encoder.layer[0].output.dense", "encoder.layer[1].output.dense" not match
weight_layer_for_check = ["embeddings.word_embeddings"] # [30522, 128]
# # assert same weight before step; all pass
# check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol)
# # assert loss; all pass
# assert_close(org_loss, sharded_loss)
# # assert same grad before step; all pass
# check_dist_grad(org_model, sharded_model, weight_layer_for_check, atol, rtol)
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_dist_param(bert, sharded_bert, weight_layer_for_check, atol, rtol)
# check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check optim states
check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"HybridParallelPlugin + Bert Model Zoo Test Passed")
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_bert_test_on_lowlevelzero_plugin() # err in TODO layer
exam_bert_test_on_hybrid_plugin() # pass
exam_dist_came_base() # pass
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_came():
spawn(run_dist, nprocs=4)
if __name__ == "__main__":
test_dist_came()