[test] fixed rerun_on_exception and adapted test cases (#487)

This commit is contained in:
Frank Lee
2022-03-25 17:25:12 +08:00
committed by GitHub
parent 4d322b79da
commit 3601b2bad0
31 changed files with 143 additions and 135 deletions

View File

@@ -7,14 +7,10 @@ import pytest
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.testing import rerun_on_exception
from functools import partial
CONFIG = dict(
parallel=dict(
tensor=dict(size=4, mode='sequence')
)
)
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
def check_ring_qk(rank, world_size):
@@ -26,14 +22,14 @@ def check_ring_qk(rank, world_size):
sub_seq_length = seq_length // world_size
# create master tensors
q = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
k = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_q = q.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_k = k.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes
q.requires_grad = True
@@ -53,7 +49,7 @@ def check_ring_qk(rank, world_size):
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
# check master and distributed attetion scores
sub_master_a = a[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
# run master backward
@@ -61,11 +57,11 @@ def check_ring_qk(rank, world_size):
a.mean().backward()
# run distributed backward
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_a, partial_master_a_grad)
# check master and distributed grads
partial_master_q_grad = q.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
'attention score cannot match'
@@ -79,14 +75,14 @@ def check_ring_av(rank, world_size):
sub_seq_length = seq_length // world_size
# create master tensors
a = torch.rand(batch_size*num_heads, seq_length, seq_length).cuda()
v = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda()
v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_a = a.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_v = v.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes
a.requires_grad = True
@@ -108,7 +104,7 @@ def check_ring_av(rank, world_size):
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
# check master and distributed output
sub_master_out = out[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
# # run master backward
@@ -116,23 +112,17 @@ def check_ring_av(rank, world_size):
out.mean().backward()
# # run distributed backward
partial_master_out_grad = out.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_out, partial_master_out_grad)
# # check master and distributed grads
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
'attention output cannot match'
def run_test(rank, world_size):
colossalai.launch(
rank=rank,
world_size=world_size,
config=CONFIG,
host='localhost',
port=29500
)
colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=29500)
# check_ring_qk(rank, world_size)
check_ring_av(rank, world_size)
@@ -142,6 +132,7 @@ def run_test(rank, world_size):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sequence():
world_size = 4
run_func = partial(run_test, world_size=world_size)