[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -23,30 +23,30 @@ def attention_ref(q, k, v, attn_mask=None, causal=False):
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
d = q.shape[-1]
scale = 1.0 / math.sqrt(d)
scores = torch.einsum('bthd,bshd->bhts', q * scale, k)
scores = torch.einsum("bthd,bshd->bhts", q * scale, k)
if attn_mask is not None:
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
scores.masked_fill_(causal_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
output = torch.einsum('bhts,bshd->bthd', attention, v)
output = torch.einsum("bhts,bshd->bthd", attention, v)
output = rearrange(output, "b s h d -> b s (h d)")
# Modify the data at the positions of the mask to 0
if attn_mask is not None:
output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0)
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0)
return output.to(dtype=dtype_og)
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
@parameterize('dropout', [0.0])
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
def test_attention_gpt(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
@@ -78,9 +78,9 @@ def test_attention_gpt(proj_shape, dtype, dropout):
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
@parameterize('dropout', [0.0])
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
def test_attention_bert(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
@@ -111,9 +111,9 @@ def test_attention_bert(proj_shape, dtype, dropout):
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
@parameterize('dropout', [0.0])
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
def test_attention_no_mask(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
@@ -141,9 +141,9 @@ def test_attention_no_mask(proj_shape, dtype, dropout):
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
@parameterize('dtype', DTYPE)
@parameterize('dropout', [0.0])
@parameterize("proj_shape", [(6, 24, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
def test_cross_attention(proj_shape, dtype, dropout):
(B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD