support qwen model

This commit is contained in:
Wang Binluo
2024-04-09 11:50:35 +08:00
committed by アマデウス
parent 32e642bf40
commit 4c69e2dc91
3 changed files with 23 additions and 38 deletions

View File

@@ -44,7 +44,7 @@ class Qwen2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
)-> Union[Tuple, BaseModelOutputWithPast]:
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -82,14 +82,18 @@ class Qwen2PipelineForwards:
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
@@ -123,18 +127,11 @@ class Qwen2PipelineForwards:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
hidden_states,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@@ -148,20 +145,14 @@ class Qwen2PipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
None,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
@@ -315,7 +306,7 @@ class Qwen2PipelineForwards:
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def qwen2_for_sequence_classification_forward(
self: Qwen2ForSequenceClassification,