[example]add gpt2 benchmark example script. (#5295)

* benchmark gpt2

* fix

fix

fix

fix

* [doc] fix typo in Colossal-LLaMA-2/README.md (#5247)

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed ddp test (#5254)

* [ci] fixed ddp test

* polish

* fix typo in  applications/ColossalEval/README.md (#5250)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] fix doc typo (#5256)

* [doc] fix annotation display

* [doc] fix llama2 doc

* [hotfix]: add pp sanity check and fix mbs arg (#5268)

* fix: fix misleading mbs arg

* feat: add pp sanity check

* fix: fix 1f1b sanity check

* [workflow] fixed incomplete bash command (#5272)

* [workflow] fixed oom tests (#5275)

* [workflow] fixed oom tests

* polish

* polish

* polish

* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)

* fix ci

fix

* fix test

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

* fix

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix

* [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)

* fix auto loading gpt2 tokenizer (#5279)

* [doc] add llama2-13B disyplay (#5285)

* Update README.md

* fix 13b typo

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* fix llama pretrain (#5287)

* fix

* fix

* fix

fix

* fix

fix

fix

* fix

fix

* benchmark gpt2

* fix

fix

fix

fix

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* fix

fix

* fix

fix

fix

* fix

* fix

fix

fix

fix

fix

* fix

* Update shardformer.py

---------

Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>
This commit is contained in:
flybird11111
2024-03-04 16:18:13 +08:00
committed by GitHub
parent 4b8312c08e
commit 29695cf70c
22 changed files with 421 additions and 48 deletions

View File

@@ -24,6 +24,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
class GPT2PipelineForwards:
"""
@@ -326,7 +328,15 @@ class GPT2PipelineForwards:
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism:
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
@@ -1006,3 +1016,84 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import GPT2LMHeadModel
def forward(
self: GPT2LMHeadModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism:
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
return forward