diff --git a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py index 156e60333..2b7b999d4 100644 --- a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -12,15 +12,10 @@ def check_selfattention(): BATCH = 4 HIDDEN_SIZE = 16 - layer = TransformerSelfAttentionRing( - 16, - 8, - 8, - 0.1 - ) + layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) layer = layer.to(get_current_device()) hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) - attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( - get_current_device()) + attention_mask = torch.randint(low=0, high=2, + size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device()) out = layer(hidden_states, attention_mask)