mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
[pipeline] add bloom model pipeline (#4210)
* bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * finish bloom model * test shard gpt2 * clear cache
This commit is contained in:
parent
31bcf867ae
commit
37d22f6878
@ -1,11 +1,26 @@
|
|||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import CrossEntropyLoss, Module
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomModel
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..modeling.bloom import build_bloom_alibi_tensor_fn
|
from ..modeling.bloom import build_bloom_alibi_tensor_fn
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BloomPolicy(Policy):
|
class BloomPolicy(Policy):
|
||||||
|
|
||||||
@ -110,7 +125,46 @@ class BloomPolicy(Policy):
|
|||||||
|
|
||||||
|
|
||||||
class BloomModelPolicy(BloomPolicy):
|
class BloomModelPolicy(BloomPolicy):
|
||||||
pass
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomModel
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages)
|
||||||
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
policy[BloomModel] = ModulePolicyDescription(method_replacement={
|
||||||
|
"forward":
|
||||||
|
partial(bloom_model_forward, stage_manager=self.pipeline_stage_manager, stage_index=stage_index)
|
||||||
|
})
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""
|
||||||
|
get pipeline layers for current stage
|
||||||
|
"""
|
||||||
|
module = self.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = []
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.word_embeddings)
|
||||||
|
held_layers.append(module.word_embeddings_layernorm)
|
||||||
|
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.ln_f)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
'''no shared params in bloommodel'''
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class BloomForCausalLMPolicy(BloomPolicy):
|
class BloomForCausalLMPolicy(BloomPolicy):
|
||||||
@ -181,3 +235,174 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
|||||||
class BloomForQuestionAnsweringPolicy(BloomPolicy):
|
class BloomForQuestionAnsweringPolicy(BloomPolicy):
|
||||||
# No head sharding as the output features is only 2
|
# No head sharding as the output features is only 2
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def bloom_model_forward(
|
||||||
|
self: BloomModel,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
**deprecated_arguments,
|
||||||
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||||
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||||
|
warnings.warn(
|
||||||
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||||
|
" passing `position_ids`.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
if len(deprecated_arguments) > 0:
|
||||||
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (output_hidden_states
|
||||||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# add warnings here
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
|
output_hidden_states = False
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||||
|
use_cache = False
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape batch_size x num_heads x N x N
|
||||||
|
|
||||||
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
|
# case: First stage of training
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
# check input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
|
# initialize in the first stage and then pass to the next stage
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
# extra recording tensor should be generated in the first stage
|
||||||
|
|
||||||
|
presents = () if use_cache else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = tuple([None] * len(self.h))
|
||||||
|
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
if past_key_values[0] is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] # source_len
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
|
||||||
|
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
# causal_mask is constructed every stage and its input is passed through different stages
|
||||||
|
causal_mask = self._prepare_attn_mask(
|
||||||
|
attention_mask,
|
||||||
|
input_shape=(batch_size, seq_length),
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx])):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
alibi,
|
||||||
|
causal_mask,
|
||||||
|
layer_past,
|
||||||
|
head_mask[i],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = block(
|
||||||
|
hidden_states,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
head_mask=head_mask[i],
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
alibi=alibi,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
presents = presents + (outputs[1],)
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + \
|
||||||
|
(outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
|
# TODO: deal with all_hidden_states, all_self_attentions, presents
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||||
|
|
||||||
|
# attention_mask is not returned ; presents = past_key_values
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=presents,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# always return dict for imediate stage
|
||||||
|
return {'hidden_states': hidden_states}
|
||||||
|
@ -51,15 +51,17 @@ output_transform_fn = lambda x: x
|
|||||||
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
|
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
|
||||||
loss_fn = lambda x: x.loss
|
loss_fn = lambda x: x.loss
|
||||||
|
|
||||||
config = transformers.GPT2Config(n_layer=2,
|
config = transformers.GPT2Config(
|
||||||
n_head=4,
|
n_layer=2,
|
||||||
vocab_size=50258,
|
n_head=4,
|
||||||
attn_pdrop=0,
|
#n_embd=128,
|
||||||
embd_pdrop=0,
|
vocab_size=50258,
|
||||||
resid_pdrop=0,
|
attn_pdrop=0,
|
||||||
summary_first_dropout=0,
|
embd_pdrop=0,
|
||||||
hidden_dropout=0,
|
resid_pdrop=0,
|
||||||
problem_type="single_label_classification")
|
summary_first_dropout=0,
|
||||||
|
hidden_dropout=0,
|
||||||
|
problem_type="single_label_classification")
|
||||||
|
|
||||||
# register the following models
|
# register the following models
|
||||||
model_zoo.register(name='transformers_gpt',
|
model_zoo.register(name='transformers_gpt',
|
||||||
|
@ -0,0 +1,84 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_hf_output_close,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
# check forward
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [False])
|
||||||
|
@parameterize('use_lazy_init', [False])
|
||||||
|
#TODO: merge this into test_shard_bloom
|
||||||
|
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
|
DP_DIM, PP_DIM = 0, 1
|
||||||
|
DP_SIZE, PP_SIZE = 2, 2
|
||||||
|
RANK_TO_COORDINATE = {
|
||||||
|
0: (0, 0),
|
||||||
|
1: (0, 1),
|
||||||
|
2: (1, 0),
|
||||||
|
3: (1, 1),
|
||||||
|
}
|
||||||
|
PP_RANKS_IN_GROUP = {
|
||||||
|
0: [0, 1],
|
||||||
|
1: [0, 1],
|
||||||
|
2: [2, 3],
|
||||||
|
3: [2, 3],
|
||||||
|
}
|
||||||
|
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||||
|
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||||
|
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||||
|
hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32).cuda()
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
if name == 'transformers_bloom':
|
||||||
|
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||||
|
enable_tensor_parallelism, use_lazy_init)
|
||||||
|
if stage_manager.stage == 0:
|
||||||
|
attention_mask = torch.ones_like(x).cuda()
|
||||||
|
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||||
|
assert output['hidden_states'].shape == (2, 3, 64)
|
||||||
|
else:
|
||||||
|
attention_mask = torch.ones((2, 3)).cuda()
|
||||||
|
output = sharded_model(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
assert output[0].shape == (2, 3, 64)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_bloom(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_bloom_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_bloom():
|
||||||
|
spawn(check_bloom, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_bloom()
|
@ -70,6 +70,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
@parameterize('enable_fused_normalization', [True, False])
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
@parameterize('use_lazy_init', [False, True])
|
@parameterize('use_lazy_init', [False, True])
|
||||||
|
@clear_cache_before_run()
|
||||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
Loading…
Reference in New Issue
Block a user