mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-25 19:21:17 +00:00
* fix for async io * test for upgrading transformers * add ci machine * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_fp16_torch.py * Update build_on_pr.yml * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fiux * fix * fix * fix * upgrade llama * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * upgrade_bert * upgrade_bloom * [upgrade] upgrade gpt2 (#6291) * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * upgrade command * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * add explanation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [upgrade]Upgrade qwen2 (#6302) * upgrade qwen2 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * update_bloom * fix * add explantion * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade_sam * add the explanation * upgrade_t * fix * fix * fix * upgrade_gptj * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [upgrade]upgrade opt (#6307) * upgrade opt * fix * [upgrade]Upgrade mixtral (#6317) * upgrade mixtral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade infer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade drafter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade lazy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade mixtral --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [upgrade]Upgrade vit (#6308) * fix * fix * fix rotate embedding test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [upgrade]upgrade mistral (#6296) * upgrade mistral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix falcon * fix * Update test_shard_deepseek.py * Update build_on_pr.yml * Update requirements.txt * fix (#6327) * fix (#6328) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update bert.py * fix (#6329) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hanks <hangxu0304@gmail.com> Co-authored-by: wangbluo <2538539015@qq.com> Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
44 lines
1.7 KiB
Python
44 lines
1.7 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
# Same as the SamVisionAttention forward method in the v4.51.3 transformers
|
|
def forward_fn():
|
|
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
|
batch_size, height, width, _ = hidden_states.shape
|
|
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
|
qkv = (
|
|
self.qkv(hidden_states)
|
|
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
|
.permute(2, 0, 3, 1, 4)
|
|
)
|
|
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
|
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
|
|
|
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
|
|
|
if self.use_rel_pos:
|
|
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
|
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
|
)
|
|
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
|
|
attn_weights = attn_weights + decomposed_rel_pos
|
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
|
|
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
|
|
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
|
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
|
|
|
attn_output = self.proj(attn_output)
|
|
|
|
if output_attentions:
|
|
outputs = (attn_output, attn_weights)
|
|
else:
|
|
outputs = (attn_output, None)
|
|
|
|
return outputs
|
|
|
|
return forward
|