mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -23,30 +23,30 @@ def attention_ref(q, k, v, attn_mask=None, causal=False):
|
||||
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
||||
d = q.shape[-1]
|
||||
scale = 1.0 / math.sqrt(d)
|
||||
scores = torch.einsum('bthd,bshd->bhts', q * scale, k)
|
||||
scores = torch.einsum("bthd,bshd->bhts", q * scale, k)
|
||||
|
||||
if attn_mask is not None:
|
||||
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
|
||||
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
|
||||
if causal:
|
||||
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
|
||||
scores.masked_fill_(causal_mask, float('-inf'))
|
||||
scores.masked_fill_(causal_mask, float("-inf"))
|
||||
attention = torch.softmax(scores, dim=-1)
|
||||
|
||||
output = torch.einsum('bhts,bshd->bthd', attention, v)
|
||||
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
||||
output = rearrange(output, "b s h d -> b s (h d)")
|
||||
|
||||
# Modify the data at the positions of the mask to 0
|
||||
if attn_mask is not None:
|
||||
output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0)
|
||||
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0)
|
||||
|
||||
return output.to(dtype=dtype_og)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('proj_shape', [(6, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
@parameterize('dropout', [0.0])
|
||||
@parameterize("proj_shape", [(6, 8, 4, 16)])
|
||||
@parameterize("dtype", DTYPE)
|
||||
@parameterize("dropout", [0.0])
|
||||
def test_attention_gpt(proj_shape, dtype, dropout):
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
@@ -78,9 +78,9 @@ def test_attention_gpt(proj_shape, dtype, dropout):
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('proj_shape', [(6, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
@parameterize('dropout', [0.0])
|
||||
@parameterize("proj_shape", [(6, 8, 4, 16)])
|
||||
@parameterize("dtype", DTYPE)
|
||||
@parameterize("dropout", [0.0])
|
||||
def test_attention_bert(proj_shape, dtype, dropout):
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
@@ -111,9 +111,9 @@ def test_attention_bert(proj_shape, dtype, dropout):
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('proj_shape', [(6, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
@parameterize('dropout', [0.0])
|
||||
@parameterize("proj_shape", [(6, 8, 4, 16)])
|
||||
@parameterize("dtype", DTYPE)
|
||||
@parameterize("dropout", [0.0])
|
||||
def test_attention_no_mask(proj_shape, dtype, dropout):
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
@@ -141,9 +141,9 @@ def test_attention_no_mask(proj_shape, dtype, dropout):
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
@parameterize('dropout', [0.0])
|
||||
@parameterize("proj_shape", [(6, 24, 8, 4, 16)])
|
||||
@parameterize("dtype", DTYPE)
|
||||
@parameterize("dropout", [0.0])
|
||||
def test_cross_attention(proj_shape, dtype, dropout):
|
||||
(B, S, T, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
Reference in New Issue
Block a user