mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
2f615a49fd
commit
c6291be1b1
@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
@ -94,7 +93,6 @@ class LlamaPipelineForwards:
|
|||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
|
|
||||||
# Support SP + PP
|
# Support SP + PP
|
||||||
sp_mode = shard_config.sequence_parallelism_mode
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
sp_group = shard_config.sequence_parallel_process_group
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
@ -113,7 +111,6 @@ class LlamaPipelineForwards:
|
|||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||||
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length + past_seen_tokens
|
seq_length_with_past = seq_length + past_seen_tokens
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -143,7 +140,9 @@ class LlamaPipelineForwards:
|
|||||||
invert=(sp_mode != "ring_attn"),
|
invert=(sp_mode != "ring_attn"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
|
attn_kwargs: torch.Tensor = self._update_causal_mask(
|
||||||
|
attention_mask, hidden_states, cache_position, past_key_values
|
||||||
|
)
|
||||||
|
|
||||||
# Support SP + PP. Later stages have already received the split input.
|
# Support SP + PP. Later stages have already received the split input.
|
||||||
split_input = disable_pp or stage_manager.is_first_stage()
|
split_input = disable_pp or stage_manager.is_first_stage()
|
||||||
@ -207,7 +206,7 @@ class LlamaPipelineForwards:
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
position_embeddings
|
position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@ -218,7 +217,7 @@ class LlamaPipelineForwards:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
@ -33,11 +33,7 @@ class LlamaPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||||
LlamaAttention,
|
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
|
Loading…
Reference in New Issue
Block a user