mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[pipeline] add pipeline forward for variants of gpt2 (#4238)
* add forward for GPTLMHeadModel * add test for gpt_lm * arranging get_held_layers method * arrange forward replacement * add forward for GPT2ForTokenClassification * add forward for GPT2ForSequenceClassification * fix test_shard_gpt2.py * add GPT2DoubleHeadsmodel & fix bugs * add id checking in get_shared_params
This commit is contained in:
committed by
Hongxin Liu
parent
7e4de520e1
commit
a14d352088
@@ -5,15 +5,9 @@ import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
@@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_gpt2
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
@@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name != "transformers_gpt":
|
||||
continue
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = 768
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
org_model.train()
|
||||
org_output = org_model(**inputs)
|
||||
hidden_state_shape = org_output['last_hidden_state'].shape
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
output = sharded_model(**inputs)
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
else:
|
||||
attention_mask = inputs['attention_mask']
|
||||
if not stage_manager.is_first_stage():
|
||||
# change inputs if not the first stage
|
||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask)
|
||||
if stage_manager.is_last_stage():
|
||||
assert output['last_hidden_state'].shape == hidden_state_shape
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
if name != 'transformers_gpt':
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user