mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[pipeline] All bert models (#4233)
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2
.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
* add pure pipeline test
* finish some bert models
* finish all bert models
* finish bert tests
* fix bugs
* fix bugs
* fix test pipeline
* fix data gen for qa
* update the set pipeline forward
* shared params
* fix bugs
This commit is contained in:
@@ -7,6 +7,7 @@ from transformers.models.bert.modeling_bert import BertForPreTraining
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
@@ -35,16 +36,20 @@ def check_bert_for_pretraining_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_for_pretraining_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
print(output['hidden_states'].shape)
|
||||
output = bert_for_pretraining_forward(
|
||||
self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
|
||||
else:
|
||||
@@ -52,8 +57,8 @@ def check_bert_for_pretraining_forward():
|
||||
output = bert_for_pretraining_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
print(output[0].shape)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
assert output[0].shape == (2, 3, 30522)
|
||||
# assert output[1].shape == (2, 768)
|
||||
|
||||
|
@@ -7,12 +7,13 @@ from transformers.models.bert.modeling_bert import BertLMHeadModel
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_lmhead_forward():
|
||||
def check_bert_lm_head_model_forward():
|
||||
configuration = BertConfig()
|
||||
model = BertLMHeadModel(configuration)
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
@@ -35,24 +36,28 @@ def check_bert_lmhead_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_lmhead_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
|
||||
output = bert_lm_head_model_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output['hidden_states'].shape)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3))
|
||||
output = bert_lmhead_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
output = bert_lm_head_model_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 30522)
|
||||
|
||||
@@ -93,7 +98,7 @@ def check_bert_lmhead_policy():
|
||||
|
||||
def run_dist_model(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bert_lmhead_forward()
|
||||
check_bert_lm_head_model_forward()
|
||||
|
||||
|
||||
def run_dist_policy(rank, world_size, port):
|
||||
@@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port):
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bert_lmhead_forward():
|
||||
def test_bert_lm_head_model_forward():
|
||||
spawn(run_dist_model, 4)
|
||||
|
||||
|
||||
@@ -115,5 +120,5 @@ def test_bert_lmhead_policy():
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""test the bert for pretraining model forward and bert for pretraining model policy"""
|
||||
test_bert_lmhead_forward()
|
||||
test_bert_lm_head_model_forward()
|
||||
test_bert_lmhead_policy()
|
@@ -6,12 +6,14 @@ from transformers.models.bert.modeling_bert import BertModel
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_model_forward():
|
||||
# this test may crash for internet reasons
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
@@ -34,20 +36,25 @@ def check_bert_model_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
print(output['hidden_states'].shape)
|
||||
output = bert_model_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3))
|
||||
output = bert_model_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 768)
|
||||
|
||||
@@ -112,4 +119,3 @@ if __name__ == "__main__":
|
||||
"""test the bert model forward and bert model policy"""
|
||||
#test_bert_model_forward()
|
||||
test_bert_model_policy()
|
||||
# this test need config to run
|
||||
|
Reference in New Issue
Block a user