mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[shardformer] support Blip2 (#4243)
* support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin
This commit is contained in:
60
colossalai/shardformer/modeling/blip2.py
Normal file
60
colossalai/shardformer/modeling/blip2.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def forward_fn():
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
mixed_qkv = self.qkv(hidden_states)
|
||||
|
||||
# modified from original code, which is:
|
||||
# mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
||||
# 2, 0, 3, 1, 4
|
||||
# )
|
||||
# to:
|
||||
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
query_states, key_states, value_states = (
|
||||
mixed_qkv[0],
|
||||
mixed_qkv[1],
|
||||
mixed_qkv[2],
|
||||
)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores * self.scale
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.projection(context_layer)
|
||||
|
||||
outputs = (output, attention_probs) if output_attentions else (output, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
Reference in New Issue
Block a user