mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354)
* add naive optimizer for 3DPlugin/refactor gpt2 shardformer test * merge tests of PP/DP/TP combinations into one test file * fix bug when sync grad for dp in HybridPlugin * update supported precisions for 3DPlugin/fix bug when shifting tp_degree * improve the passing of lazy_init * modify lazy_init/use sync_shared_params
This commit is contained in:
committed by
Hongxin Liu
parent
f13954cd58
commit
0ceec8f9a9
@@ -160,7 +160,6 @@ def check_llama(rank, world_size, port):
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.skip('This test will fail')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@@ -1,85 +1,180 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
from colossalai.tensor.d_tensor.api import (
|
||||
clear_layout_converter,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
# do backward
|
||||
use_lazy_init = False
|
||||
if 'use_lazy_init' in test_config:
|
||||
use_lazy_init = test_config.pop('use_lazy_init')
|
||||
|
||||
if use_lazy_init:
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
|
||||
# prepare booster
|
||||
plugin = HybridParallelPlugin(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
stage_manager = plugin.stage_manager
|
||||
|
||||
# prepare models and optimizers
|
||||
with ctx:
|
||||
org_model = model_fn().cuda()
|
||||
sharded_model = copy.deepcopy(org_model)
|
||||
|
||||
if use_lazy_init:
|
||||
org_model = ctx.materialize(org_model)
|
||||
|
||||
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
||||
criterion = loss_fn
|
||||
|
||||
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
outputs = output_transform_fn(outputs)
|
||||
loss = criterion(outputs)
|
||||
return loss
|
||||
|
||||
# do forward and backward
|
||||
data = data_gen_fn()
|
||||
sharded_model.train()
|
||||
if stage_manager:
|
||||
data = {
|
||||
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
data_iter = iter([data])
|
||||
sharded_output = booster.execute_pipeline(data_iter,
|
||||
sharded_model,
|
||||
_criterion,
|
||||
sharded_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
sharded_loss = sharded_output['loss']
|
||||
else:
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
sharded_output = sharded_model(**data)
|
||||
sharded_loss = criterion(sharded_output)
|
||||
sharded_loss.backward()
|
||||
|
||||
org_model.train()
|
||||
org_output = org_model(**data)
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
|
||||
# check last hidden state
|
||||
if org_model.__class__.__name__ == 'GPT2Model':
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
if stage_manager is None:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
if stage_manager and stage_manager.is_last_stage():
|
||||
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']],
|
||||
dim=0)
|
||||
|
||||
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \
|
||||
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
|
||||
|
||||
# check loss
|
||||
assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \
|
||||
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'GPT2Model':
|
||||
org_model = org_model
|
||||
sharded_model = sharded_model
|
||||
sharded_model = sharded_model.unwrap()
|
||||
else:
|
||||
org_model = org_model.transformer
|
||||
sharded_model = sharded_model.transformer
|
||||
sharded_model = sharded_model.unwrap().transformer
|
||||
|
||||
# check mlp grad
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
||||
# check weights and gradients
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
|
||||
# check embedding weights
|
||||
org_grad = org_model.wte.weight.grad
|
||||
shard_grad = sharded_model.wte.weight.grad
|
||||
shard_weight = sharded_model.wte.weight
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(plugin.tp_size)]
|
||||
dist.all_gather(shard_grad_list, shard_grad, plugin.tp_group)
|
||||
shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
|
||||
assert torch.allclose(org_grad, shard_grad, atol=1e-5, rtol=1e-3), \
|
||||
f"shard model grad is not equal to origin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
|
||||
org_weight = org_model.h[0].mlp.c_fc.weight
|
||||
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)]
|
||||
dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group)
|
||||
shard_weight = torch.cat(shard_weight_list, dim=1)
|
||||
|
||||
assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \
|
||||
f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}"
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': True
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
def run_gpt2_test(test_config):
|
||||
|
||||
# TODO: add plugin_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -93,7 +188,7 @@ def check_gpt2(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2():
|
||||
spawn(check_gpt2, 2)
|
||||
spawn(check_gpt2, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,72 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# TODO: add tests for forward/backward later
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_gpt2
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
input_ids = inputs['input_ids']
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = sharded_model.config.n_embd
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
# change inputs if not the first stage
|
||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
if name == 'transformers_gpt':
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_gpt2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_gpt2_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2():
|
||||
spawn(check_gpt2, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt2()
|
Reference in New Issue
Block a user