mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[pipeline] add bloom model pipeline (#4210)
* bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * finish bloom model * test shard gpt2 * clear cache
This commit is contained in:
@@ -51,15 +51,17 @@ output_transform_fn = lambda x: x
|
||||
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.GPT2Config(n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification")
|
||||
config = transformers.GPT2Config(
|
||||
n_layer=2,
|
||||
n_head=4,
|
||||
#n_embd=128,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification")
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_gpt',
|
||||
|
@@ -0,0 +1,84 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_bloom
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32).cuda()
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name == 'transformers_bloom':
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||
assert output['hidden_states'].shape == (2, 3, 64)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
assert output[0].shape == (2, 3, 64)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_bloom(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_bloom_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom():
|
||||
spawn(check_bloom, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bloom()
|
@@ -70,6 +70,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
Reference in New Issue
Block a user