diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 3ea4db9e2..aa39bf40c 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -128,7 +128,7 @@ class OPTPipelineForwards: # required mask seq length can be calculated via length of past mask_seq_length = past_key_values_length + seq_length # embed positions - if self.decoder._use_flash_attention_2: + if self.decoder.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = ( @@ -542,6 +542,9 @@ class OPTPipelineForwards: def get_opt_flash_attention_forward(shard_config: ShardConfig): from transformers.models.opt.modeling_opt import OPTAttention + def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): + return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous() + def forward( self: OPTAttention, hidden_states: torch.Tensor, @@ -568,30 +571,20 @@ def get_opt_flash_attention_forward(shard_config: ShardConfig): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = _shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) + query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim) dropout_p = self.dropout if self.training else 0.0 attn_output = ColoAttention.attention( @@ -630,6 +623,8 @@ def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index 2da94a4fc..5ffc227f9 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -53,6 +53,7 @@ config = transformers.OPTConfig( num_hidden_layers=2, num_attention_heads=4, dropout=0, + attn_implementation="eager", ) # register the following models