mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[pipeline] refactor test pipeline and remove useless utils in pipeline (#4324)
* refactor tests * refactor bloom model * finish policy tests * refactor tests * fix test pure pipeline * remove test pipeline and cutdown launch process * refactor tests * refactor bloom model * finish policy tests * refactor tests * fix test pure pipeline * remove test pipeline and cutdown launch process
This commit is contained in:
@@ -1,64 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.models.bert import BertConfig
|
||||
from transformers.models.bert.modeling_bert import BertForPreTraining
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_for_pretraining_policy():
|
||||
configuration = BertConfig()
|
||||
model = BertForPreTraining(configuration)
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
# print(pg_mesh)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
|
||||
model_policy = BertForPreTrainingPolicy()
|
||||
model_policy.set_model(model)
|
||||
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
model_policy.set_shard_config(model_config)
|
||||
layers = model_policy.get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 6 + 1
|
||||
else:
|
||||
assert len(layers) == 6 + 2
|
||||
|
||||
|
||||
def run_dist_policy(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bert_for_pretraining_policy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bert_for_pretraining_policy():
|
||||
spawn(run_dist_policy, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""test the bert for pretraining model forward and bert for pretraining model policy"""
|
||||
test_bert_for_pretraining_policy()
|
@@ -1,64 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.models.bert import BertConfig
|
||||
from transformers.models.bert.modeling_bert import BertLMHeadModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_lmhead_policy():
|
||||
configuration = BertConfig()
|
||||
model = BertLMHeadModel(configuration)
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
# print(pg_mesh)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
|
||||
model_policy = BertLMHeadModelPolicy()
|
||||
model_policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
model_policy.set_shard_config(model_config)
|
||||
layers = model_policy.get_held_layers()
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 6 + 1
|
||||
else:
|
||||
assert len(layers) == 6 + 2
|
||||
|
||||
|
||||
def run_dist_policy(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bert_lmhead_policy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bert_lmhead_policy():
|
||||
spawn(run_dist_policy, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""test the bert for lm head model policy"""
|
||||
test_bert_lmhead_policy()
|
@@ -1,66 +0,0 @@
|
||||
'''
|
||||
In the test policy we only test policy: held layers and others, as the tests for forward logic are done in test_shardformer/test_model
|
||||
'''
|
||||
|
||||
import pytest
|
||||
import torch.distributed as dist
|
||||
from transformers.models.bert.modeling_bert import BertModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertModelPolicy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_model_policy():
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
# print(pg_mesh)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
|
||||
model_policy = BertModelPolicy()
|
||||
model_policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
model_policy.set_shard_config(model_config)
|
||||
|
||||
layers = model_policy.get_held_layers()
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 6 + 1
|
||||
else:
|
||||
assert len(layers) == 6 + 1
|
||||
|
||||
|
||||
def run_dist_policy(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bert_model_policy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bert_model_policy():
|
||||
spawn(run_dist_policy, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""test the bert model policy"""
|
||||
test_bert_model_policy()
|
@@ -1,63 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.models.bloom import BloomConfig, BloomModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bloom import BloomModelPolicy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bloom_model_policy():
|
||||
# create a BloomModel
|
||||
configuration = BloomConfig()
|
||||
model = BloomModel(configuration)
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
# print(pg_mesh)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
|
||||
model_policy = BloomModelPolicy()
|
||||
model_policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
model_policy.set_shard_config(model_config)
|
||||
layers = model_policy.get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 1 + 2
|
||||
else:
|
||||
assert len(layers) == 1 + 1
|
||||
|
||||
|
||||
def run_dist_policy(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bloom_model_policy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bloom_model_policy():
|
||||
spawn(run_dist_policy, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""test the bloom model policy"""
|
||||
test_bloom_model_policy()
|
@@ -2,7 +2,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
|
@@ -5,6 +5,8 @@ import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
@@ -17,9 +19,55 @@ from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
pass
|
||||
def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
|
||||
stage_manager = stage_manager
|
||||
policy = get_autopolicy(model)
|
||||
policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
policy.set_shard_config(model_config)
|
||||
layers = policy.get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 1 + 1
|
||||
else:
|
||||
if name == "transformers_bert":
|
||||
assert len(layers) == 1 + 1
|
||||
elif name in [
|
||||
"transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_mcq"
|
||||
]:
|
||||
assert len(layers) == 1 + 3
|
||||
else:
|
||||
assert len(layers) == 1 + 2
|
||||
|
||||
|
||||
def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
|
||||
if name == 'transformers_bert_for_mcq':
|
||||
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
assert output['hidden_states'].shape == (6, 3, 128)
|
||||
else:
|
||||
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
|
||||
output = sharded_model(input_ids=x,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
assert output[0].shape == (2, 3)
|
||||
else:
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
# one batch, 2 single sentences, each sentence has 3 tokens
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
assert output[0].shape[0] == 2
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@@ -27,55 +75,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_bert
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
PP_DIM = 0
|
||||
PP_SIZE = 2
|
||||
pg_mesh = ProcessGroupMesh(PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
|
||||
if name == 'transformers_bert_for_mcq':
|
||||
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
assert output['hidden_states'].shape == (6, 3, 128)
|
||||
else:
|
||||
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
|
||||
output = sharded_model(input_ids=x,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
assert output[0].shape == (2, 3)
|
||||
else:
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
# one batch, 2 single sentences, each sentence has 3 tokens
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
assert output[0].shape[0] == 2
|
||||
check_bert_model_policy(name, org_model, stage_manager)
|
||||
check_bert_model_pipeline_forward(name, sharded_model, stage_manager)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -90,7 +100,7 @@ def check_bert(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bert():
|
||||
spawn(check_bert, 4)
|
||||
spawn(check_bert, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -5,7 +5,9 @@ import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
@@ -18,9 +20,37 @@ from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
pass
|
||||
def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
|
||||
policy = get_autopolicy(model)
|
||||
policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
policy.set_shard_config(model_config)
|
||||
layers = policy.get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 0 + 2
|
||||
else:
|
||||
if name == 'transformers_bloom':
|
||||
assert len(layers) == 1 + 1
|
||||
elif name == 'transformers_bloom_for_token_classification':
|
||||
assert len(layers) == 1 + 3
|
||||
else:
|
||||
assert len(layers) == 1 + 2
|
||||
|
||||
|
||||
def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
|
||||
if stage_manager.stage == 0:
|
||||
x = torch.randint(0, 1000, (1, 3)).cuda()
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||
assert output['hidden_states'].shape == (1, 3, 64)
|
||||
else:
|
||||
attention_mask = torch.ones((1, 3)).cuda()
|
||||
hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
|
||||
output = sharded_model(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
assert output[0].shape[0] == 1
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@@ -28,40 +58,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_bloom
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
PP_DIM = 0
|
||||
PP_SIZE = 2
|
||||
pg_mesh = ProcessGroupMesh(PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
x = torch.randint(0, 1000, (1, 3)).cuda()
|
||||
hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||
assert output['hidden_states'].shape == (1, 3, 64)
|
||||
else:
|
||||
attention_mask = torch.ones((1, 3)).cuda()
|
||||
output = sharded_model(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
assert output[0].shape[0] == 1
|
||||
check_bloom_model_policy(name, org_model, stage_manager)
|
||||
check_bloom_model_pipeline_forward(name, sharded_model, stage_manager)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -76,7 +83,7 @@ def check_bloom(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom():
|
||||
spawn(check_bloom, 4)
|
||||
spawn(check_bloom, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -5,7 +5,9 @@ import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
@@ -18,9 +20,35 @@ from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
pass
|
||||
def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
|
||||
policy = get_autopolicy(model)
|
||||
policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
policy.set_shard_config(model_config)
|
||||
layers = policy.get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 2 + 1
|
||||
else:
|
||||
if name == "transformers_llama":
|
||||
assert len(layers) == 2 + 1
|
||||
else:
|
||||
assert len(layers) == 2 + 2
|
||||
|
||||
|
||||
def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
assert output[0] is not None
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@@ -28,40 +56,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_llama
|
||||
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
PP_DIM = 0
|
||||
PP_SIZE = 2
|
||||
pg_mesh = ProcessGroupMesh(PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
assert output[0] is not None
|
||||
check_llama_model_policy(name, org_model, stage_manager)
|
||||
check_llama_model_pipeline_forward(name, sharded_model, stage_manager)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -76,7 +82,7 @@ def check_llama(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, 4)
|
||||
spawn(check_llama, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user