mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[shardformer] rewrite tests for opt/bloom/llama/vit/chatglm (#4395)
* rewrite opt tests * rewrite llama tests * rewrite bloom & vit tests * rewrite chatglm tests * fix LinearCol for classfiers * add judge for other tp layers, fix lazy init in util
This commit is contained in:
committed by
Hongxin Liu
parent
21e0a42fd1
commit
7711bd524a
@@ -1,60 +1,127 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
org_loss, org_output, sharded_loss, sharded_output = \
|
||||
run_forward_backward_with_hybrid_plugin(
|
||||
org_model,
|
||||
sharded_model,
|
||||
sharded_optimizer,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
criterion,
|
||||
booster)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
|
||||
if org_model.__class__.__name__ == 'ViTModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'ViTModel':
|
||||
vit_model = org_model
|
||||
shard_vit_model = sharded_model
|
||||
shard_vit_model = sharded_model.unwrap()
|
||||
else:
|
||||
vit_model = org_model.vit
|
||||
shard_vit_model = sharded_model.vit
|
||||
shard_vit_model = sharded_model.unwrap().vit
|
||||
|
||||
# check grad
|
||||
col_layer_for_check = ['encoder.layer[0].attention.attention.query']
|
||||
row_layer_for_check = ['encoder.layer[0].attention.output.dense']
|
||||
check_grad(vit_model, shard_vit_model, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(vit_model, shard_vit_model, row_layer_for_check, atol=1e-5, rtol=1e-3, dim=1, verbose=False)
|
||||
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
|
||||
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(vit_model,
|
||||
shard_vit_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-5,
|
||||
rtol=1e-3,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
check_grad(vit_model,
|
||||
shard_vit_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-5,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(vit_model,
|
||||
shard_vit_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=5e-3,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
def run_vit_test(test_config):
|
||||
|
||||
# TODO: add test_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
# TODO: add test_config for flash attention & jit operator after supporting
|
||||
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -68,7 +135,7 @@ def check_vit(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_vit():
|
||||
spawn(check_vit, 2)
|
||||
spawn(check_vit, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user