mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[shardformer] tests for 3d parallel (#4493)
This commit is contained in:
@@ -56,7 +56,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
# unwrap model
|
||||
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
|
||||
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
|
||||
col_layer_for_check = ['layers[0].self_attn.o_proj']
|
||||
@@ -156,12 +155,40 @@ def run_llama_test(test_config):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_llama_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test()
|
||||
|
||||
|
||||
def check_llama_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
@@ -169,5 +196,13 @@ def test_llama():
|
||||
spawn(check_llama, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama_3d():
|
||||
spawn(check_llama_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
||||
test_llama_3d()
|
||||
|
Reference in New Issue
Block a user