[pipeline] refactor test pipeline and remove useless utils in pipeline (#4324)

* refactor tests

* refactor bloom model

* finish policy tests

* refactor tests

* fix test pure pipeline

* remove test pipeline and cutdown launch process

* refactor tests

* refactor bloom model

* finish policy tests

* refactor tests

* fix test pure pipeline

* remove test pipeline and cutdown launch process
This commit is contained in:
Jianghai
2023-08-01 10:35:17 +08:00
committed by Hongxin Liu
parent d3c6cd66f3
commit f13954cd58
14 changed files with 138 additions and 1246 deletions

View File

@@ -1,64 +0,0 @@
import pytest
import torch
import torch.distributed as dist
from transformers.models.bert import BertConfig
from transformers.models.bert.modeling_bert import BertForPreTraining
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_for_pretraining_policy():
configuration = BertConfig()
model = BertForPreTraining(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BertForPreTrainingPolicy()
model_policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
model_policy.set_shard_config(model_config)
layers = model_policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 6 + 1
else:
assert len(layers) == 6 + 2
def run_dist_policy(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_for_pretraining_policy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_for_pretraining_policy():
spawn(run_dist_policy, 4)
if __name__ == "__main__":
"""test the bert for pretraining model forward and bert for pretraining model policy"""
test_bert_for_pretraining_policy()

View File

@@ -1,64 +0,0 @@
import pytest
import torch
import torch.distributed as dist
from transformers.models.bert import BertConfig
from transformers.models.bert.modeling_bert import BertLMHeadModel
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_lmhead_policy():
configuration = BertConfig()
model = BertLMHeadModel(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BertLMHeadModelPolicy()
model_policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
model_policy.set_shard_config(model_config)
layers = model_policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 6 + 1
else:
assert len(layers) == 6 + 2
def run_dist_policy(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_lmhead_policy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_lmhead_policy():
spawn(run_dist_policy, 4)
if __name__ == "__main__":
"""test the bert for lm head model policy"""
test_bert_lmhead_policy()

View File

@@ -1,66 +0,0 @@
'''
In the test policy we only test policy: held layers and others, as the tests for forward logic are done in test_shardformer/test_model
'''
import pytest
import torch.distributed as dist
from transformers.models.bert.modeling_bert import BertModel
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertModelPolicy
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_model_policy():
model = BertModel.from_pretrained('bert-base-uncased')
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BertModelPolicy()
model_policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
model_policy.set_shard_config(model_config)
layers = model_policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 6 + 1
else:
assert len(layers) == 6 + 1
def run_dist_policy(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_model_policy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_model_policy():
spawn(run_dist_policy, 4)
if __name__ == "__main__":
"""test the bert model policy"""
test_bert_model_policy()

View File

@@ -1,63 +0,0 @@
import pytest
import torch
import torch.distributed as dist
from transformers.models.bloom import BloomConfig, BloomModel
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bloom import BloomModelPolicy
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bloom_model_policy():
# create a BloomModel
configuration = BloomConfig()
model = BloomModel(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BloomModelPolicy()
model_policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
model_policy.set_shard_config(model_config)
layers = model_policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 1 + 2
else:
assert len(layers) == 1 + 1
def run_dist_policy(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bloom_model_policy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bloom_model_policy():
spawn(run_dist_policy, 4)
if __name__ == "__main__":
"""test the bloom model policy"""
test_bloom_model_policy()

View File

@@ -2,7 +2,10 @@ import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,

View File

@@ -5,6 +5,8 @@ import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.shardformer.shard import ShardConfig
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
@@ -17,9 +19,55 @@ from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
pass
def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
stage_manager = stage_manager
policy = get_autopolicy(model)
policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
policy.set_shard_config(model_config)
layers = policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 1 + 1
else:
if name == "transformers_bert":
assert len(layers) == 1 + 1
elif name in [
"transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification",
"transformers_bert_for_mcq"
]:
assert len(layers) == 1 + 3
else:
assert len(layers) == 1 + 2
def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
if name == 'transformers_bert_for_mcq':
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
attention_mask = torch.ones_like(x).cuda()
if stage_manager.stage == 0:
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
assert output['hidden_states'].shape == (6, 3, 128)
else:
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
output = sharded_model(input_ids=x,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
assert output[0].shape == (2, 3)
else:
x = torch.randint(0, 1000, (2, 3)).cuda()
# one batch, 2 single sentences, each sentence has 3 tokens
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
assert output['hidden_states'].shape == (2, 3, 128)
else:
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
assert output[0].shape[0] == 2
@parameterize('enable_fused_normalization', [False])
@@ -27,55 +75,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_bert
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
PP_DIM = 0
PP_SIZE = 2
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
if name == 'transformers_bert_for_mcq':
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
attention_mask = torch.ones_like(x).cuda()
if stage_manager.stage == 0:
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
assert output['hidden_states'].shape == (6, 3, 128)
else:
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
output = sharded_model(input_ids=x,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
assert output[0].shape == (2, 3)
else:
x = torch.randint(0, 1000, (2, 3)).cuda()
# one batch, 2 single sentences, each sentence has 3 tokens
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
assert output['hidden_states'].shape == (2, 3, 128)
else:
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
assert output[0].shape[0] == 2
check_bert_model_policy(name, org_model, stage_manager)
check_bert_model_pipeline_forward(name, sharded_model, stage_manager)
torch.cuda.empty_cache()
@@ -90,7 +100,7 @@ def check_bert(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert():
spawn(check_bert, 4)
spawn(check_bert, 2)
if __name__ == "__main__":

View File

@@ -5,7 +5,9 @@ import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard import ShardConfig
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
@@ -18,9 +20,37 @@ from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
pass
def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
policy = get_autopolicy(model)
policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
policy.set_shard_config(model_config)
layers = policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 0 + 2
else:
if name == 'transformers_bloom':
assert len(layers) == 1 + 1
elif name == 'transformers_bloom_for_token_classification':
assert len(layers) == 1 + 3
else:
assert len(layers) == 1 + 2
def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
if stage_manager.stage == 0:
x = torch.randint(0, 1000, (1, 3)).cuda()
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (1, 3, 64)
else:
attention_mask = torch.ones((1, 3)).cuda()
hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0].shape[0] == 1
@parameterize('enable_fused_normalization', [False])
@@ -28,40 +58,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_bloom
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
PP_DIM = 0
PP_SIZE = 2
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
x = torch.randint(0, 1000, (1, 3)).cuda()
hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (1, 3, 64)
else:
attention_mask = torch.ones((1, 3)).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0].shape[0] == 1
check_bloom_model_policy(name, org_model, stage_manager)
check_bloom_model_pipeline_forward(name, sharded_model, stage_manager)
torch.cuda.empty_cache()
@@ -76,7 +83,7 @@ def check_bloom(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
spawn(check_bloom, 4)
spawn(check_bloom, 2)
if __name__ == "__main__":

View File

@@ -5,7 +5,9 @@ import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard import ShardConfig
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
@@ -18,9 +20,35 @@ from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
pass
def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
policy = get_autopolicy(model)
policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
policy.set_shard_config(model_config)
layers = policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 2 + 1
else:
if name == "transformers_llama":
assert len(layers) == 2 + 1
else:
assert len(layers) == 2 + 2
def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
x = torch.randint(0, 1000, (2, 3)).cuda()
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (2, 3, 128)
else:
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0] is not None
@parameterize('enable_fused_normalization', [False])
@@ -28,40 +56,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_llama
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
PP_DIM = 0
PP_SIZE = 2
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
x = torch.randint(0, 1000, (2, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (2, 3, 128)
else:
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0] is not None
check_llama_model_policy(name, org_model, stage_manager)
check_llama_model_pipeline_forward(name, sharded_model, stage_manager)
torch.cuda.empty_cache()
@@ -76,7 +82,7 @@ def check_llama(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 4)
spawn(check_llama, 2)
if __name__ == "__main__":