mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[test] fixed rerun_on_exception and adapted test cases (#487)
This commit is contained in:
@@ -7,14 +7,10 @@ import pytest
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from functools import partial
|
||||
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
tensor=dict(size=4, mode='sequence')
|
||||
)
|
||||
)
|
||||
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
|
||||
|
||||
|
||||
def check_ring_qk(rank, world_size):
|
||||
@@ -26,14 +22,14 @@ def check_ring_qk(rank, world_size):
|
||||
sub_seq_length = seq_length // world_size
|
||||
|
||||
# create master tensors
|
||||
q = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
|
||||
k = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
|
||||
q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
|
||||
k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
|
||||
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
|
||||
# create distributed tensors
|
||||
sub_q = q.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
|
||||
sub_k = k.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
|
||||
sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
|
||||
sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
|
||||
|
||||
# set autograd attributes
|
||||
q.requires_grad = True
|
||||
@@ -53,7 +49,7 @@ def check_ring_qk(rank, world_size):
|
||||
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
|
||||
|
||||
# check master and distributed attetion scores
|
||||
sub_master_a = a[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||
sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
|
||||
|
||||
# run master backward
|
||||
@@ -61,11 +57,11 @@ def check_ring_qk(rank, world_size):
|
||||
a.mean().backward()
|
||||
|
||||
# run distributed backward
|
||||
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
torch.autograd.backward(sub_a, partial_master_a_grad)
|
||||
|
||||
# check master and distributed grads
|
||||
partial_master_q_grad = q.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||
partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
|
||||
'attention score cannot match'
|
||||
|
||||
@@ -79,14 +75,14 @@ def check_ring_av(rank, world_size):
|
||||
sub_seq_length = seq_length // world_size
|
||||
|
||||
# create master tensors
|
||||
a = torch.rand(batch_size*num_heads, seq_length, seq_length).cuda()
|
||||
v = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
|
||||
a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda()
|
||||
v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
|
||||
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
|
||||
# create distributed tensors
|
||||
sub_a = a.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
|
||||
sub_v = v.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
|
||||
sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
|
||||
sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
|
||||
|
||||
# set autograd attributes
|
||||
a.requires_grad = True
|
||||
@@ -108,7 +104,7 @@ def check_ring_av(rank, world_size):
|
||||
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
|
||||
|
||||
# check master and distributed output
|
||||
sub_master_out = out[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||
sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
|
||||
|
||||
# # run master backward
|
||||
@@ -116,23 +112,17 @@ def check_ring_av(rank, world_size):
|
||||
out.mean().backward()
|
||||
|
||||
# # run distributed backward
|
||||
partial_master_out_grad = out.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||
partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
torch.autograd.backward(sub_out, partial_master_out_grad)
|
||||
|
||||
# # check master and distributed grads
|
||||
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
|
||||
'attention output cannot match'
|
||||
|
||||
|
||||
def run_test(rank, world_size):
|
||||
colossalai.launch(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
config=CONFIG,
|
||||
host='localhost',
|
||||
port=29500
|
||||
)
|
||||
colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=29500)
|
||||
|
||||
# check_ring_qk(rank, world_size)
|
||||
check_ring_av(rank, world_size)
|
||||
@@ -142,6 +132,7 @@ def run_test(rank, world_size):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_sequence():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, world_size=world_size)
|
||||
|
Reference in New Issue
Block a user