[pipeline] All bert models (#4233)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision

* add pure pipeline test

* finish some bert models

* finish all bert models

* finish bert tests

* fix bugs

* fix bugs

* fix test pipeline

* fix data gen for qa

* update the set pipeline forward

* shared params

* fix bugs
This commit is contained in:
Jianghai 2023-07-17 16:12:20 +08:00 committed by Hongxin Liu
parent a14d352088
commit e7cc62d735
13 changed files with 988 additions and 144 deletions

View File

@ -64,7 +64,10 @@ def _broadcast_object_list(object_list: List[Any],
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
if torch.__version__ >= "1.13.0":
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list])
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)

View File

@ -205,7 +205,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:

View File

@ -42,6 +42,8 @@ _POLICY_LIST = {
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
"transformers.models.bert.modeling_bert.BertForQuestionAnswering":
PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"),
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel":

File diff suppressed because it is too large Load Diff

View File

@ -212,11 +212,13 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
llama_model = self.model.model
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
# tie weights
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
return [{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
}]
return []

View File

@ -1 +1 @@
from .torchrec import *
#from .torchrec import *

View File

@ -87,6 +87,17 @@ def data_gen_for_mcq():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
def data_gen_for_qa():
# generating data for question answering
# no need for labels and use start and end position instead
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
return data
# define output transform function
output_transform_fn = lambda x: x
@ -150,3 +161,9 @@ model_zoo.register(name='transformers_bert_for_mcq',
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_question_answering',
model_fn=lambda: transformers.BertForQuestionAnswering(config),
data_gen_fn=data_gen_for_qa,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -7,6 +7,7 @@ from transformers.models.bert.modeling_bert import BertForPreTraining
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -35,16 +36,20 @@ def check_bert_for_pretraining_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_for_pretraining_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output['hidden_states'].shape)
output = bert_for_pretraining_forward(
self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index,
)
assert output['hidden_states'].shape == (2, 3, 768)
else:
@ -52,8 +57,8 @@ def check_bert_for_pretraining_forward():
output = bert_for_pretraining_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output[0].shape)
stage_manager=stage_manager,
stage_index=stage_index)
assert output[0].shape == (2, 3, 30522)
# assert output[1].shape == (2, 768)

View File

@ -7,12 +7,13 @@ from transformers.models.bert.modeling_bert import BertLMHeadModel
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_lmhead_forward():
def check_bert_lm_head_model_forward():
configuration = BertConfig()
model = BertLMHeadModel(configuration)
DP_DIM, PP_DIM = 0, 1
@ -35,24 +36,28 @@ def check_bert_lmhead_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_lmhead_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager)
output = bert_lm_head_model_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index)
print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 768)
else:
attention_mask = torch.ones((2, 3))
output = bert_lmhead_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
output = bert_lm_head_model_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index)
print(output[0].shape)
assert output[0].shape == (2, 3, 30522)
@ -93,7 +98,7 @@ def check_bert_lmhead_policy():
def run_dist_model(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_lmhead_forward()
check_bert_lm_head_model_forward()
def run_dist_policy(rank, world_size, port):
@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_lmhead_forward():
def test_bert_lm_head_model_forward():
spawn(run_dist_model, 4)
@ -115,5 +120,5 @@ def test_bert_lmhead_policy():
if __name__ == "__main__":
"""test the bert for pretraining model forward and bert for pretraining model policy"""
test_bert_lmhead_forward()
test_bert_lm_head_model_forward()
test_bert_lmhead_policy()

View File

@ -6,12 +6,14 @@ from transformers.models.bert.modeling_bert import BertModel
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_model_forward():
# this test may crash for internet reasons
model = BertModel.from_pretrained('bert-base-uncased')
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
@ -34,20 +36,25 @@ def check_bert_model_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)
layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
print(output['hidden_states'].shape)
output = bert_model_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index)
assert output['hidden_states'].shape == (2, 3, 768)
else:
attention_mask = torch.ones((2, 3))
output = bert_model_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
stage_manager=stage_manager,
stage_index=stage_index)
print(output[0].shape)
assert output[0].shape == (2, 3, 768)
@ -112,4 +119,3 @@ if __name__ == "__main__":
"""test the bert model forward and bert model policy"""
#test_bert_model_forward()
test_bert_model_policy()
# this test need config to run

View File

@ -49,7 +49,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
original_model.train()
sharded_model.train()

View File

@ -0,0 +1,164 @@
import random
from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
class PipelineOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module):
super().__init__(optim)
params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group['params'] if p in params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})
# TODO: support amp
class PipelinedModel(ModelWrapper):
def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager
shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module)
self.shared_param_process_groups = []
super().__init__(module)
def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0):
sampler = DistributedSampler(
dataset,
#rank=self.pg_mesh.coordinate(DP_AXIS),
shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
)
def execute_pipeline(
data_iter: Iterator,
model: PipelinedModel,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: PipelineOptimizer,
return_loss: bool = True,
return_outputs: bool = False,
schedule: OneForwardOneBackwardSchedule = None,
) -> dict:
# return loss or outputs if needed
outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs)
return outputs
class data_iter():
def __getitem__(self, x):
return torch.randint(0, 100, (4, 128)).cuda()
def loss(x, y):
return (x[0].float().mean() - y[0].float().mean())
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
PP_DIM = 0
PP_SIZE = 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
from datasets import load_dataset
#dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi")
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
num_microbatches = 2
org_model = model_fn().cuda()
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
#dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4)
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager)
pipelined_model = PipelinedModel(org_model, shard_config, stage_manager)
pp_optimizer = PipelineOptimizer(optimizer, pipelined_model)
data_it = iter(data_iter())
results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule)
if stage_manager.is_last_stage():
assert results['loss'] is not None
assert results['outputs'] is None
torch.cuda.empty_cache()
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 2)
if __name__ == "__main__":
test_llama()

View File

@ -45,25 +45,37 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
x = torch.randint(0, 1000, (2, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name == 'transformers_bert':
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
if name == 'transformers_bert_for_mcq':
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
attention_mask = torch.ones_like(x).cuda()
if stage_manager.stage == 0:
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
assert output['hidden_states'].shape == (6, 3, 128)
else:
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
output = sharded_model(input_ids=x,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
assert output[0].shape == (2, 3)
else:
x = torch.randint(0, 1000, (2, 3)).cuda()
# one batch, 2 single sentences, each sentence has 3 tokens
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
# print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 128)
else:
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
# print(output[0].shape)
assert output[0].shape == (2, 3, 128)
assert output[0].shape[0] == 2
torch.cuda.empty_cache()