[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2025-02-14 14:48:54 +08:00
committed by GitHub
parent ec73f1b5e2
commit 014837e725
21 changed files with 478 additions and 91 deletions

View File

@@ -4,9 +4,10 @@ import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import (
@@ -16,6 +17,7 @@ from colossalai.moe._operation import (
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import ParallelModule
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
@@ -167,6 +169,9 @@ def deepseek_v3_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
hidden_states_internal: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -203,8 +208,11 @@ def deepseek_v3_model_forward(
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if stage_manager is None or stage_manager.is_first_stage():
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
else:
inputs_embeds = hidden_states_internal
if self._use_flash_attention_2:
# 2d mask is passed through the layers
@@ -226,7 +234,11 @@ def deepseek_v3_model_forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for i, decoder_layer in enumerate(self.layers):
if stage_index is not None:
start_idx, end_idx = stage_index
else:
start_idx, end_idx = 0, len(self.layers)
for i, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -258,7 +270,8 @@ def deepseek_v3_model_forward(
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if stage_manager is None or stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
@@ -267,6 +280,10 @@ def deepseek_v3_model_forward(
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if stage_manager is not None and not stage_manager.is_last_stage():
return {
"hidden_states_internal": hidden_states,
}
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
@@ -275,3 +292,94 @@ def deepseek_v3_model_forward(
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def deepseek_v3_for_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
hidden_states_internal: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
>>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = deepseek_v3_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
stage_index=stage_index,
hidden_states_internal=hidden_states_internal,
)
if stage_manager is not None and not stage_manager.is_last_stage():
return outputs
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)